Skip to content
5 changes: 5 additions & 0 deletions src/sentry/consumers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,11 @@ def ingest_transactions_options() -> list[click.Option]:
is_flag=True,
default=False,
),
click.Option(
["--max-memory-percentage", "max_memory_percentage"],
default=1.0,
help="Maximum memory usage of the Redis cluster in % (0.0-1.0) before the consumer backpressures.",
),
*multiprocessing_options(default_max_batch_size=100),
],
},
Expand Down
11 changes: 8 additions & 3 deletions src/sentry/processing/backpressure/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import rb
import requests
from redis import StrictRedis
from rediscluster import RedisCluster


Expand Down Expand Up @@ -47,7 +48,7 @@ def query_rabbitmq_memory_usage(host: str) -> ServiceMemory:
# Based on configuration, this could be:
# - a `rediscluster` Cluster (actually `RetryingRedisCluster`)
# - a `rb.Cluster` (client side routing cluster client)
Cluster = Union[RedisCluster, rb.Cluster]
Cluster = Union[RedisCluster, rb.Cluster, StrictRedis]


def get_memory_usage(node_id: str, info: Mapping[str, Any]) -> ServiceMemory:
Expand All @@ -68,12 +69,14 @@ def get_host_port_info(node_id: str, cluster: Cluster) -> NodeInfo:
# RedisCluster node mapping
node = cluster.connection_pool.nodes.nodes.get(node_id)
return NodeInfo(node["host"], node["port"])
else:
elif isinstance(cluster, rb.Cluster):
# rb.Cluster node mapping
node = cluster.hosts[node_id]
return NodeInfo(node.host, node.port)
except Exception:
return NodeInfo(None, None)
pass

return NodeInfo(None, None)


def iter_cluster_memory_usage(cluster: Cluster) -> Generator[ServiceMemory]:
Expand All @@ -83,6 +86,8 @@ def iter_cluster_memory_usage(cluster: Cluster) -> Generator[ServiceMemory]:
if isinstance(cluster, RedisCluster):
# `RedisCluster` returns these as a dictionary, with the node-id as key
cluster_info = cluster.info()
elif isinstance(cluster, StrictRedis):
cluster_info = {"main": cluster.info()}
else:
# rb.Cluster returns a promise with a dictionary with a _local_ node-id as key
with cluster.all() as client:
Expand Down
38 changes: 24 additions & 14 deletions src/sentry/spans/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,15 @@
from __future__ import annotations

import itertools
from collections.abc import MutableMapping, Sequence
from collections.abc import Generator, MutableMapping, Sequence
from typing import Any, NamedTuple

import rapidjson
from django.conf import settings
from django.utils.functional import cached_property
from sentry_redis_tools.clients import RedisCluster, StrictRedis

from sentry.processing.backpressure.memory import ServiceMemory, iter_cluster_memory_usage
from sentry.utils import metrics, redis

# SegmentKey is an internal identifier used by the redis buffer that is also
Expand Down Expand Up @@ -306,6 +307,27 @@ def _group_by_parent(self, spans: Sequence[Span]) -> dict[tuple[str, str], list[

return trees

def record_stored_segments(self):
with metrics.timer("spans.buffer.get_stored_segments"):
with self.client.pipeline(transaction=False) as p:
for shard in self.assigned_shards:
key = self._get_queue_key(shard)
p.zcard(key)
Comment on lines +314 to +315
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please consider measuring this in bytes rather than set cardinality as spans can be very different from each other.

A way to do this is to keep a sharded counter of the size of what you write, when you add a batch of spans you increment, when you flush you reduce. Then you can rely on the amount of data in redis rather than the number of spans.

Or, better, measure the amount of free memory in redis like backpressure does. Then you cannot go wrong with the signal to trigger backpressure.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. also had a conversation with jan and had to split up into two kinds of backpressure. see updated pr description


result = p.execute()

assert len(result) == len(self.assigned_shards)

for shard_i, queue_size in zip(self.assigned_shards, result):
metrics.timing(
"spans.buffer.flush_segments.queue_size",
queue_size,
tags={"shard_i": shard_i},
)

def get_memory_info(self) -> Generator[ServiceMemory]:
return iter_cluster_memory_usage(self.client)

def flush_segments(self, now: int, max_segments: int = 0) -> dict[SegmentKey, FlushedSegment]:
cutoff = now

Expand All @@ -318,13 +340,11 @@ def flush_segments(self, now: int, max_segments: int = 0) -> dict[SegmentKey, Fl
p.zrangebyscore(
key, 0, cutoff, start=0 if max_segments else None, num=max_segments or None
)
p.zcard(key)
queue_keys.append(key)

result = iter(p.execute())
result = p.execute()

segment_keys: list[tuple[QueueKey, SegmentKey]] = []
queue_sizes = []

with metrics.timer("spans.buffer.flush_segments.load_segment_data"):
with self.client.pipeline(transaction=False) as p:
Expand All @@ -335,18 +355,8 @@ def flush_segments(self, now: int, max_segments: int = 0) -> dict[SegmentKey, Fl
segment_keys.append((queue_key, segment_key))
p.smembers(segment_key)

# ZCARD output
queue_sizes.append(next(result))

segments = p.execute()

for shard_i, queue_size in zip(self.assigned_shards, queue_sizes):
metrics.timing(
"spans.buffer.flush_segments.queue_size",
queue_size,
tags={"shard_i": shard_i},
)

return_segments = {}

num_has_root_spans = 0
Expand Down
23 changes: 13 additions & 10 deletions src/sentry/spans/consumers/process/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ def __init__(
input_block_size: int | None,
output_block_size: int | None,
produce_to_pipe: Callable[[KafkaPayload], None] | None = None,
max_memory_percentage: float = 1.0,
):
super().__init__()

# config
self.max_batch_size = max_batch_size
self.max_batch_time = max_batch_time
self.max_flush_segments = max_flush_segments
self.max_memory_percentage = max_memory_percentage
self.input_block_size = input_block_size
self.output_block_size = output_block_size
self.num_processes = num_processes
Expand All @@ -66,8 +68,9 @@ def create_with_partitions(

flusher = self._flusher = SpanFlusher(
buffer,
self.max_flush_segments,
self.produce_to_pipe,
max_flush_segments=self.max_flush_segments,
max_memory_percentage=self.max_memory_percentage,
produce_to_pipe=self.produce_to_pipe,
next_step=committer,
)

Expand All @@ -93,19 +96,19 @@ def create_with_partitions(
next_step=run_task,
)

# We use the produce timestamp to drive the clock for flushing, so that
# consumer backlogs do not cause segments to be flushed prematurely.
# The received timestamp in the span is too old for this purpose if
# Relay starts buffering, and we don't want that effect to propagate
# into this system.
def add_produce_timestamp_cb(message: Message[KafkaPayload]) -> tuple[int, KafkaPayload]:
def prepare_message(message: Message[KafkaPayload]) -> tuple[int, KafkaPayload]:
# We use the produce timestamp to drive the clock for flushing, so that
# consumer backlogs do not cause segments to be flushed prematurely.
# The received timestamp in the span is too old for this purpose if
# Relay starts buffering, and we don't want that effect to propagate
# into this system.
return (
int(message.timestamp.timestamp() if message.timestamp else time.time()),
message.payload,
)

add_timestamp = RunTask(
function=add_produce_timestamp_cb,
function=prepare_message,
next_step=batch,
)

Expand Down Expand Up @@ -133,7 +136,7 @@ def process_batch(
parent_span_id=val.get("parent_span_id"),
project_id=val["project_id"],
payload=payload.value,
is_segment_span=(val.get("parent_span_id") is None or val.get("is_remote")),
is_segment_span=bool(val.get("parent_span_id") is None or val.get("is_remote")),
)
spans.append(span)

Expand Down
77 changes: 75 additions & 2 deletions src/sentry/spans/consumers/process/flusher.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
import logging
import multiprocessing
import threading
import time
from collections.abc import Callable

import rapidjson
import sentry_sdk
from arroyo import Topic as ArroyoTopic
from arroyo.backends.kafka import KafkaPayload, KafkaProducer, build_kafka_configuration
from arroyo.processing.strategies.abstract import ProcessingStrategy
from arroyo.processing.strategies.abstract import MessageRejected, ProcessingStrategy
from arroyo.types import FilteredPayload, Message

from sentry.conf.types.kafka_definition import Topic
from sentry.spans.buffer import SpansBuffer
from sentry.utils import metrics
from sentry.utils.kafka_config import get_kafka_producer_cluster_options, get_topic_definition

MAX_PROCESS_RESTARTS = 10

logger = logging.getLogger(__name__)


class SpanFlusher(ProcessingStrategy[FilteredPayload | int]):
"""
Expand All @@ -33,15 +39,19 @@ def __init__(
self,
buffer: SpansBuffer,
max_flush_segments: int,
max_memory_percentage: float,
produce_to_pipe: Callable[[KafkaPayload], None] | None,
next_step: ProcessingStrategy[FilteredPayload | int],
):
self.buffer = buffer
self.max_flush_segments = max_flush_segments
self.max_memory_percentage = max_memory_percentage
self.next_step = next_step

self.stopped = multiprocessing.Value("i", 0)
self.redis_was_full = False
self.current_drift = multiprocessing.Value("i", 0)
self.should_backpressure = multiprocessing.Value("i", 0)

from sentry.utils.arroyo import _get_arroyo_subprocess_initializer

Expand All @@ -59,24 +69,30 @@ def __init__(
initializer,
self.stopped,
self.current_drift,
self.should_backpressure,
self.buffer,
self.max_flush_segments,
produce_to_pipe,
),
daemon=True,
)

self.process_restarts = 0

self.process.start()

@staticmethod
def main(
initializer: Callable | None,
stopped,
current_drift,
should_backpressure,
buffer: SpansBuffer,
max_flush_segments: int,
produce_to_pipe: Callable[[KafkaPayload], None] | None,
) -> None:
sentry_sdk.set_tag("sentry_spans_buffer_component", "flusher")

try:
if initializer:
initializer()
Expand All @@ -102,6 +118,10 @@ def produce(payload: KafkaPayload) -> None:
now = int(time.time()) + current_drift.value
flushed_segments = buffer.flush_segments(max_segments=max_flush_segments, now=now)

should_backpressure.value = len(flushed_segments) >= max_flush_segments * len(
buffer.assigned_shards
)

if not flushed_segments:
time.sleep(1)
continue
Expand Down Expand Up @@ -139,8 +159,61 @@ def poll(self) -> None:
self.next_step.poll()

def submit(self, message: Message[FilteredPayload | int]) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a short docstring that shows the two conditions we return with back pressure. Let's also document in which of the two cases the flusher itself will stop and why.

# Note that submit is not actually a hot path. Their message payloads
# are mapped from *batches* of spans, and there are a handful of spans
# per second at most. If anything, self.poll() might even be called
# more often than submit()
if not self.process.is_alive():
metrics.incr("sentry.spans.buffer.flusher_dead")
if self.process_restarts < MAX_PROCESS_RESTARTS:
self.process.start()
self.process_restarts += 1
else:
raise RuntimeError(
"flusher process has crashed.\n\nSearch for sentry_spans_buffer_component:flusher in Sentry to get the original error."
)

self.buffer.record_stored_segments()

# We pause insertion into Redis if the flusher is not making progress
# fast enough. We could backlog into Redis, but we assume, despite best
# efforts, it is still always going to be less durable than Kafka.
# Minimizing our Redis memory usage also makes COGS easier to reason
# about.
#
# should_backpressure is true if there are many segments to flush, but
# the flusher can't get all of them out.
if self.should_backpressure.value:
metrics.incr("sentry.spans.buffer.flusher.backpressure")
raise MessageRejected()

# We set the drift. The backpressure based on redis memory comes after.
# If Redis is full for a long time, the drift will grow into a large
# negative value, effectively pausing flushing as well.
if isinstance(message.payload, int):
self.current_drift.value = message.payload - int(time.time())
self.current_drift.value = drift = message.payload - int(time.time())
metrics.timing("sentry.spans.buffer.flusher.drift", drift)

# We also pause insertion into Redis if Redis is too full. In this case
# we cannot allow the flusher to progress either, as it would write
# partial/fragmented segments to buffered-segments topic. We have to
# wait until the situation is improved manually.
if self.max_memory_percentage < 1.0:
memory_infos = list(self.buffer.get_memory_info())
used = sum(x.used for x in memory_infos)
available = sum(x.available for x in memory_infos)
if available > 0 and used / available > self.max_memory_percentage:
if not self.redis_was_full:
logger.fatal("Pausing consumer due to Redis being full")
metrics.incr("sentry.spans.buffer.flusher.hard_backpressure")
self.redis_was_full = True
# Pause consumer if Redis memory is full. Because the drift is
# set before we emit backpressure, the flusher effectively
# stops as well. Alternatively we may simply crash the consumer
# but this would also trigger a lot of rebalancing.
raise MessageRejected()

self.redis_was_full = False
self.next_step.submit(message)

def terminate(self) -> None:
Expand Down
17 changes: 1 addition & 16 deletions tests/sentry/spans/consumers/process/test_consumer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import threading
from datetime import datetime

import rapidjson
Expand All @@ -8,16 +7,7 @@
from sentry.spans.consumers.process.factory import ProcessSpansStrategyFactory


class FakeProcess(threading.Thread):
"""
Pretend this is multiprocessing.Process
"""

def terminate(self):
pass


def test_basic(monkeypatch, request):
def test_basic(monkeypatch):
# Flush very aggressively to make test pass instantly
monkeypatch.setattr("time.sleep", lambda _: None)

Expand Down Expand Up @@ -61,11 +51,6 @@ def add_commit(offsets, force=False):
)
)

@request.addfinalizer
def _():
step.join()
fac.shutdown()

step.poll()
fac._flusher.current_drift.value = 9000 # "advance" our "clock"

Expand Down
Loading
Loading