|
38 | 38 | from ..models.modeling_utils import DecoderModelForCausalLM |
39 | 39 | from ..modules.decoder_layer import DecoderLayer |
40 | 40 | from ..speculative.drafter import Drafter |
| 41 | +from ..speculative.speculation_gate import SpeculationGate |
41 | 42 | from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem |
42 | 43 | from .guided_decoder import GuidedDecoder |
43 | 44 | from .handle_logits import HandleLogits |
@@ -207,6 +208,20 @@ def __init__(self, |
207 | 208 | self.num_fetch_requests = 0 |
208 | 209 | self.shutdown_event = threading.Event() |
209 | 210 |
|
| 211 | + # Rolling acceptance tracking for spec decode (disable speculation if rolling acceptance is below threshold) |
| 212 | + spec_config = getattr(self.model_engine, 'spec_config', None) |
| 213 | + self.acceptance_window = getattr( |
| 214 | + spec_config, 'acceptance_window', |
| 215 | + None) if spec_config is not None else None |
| 216 | + self.acceptance_length_threshold = getattr( |
| 217 | + spec_config, 'acceptance_length_threshold', |
| 218 | + None) if spec_config is not None else None |
| 219 | + self.speculation_permanently_disabled = False |
| 220 | + self.speculation_gate = None |
| 221 | + if self.acceptance_window and self.acceptance_length_threshold is not None: |
| 222 | + self.speculation_gate = SpeculationGate( |
| 223 | + self.acceptance_window, self.acceptance_length_threshold) |
| 224 | + |
210 | 225 | # response used data |
211 | 226 | self.response_lock = threading.Lock() |
212 | 227 | self.response_cv = threading.Condition(self.response_lock) |
@@ -969,15 +984,14 @@ def _prepare_and_schedule_batch(self): |
969 | 984 |
|
970 | 985 | if self.drafter is not None: |
971 | 986 | # Honor permanent disable flag based on rolling acceptance first |
972 | | - if getattr(self.model_engine, 'speculation_permanently_disabled', |
973 | | - False): |
| 987 | + if getattr(self, 'speculation_permanently_disabled', False): |
974 | 988 | self.use_spec_decode = False |
975 | 989 | else: |
976 | 990 | self.use_spec_decode = self.drafter.should_use_spec_decode( |
977 | 991 | self.active_requests, self.max_batch_size, |
978 | 992 | self.model_engine.max_num_tokens, |
979 | 993 | self.model_engine.spec_config.max_draft_len) |
980 | | - |
| 994 | + logger.debug(f"Use spec decode: {self.use_spec_decode}") |
981 | 995 | self.model_engine.enable_spec_decode = self.use_spec_decode |
982 | 996 |
|
983 | 997 | # When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch, |
@@ -1911,24 +1925,29 @@ def _handle_responses(self): |
1911 | 1925 | new_responses.append((req_id, response)) |
1912 | 1926 |
|
1913 | 1927 | if request_done: |
1914 | | - if (self.model_engine.enable_spec_decode and |
1915 | | - not self.model_engine.speculation_permanently_disabled |
| 1928 | + if (self.model_engine.enable_spec_decode |
| 1929 | + and not self.speculation_permanently_disabled |
1916 | 1930 | and not request.is_dummy and not self.is_warmup): |
1917 | | - if self.model_engine.speculation_gate is not None: |
| 1931 | + if self.speculation_gate is not None: |
1918 | 1932 | # Response handling runs on multiple PP ranks. Only the last PP rank performs |
1919 | 1933 | # sampling; restrict rolling stat updates to it to avoid overcounting. |
1920 | 1934 | if (not getattr(self.dist, 'has_pp', |
1921 | 1935 | False)) or self.dist.is_last_pp_rank: |
1922 | 1936 | avg_decoded = getattr( |
1923 | 1937 | request, 'avg_decoded_tokens_per_iter', None) |
1924 | | - disabled_now, _ = self.model_engine.speculation_gate.record_avg_decoded( |
1925 | | - avg_decoded, |
1926 | | - request_id=getattr(request, 'py_request_id', |
1927 | | - None)) |
1928 | | - if disabled_now: |
1929 | | - # disable speculation permanently |
1930 | | - # starting from next iteration, _prepare_and_schedule_batch will set self.use_spec_decode to False |
1931 | | - self.model_engine.speculation_permanently_disabled = True |
| 1938 | + if avg_decoded is not None: |
| 1939 | + disabled_now, _ = self.speculation_gate.record_avg_decoded( |
| 1940 | + avg_decoded, |
| 1941 | + request_id=getattr(request, 'py_request_id', |
| 1942 | + None)) |
| 1943 | + if disabled_now: |
| 1944 | + # disable speculation permanently |
| 1945 | + # starting from next iteration, _prepare_and_schedule_batch will set self.use_spec_decode to False |
| 1946 | + self.speculation_permanently_disabled = True |
| 1947 | + else: |
| 1948 | + logger.debug( |
| 1949 | + f"Request {request.py_request_id} has no avg_decoded_tokens_per_iter" |
| 1950 | + ) |
1932 | 1951 | if request.is_disagg_context_transmission_state: |
1933 | 1952 | self.ctx_in_transmission_requests.append(request) |
1934 | 1953 | else: |
|
0 commit comments