Skip to content

Commit 2cbf1e2

Browse files
Stack-Attackabrarsheikhzcin
authored
Update metrics_utils for future global metrics aggregation in controller. (#55568)
## Why are these changes needed? These changes modify the autoscaler metrics collection and aggregation functions in preparation for global aggregation in the controller. ## Related issue number Partial for #46497 Required for #41135 #51905 <!-- For example: "Closes #1234" --> ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [x] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Kyle Robinson <[email protected]> Signed-off-by: Kyle Robinson <[email protected]> Signed-off-by: abrar <[email protected]> Co-authored-by: Abrar Sheikh <[email protected]> Co-authored-by: Cindy Zhang <[email protected]> Co-authored-by: abrar <[email protected]>
1 parent 117a642 commit 2cbf1e2

File tree

4 files changed

+623
-87
lines changed

4 files changed

+623
-87
lines changed

python/ray/serve/_private/metrics_utils.py

Lines changed: 203 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,28 @@
11
import asyncio
22
import bisect
33
import logging
4+
import statistics
45
from collections import defaultdict
56
from dataclasses import dataclass, field
6-
from typing import Callable, DefaultDict, Dict, Hashable, List, Optional
7+
from itertools import chain
8+
from typing import (
9+
Callable,
10+
DefaultDict,
11+
Dict,
12+
Hashable,
13+
Iterable,
14+
List,
15+
Optional,
16+
Tuple,
17+
)
718

819
from ray.serve._private.constants import (
920
METRICS_PUSHER_GRACEFUL_SHUTDOWN_TIMEOUT_S,
1021
SERVE_LOGGER_NAME,
1122
)
1223

24+
QUEUED_REQUESTS_KEY = "queued"
25+
1326
logger = logging.getLogger(SERVE_LOGGER_NAME)
1427

1528

@@ -152,7 +165,7 @@ def prune_keys_and_compact_data(self, start_timestamp_s: float):
152165

153166
def _get_datapoints(
154167
self, key: Hashable, window_start_timestamp_s: float
155-
) -> List[float]:
168+
) -> List[TimeStampedValue]:
156169
"""Get all data points given key after window_start_timestamp_s"""
157170

158171
datapoints = self.data[key]
@@ -165,52 +178,205 @@ def _get_datapoints(
165178
)
166179
return datapoints[idx:]
167180

168-
def window_average(
169-
self, key: Hashable, window_start_timestamp_s: float, do_compact: bool = True
181+
def _aggregate_reduce(
182+
self,
183+
keys: Iterable[Hashable],
184+
aggregate_fn: Callable[[Iterable[float]], float],
185+
) -> Tuple[Optional[float], int]:
186+
"""Reduce the entire set of timeseries values across the specified keys.
187+
188+
Args:
189+
keys: Iterable of keys to aggregate across.
190+
aggregate_fn: Function to apply across all float values, e.g., sum, max.
191+
192+
Returns:
193+
A tuple of (float, int) where the first element is the aggregated value
194+
and the second element is the number of valid keys used.
195+
Returns (None, 0) if no valid keys have data.
196+
197+
Example:
198+
Suppose the store contains:
199+
>>> store = InMemoryMetricsStore()
200+
>>> store.data.update({
201+
... "a": [TimeStampedValue(0, 1.0), TimeStampedValue(1, 2.0)],
202+
... "b": [],
203+
... "c": [TimeStampedValue(0, 10.0)],
204+
... })
205+
206+
Using sum across keys:
207+
208+
>>> store._aggregate_reduce(keys=["a", "b", "c"], aggregate_fn=sum)
209+
(13.0, 2)
210+
211+
Here:
212+
- The aggregated value is 1.0 + 2.0 + 10.0 = 13.0
213+
- Only keys "a" and "c" contribute values, so report_count = 2
214+
"""
215+
valid_key_count = 0
216+
217+
def _values_generator():
218+
"""Generator that yields values from valid keys without storing them all in memory."""
219+
nonlocal valid_key_count
220+
for key in keys:
221+
series = self.data.get(key, [])
222+
if not series:
223+
continue
224+
225+
valid_key_count += 1
226+
for timestamp_value in series:
227+
yield timestamp_value.value
228+
229+
# Create the generator and check if it has any values
230+
values_gen = _values_generator()
231+
try:
232+
first_value = next(values_gen)
233+
except StopIteration:
234+
# No valid data found
235+
return None, 0
236+
237+
# Apply aggregation to the generator (memory efficient)
238+
aggregated_result = aggregate_fn(chain([first_value], values_gen))
239+
return aggregated_result, valid_key_count
240+
241+
def get_latest(
242+
self,
243+
key: Hashable,
170244
) -> Optional[float]:
171-
"""Perform a window average operation for metric `key`
245+
"""Get the latest value for a given key."""
246+
if not self.data.get(key, None):
247+
return None
248+
return self.data[key][-1].value
249+
250+
def aggregate_min(
251+
self,
252+
keys: Iterable[Hashable],
253+
) -> Tuple[Optional[float], int]:
254+
"""Find the min value across all timeseries values at the specified keys.
172255
173256
Args:
174-
key: the metric name.
175-
window_start_timestamp_s: the unix epoch timestamp for the
176-
start of the window. The computed average will use all datapoints
177-
from this timestamp until now.
178-
do_compact: whether or not to delete the datapoints that's
179-
before `window_start_timestamp_s` to save memory. Default is
180-
true.
257+
keys: Iterable of keys to aggregate across.
181258
Returns:
182-
The average of all the datapoints for the key on and after time
183-
window_start_timestamp_s, or None if there are no such points.
259+
A tuple of (float, int) where the first element is the min across
260+
all values found at `keys`, and the second is the number of valid
261+
keys used to compute the min.
262+
Returns (None, 0) if no valid keys have data.
184263
"""
185-
points_after_idx = self._get_datapoints(key, window_start_timestamp_s)
264+
return self._aggregate_reduce(keys, min)
186265

187-
if do_compact:
188-
self.data[key] = points_after_idx
266+
def aggregate_max(
267+
self,
268+
keys: Iterable[Hashable],
269+
) -> Tuple[Optional[float], int]:
270+
"""Find the max value across all timeseries values at the specified keys.
189271
190-
if len(points_after_idx) == 0:
191-
return
192-
return sum(point.value for point in points_after_idx) / len(points_after_idx)
272+
Args:
273+
keys: Iterable of keys to aggregate across.
274+
Returns:
275+
A tuple of (float, int) where the first element is the max across
276+
all values found at `keys`, and the second is the number of valid
277+
keys used to compute the max.
278+
Returns (None, 0) if no valid keys have data.
279+
"""
280+
return self._aggregate_reduce(keys, max)
193281

194-
def max(
195-
self, key: Hashable, window_start_timestamp_s: float, do_compact: bool = True
196-
):
197-
"""Perform a max operation for metric `key`.
282+
def aggregate_sum(
283+
self,
284+
keys: Iterable[Hashable],
285+
) -> Tuple[Optional[float], int]:
286+
"""Sum the entire set of timeseries values across the specified keys.
198287
199288
Args:
200-
key: the metric name.
201-
window_start_timestamp_s: the unix epoch timestamp for the
202-
start of the window. The computed average will use all datapoints
203-
from this timestamp until now.
204-
do_compact: whether or not to delete the datapoints that's
205-
before `window_start_timestamp_s` to save memory. Default is
206-
true.
289+
keys: Iterable of keys to aggregate across.
207290
Returns:
208-
Max value of the data points for the key on and after time
209-
window_start_timestamp_s, or None if there are no such points.
291+
A tuple of (float, int) where the first element is the sum across
292+
all values found at `keys`, and the second is the number of valid
293+
keys used to compute the sum.
294+
Returns (None, 0) if no valid keys have data.
210295
"""
211-
points_after_idx = self._get_datapoints(key, window_start_timestamp_s)
296+
return self._aggregate_reduce(keys, sum)
212297

213-
if do_compact:
214-
self.data[key] = points_after_idx
298+
def aggregate_avg(
299+
self,
300+
keys: Iterable[Hashable],
301+
) -> Tuple[Optional[float], int]:
302+
"""Average the entire set of timeseries values across the specified keys.
215303
216-
return max((point.value for point in points_after_idx), default=None)
304+
Args:
305+
keys: Iterable of keys to aggregate across.
306+
Returns:
307+
A tuple of (float, int) where the first element is the mean across
308+
all values found at `keys`, and the second is the number of valid
309+
keys used to compute the mean.
310+
Returns (None, 0) if no valid keys have data.
311+
"""
312+
return self._aggregate_reduce(keys, statistics.mean)
313+
314+
315+
def _bucket_latest_by_window(
316+
series: List[TimeStampedValue],
317+
start: float,
318+
window_s: float,
319+
) -> Dict[int, float]:
320+
"""
321+
Map each window index -> latest value seen in that window.
322+
Assumes series is sorted by timestamp ascending.
323+
"""
324+
buckets: Dict[int, float] = {}
325+
for p in series:
326+
w = int((p.timestamp - start) // window_s)
327+
buckets[w] = p.value # overwrite keeps the latest within the window
328+
return buckets
329+
330+
331+
def _merge_two_timeseries(
332+
t1: List[TimeStampedValue], t2: List[TimeStampedValue], window_s: float
333+
) -> List[TimeStampedValue]:
334+
"""
335+
Merge two ascending time series by summing values within a specified time window.
336+
If multiple values fall within the same window in a series, the latest value is used.
337+
The output contains one point per window that had at least one value, timestamped
338+
at the window center.
339+
"""
340+
if window_s <= 0:
341+
raise ValueError(f"window_s must be positive, got {window_s}")
342+
343+
if not t1 and not t2:
344+
return []
345+
346+
# Align windows so each output timestamp sits at the start of its window.
347+
# start is snapped to window_s boundary for binning stability
348+
earliest = min(x[0].timestamp for x in (t1, t2) if x)
349+
start = earliest // window_s * window_s
350+
351+
b1 = _bucket_latest_by_window(t1, start, window_s)
352+
b2 = _bucket_latest_by_window(t2, start, window_s)
353+
354+
windows = sorted(set(b1.keys()) | set(b2.keys()))
355+
356+
merged: List[TimeStampedValue] = []
357+
for w in windows:
358+
v = b1.get(w, 0.0) + b2.get(w, 0.0)
359+
ts_start = start + w * window_s
360+
merged.append(TimeStampedValue(timestamp=ts_start, value=v))
361+
return merged
362+
363+
364+
def merge_timeseries_dicts(
365+
*timeseries_dicts: DefaultDict[Hashable, List[TimeStampedValue]],
366+
window_s: float,
367+
) -> DefaultDict[Hashable, List[TimeStampedValue]]:
368+
"""
369+
Merge multiple time-series dictionaries, typically contained within
370+
InMemoryMetricsStore().data. For the same key across stores, time series
371+
are merged with a windowed sum, where each series keeps only its latest
372+
value per window before summing.
373+
"""
374+
merged: DefaultDict[Hashable, List[TimeStampedValue]] = defaultdict(list)
375+
for timeseries_dict in timeseries_dicts:
376+
for key, ts in timeseries_dict.items():
377+
if key in merged:
378+
merged[key] = _merge_two_timeseries(merged[key], ts, window_s)
379+
else:
380+
# Window the data, even if the key is unique.
381+
merged[key] = _merge_two_timeseries(ts, [], window_s)
382+
return merged

python/ray/serve/_private/replica.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,11 +328,10 @@ def record_request_metrics(self, *, route: str, latency_ms: float, was_error: bo
328328

329329
def _push_autoscaling_metrics(self) -> Dict[str, Any]:
330330
look_back_period = self._autoscaling_config.look_back_period_s
331+
self._metrics_store.prune_keys_and_compact_data(time.time() - look_back_period)
331332
self._controller_handle.record_autoscaling_metrics.remote(
332333
replica_id=self._replica_id,
333-
window_avg=self._metrics_store.window_average(
334-
self._replica_id, time.time() - look_back_period
335-
),
334+
window_avg=self._metrics_store.aggregate_avg([self._replica_id])[0],
336335
send_timestamp=time.time(),
337336
)
338337

python/ray/serve/_private/router.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@
4242
SERVE_LOGGER_NAME,
4343
)
4444
from ray.serve._private.long_poll import LongPollClient, LongPollNamespace
45-
from ray.serve._private.metrics_utils import InMemoryMetricsStore, MetricsPusher
45+
from ray.serve._private.metrics_utils import (
46+
QUEUED_REQUESTS_KEY,
47+
InMemoryMetricsStore,
48+
MetricsPusher,
49+
)
4650
from ray.serve._private.replica_result import ReplicaResult
4751
from ray.serve._private.request_router import PendingRequest, RequestRouter
4852
from ray.serve._private.request_router.pow_2_router import (
@@ -61,9 +65,6 @@
6165
logger = logging.getLogger(SERVE_LOGGER_NAME)
6266

6367

64-
QUEUED_REQUESTS_KEY = "queued"
65-
66-
6768
class RouterMetricsManager:
6869
"""Manages metrics for the router."""
6970

@@ -392,10 +393,11 @@ def _get_aggregated_requests(self):
392393
running_requests = dict()
393394
if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE and self.autoscaling_config:
394395
look_back_period = self.autoscaling_config.look_back_period_s
396+
self.metrics_store.prune_keys_and_compact_data(
397+
time.time() - look_back_period
398+
)
395399
running_requests = {
396-
replica_id: self.metrics_store.window_average(
397-
replica_id, time.time() - look_back_period
398-
)
400+
replica_id: self.metrics_store.aggregate_avg([replica_id])[0]
399401
# If data hasn't been recorded yet, return current
400402
# number of queued and ongoing requests.
401403
or num_requests

0 commit comments

Comments
 (0)