From 54300eb87fb8a29e1aa81b9766897bb937b6bd53 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Wed, 21 May 2025 03:16:09 +0000 Subject: [PATCH 1/7] feat: Skip sampler for intermediate pp stages. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- .../_torch/distributed/communicator.py | 57 ++-------- tensorrt_llm/_torch/pyexecutor/py_executor.py | 103 +++++++++--------- tensorrt_llm/_torch/pyexecutor/sampler.py | 19 ++-- tensorrt_llm/_utils.py | 21 +++- 4 files changed, 93 insertions(+), 107 deletions(-) diff --git a/tensorrt_llm/_torch/distributed/communicator.py b/tensorrt_llm/_torch/distributed/communicator.py index 82501d1fcd4..83eb7157495 100644 --- a/tensorrt_llm/_torch/distributed/communicator.py +++ b/tensorrt_llm/_torch/distributed/communicator.py @@ -1,14 +1,15 @@ import os from abc import ABC, abstractmethod -from typing import Optional, ValuesView +from typing import Optional import numpy as np import torch import torch.distributed as dist from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_broadcast, - mpi_comm, mpi_isend, mpi_recv, mpi_send, - torch_dtype_to_np) + mpi_comm, mpi_isend, mpi_isend_object, + mpi_recv, mpi_recv_object, mpi_send, + mpi_send_object) from tensorrt_llm.mapping import Mapping @@ -121,48 +122,14 @@ def recv(self, buf: np.ndarray, src, tag=0): # in-place recv numpy buffer return mpi_recv(buf, src, tag) - def isend_tensor(self, tensor: torch.Tensor, dest, tag=0): - return self.isend(tensor.numpy(), dest, tag) - - def recv_tensor(self, tensor: torch.Tensor, src, tag=0): - return self.recv(tensor.numpy(), src, tag) - - def isend_tensor_list(self, - tensor_list: ValuesView[torch.Tensor], - dest, - tag=0): - if len(tensor_list) == 0: - return None - elif len(tensor_list) == 1: - return self.isend_tensor(next(iter(tensor_list)), dest, tag) - - return self.isend( - np.concatenate([t.numpy().ravel() for t in tensor_list]), dest, tag) - - def recv_tensor_list(self, - tensor_list: ValuesView[torch.Tensor], - src, - tag=0): - if len(tensor_list) == 0: - return None - elif len(tensor_list) == 1: - return self.recv_tensor(next(iter(tensor_list)), src, tag) - - # Use the first tensor's dtype to infer the buffer dtype - numpy_dtype = torch_dtype_to_np(next(iter(tensor_list)).dtype) - # Prepare buffer to receive tensor_list - recv_buffer = np.empty(sum([t.numel() for t in tensor_list]), - dtype=numpy_dtype) - # Receive tensors - self.recv(recv_buffer, src, tag) - # Assign to tensor_list - offset = 0 - for t in tensor_list: - t.copy_( - torch.from_numpy(recv_buffer[offset:offset + - t.numel()]).reshape(t.shape)) - offset += t.numel() - return None + def send_object(self, obj, dest, tag=0): + mpi_send_object(obj, dest, tag) + + def isend_object(self, obj, dest, tag=0): + return mpi_isend_object(obj, dest, tag) + + def recv_object(self, src, tag=0): + return mpi_recv_object(src, tag) def create_tp_comm(self): new_group = mpi_comm().group.Incl(self.mapping.tp_group) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 92f54cc53c9..72612c11b2a 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -14,8 +14,6 @@ from itertools import chain from typing import Dict, List, Optional, Tuple, Union -import dill # nosec B403 -import numpy as np import torch from tensorrt_llm._utils import (global_mpi_rank, is_trace_enabled, nvtx_range, @@ -656,6 +654,15 @@ def _executor_loop_cleanup(self): self.response_cv.notify_all() self.shutdown_event.set() + def _need_return_logits(self, scheduled_requests: ScheduledRequests): + for req in scheduled_requests.context_requests: + if req.py_return_generation_logits: + return True + for req in scheduled_requests.generation_requests: + if req.py_return_generation_logits: + return True + return False + def _executor_loop_pp(self): torch.cuda.set_device(self.device_id) got_finish_signal = False @@ -716,8 +723,13 @@ def _executor_loop_pp(self): else: with torch.cuda.nvtx.range("_forward_step_last_pp"): batch_outputs = self._forward_step(scheduled_batch) + logits_host = None + if self._need_return_logits(scheduled_batch): + logits_host = batch_outputs["logits"].to( + "cpu", non_blocking=True) sample_state = self._sample_async( scheduled_batch, batch_outputs) + sample_state.logits_host = logits_host self._update_request_states(scheduled_batch) if self.enable_iter_perf_stats: @@ -737,24 +749,33 @@ def _executor_loop_pp(self): # Stage 2: Communicate new tokens for previous batch between ranks # send/recv chain: (pp_size - 1) -> 0 -> 1 -> ... -> (pp_size - 2) - # last rank: sync decoder for previous microbatch to start new tokens comm chain. + # last rank: sync sampler for previous microbatch to start new tokens comm chain. # other ranks: send/recv tokens for next microbatch to allow overlap offset = -1 if self.dist.is_last_pp_rank else 1 prev_microbatch_id = (microbatch_id + offset) % self.num_micro_batches previous_batch = self.micro_batches[prev_microbatch_id] if previous_batch is not None: + sample_state = previous_batch.sample_state if not self.dist.is_last_pp_rank: torch.cuda.nvtx.range_push( "_handle_new_tokens_inter_pp") # Receive tokens from previous pp rank (w.r.t model forward direction) - self.dist.recv_tensor_list( - previous_batch.sample_state.host.values(), + ( + logits, + sample_state.log_probs, + sample_state.new_tensors_host, + ) = self.dist.recv_object( src=self.dist.prev_pp_rank, - tag=prev_microbatch_id) + tag=prev_microbatch_id, + ) + if logits is not None: + logits_host = torch.from_numpy(logits) + sample_state.logits_host = logits_host + sample_state.logits = logits_host.to(self.device_id) else: torch.cuda.nvtx.range_push("_handle_new_tokens_last_pp") - previous_batch.sample_state.sampler_event.synchronize() + sample_state.sampler_event.synchronize() # Send tokens to next pp rank (w.r.t model forward direction) # Second last rank does not need to since last rank has original decoded tokens @@ -762,8 +783,14 @@ def _executor_loop_pp(self): if self.send_handles[prev_microbatch_id] is not None: self.send_handles[prev_microbatch_id].Wait() self.send_handles[ - prev_microbatch_id] = self.dist.isend_tensor_list( - previous_batch.sample_state.host.values(), + prev_microbatch_id] = self.dist.isend_object( + ( + sample_state.logits_host.numpy() + if sample_state.logits_host is not None else + None, + sample_state.log_probs, + sample_state.new_tensors_host, + ), dest=self.dist.next_pp_rank, tag=prev_microbatch_id) torch.cuda.nvtx.range_pop() @@ -896,7 +923,7 @@ def _executor_loop(self): self._update_requests(sample_state) if self.kv_cache_transceiver: - # For context only req in transmission, we reset the state since decoder might have changed it + # For context only req in transmission, we reset the state since sampler might have changed it for req in ctx_transmission_reqs: req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS @@ -1101,11 +1128,15 @@ def _process_previous_batch(self): @nvtx_range("_forward_step_inter_pp") def _forward_step_inter_pp(self, scheduled_batch) -> SampleState: - batch_outputs = self._forward_step(scheduled_batch) - sample_state = self._sample_async(scheduled_batch, batch_outputs) + self._forward_step(scheduled_batch) + sampler_event = torch.cuda.Event() + sampler_event.record() self._update_request_states(scheduled_batch) - sample_state.sampler_event.synchronize() - return sample_state + sampler_event.synchronize() + return SampleState( + scheduled_requests=scheduled_batch, + sampler_event=sampler_event, + ) def _update_new_active_requests_queue_latency(self, new_requests): if self.enable_iter_perf_stats and self.dist.rank == 0: @@ -1126,12 +1157,10 @@ def _broadcast_new_requests( """Broadcasts new_requests and optional Python-only metadata (`py_request_objects`) across pipeline stages. `py_request_objects` is a tuple of (attribute_name, {request_id: object}). """ - payloads = (new_requests, py_request_objects - ) if py_request_objects is not None else new_requests + payloads = (new_requests, py_request_objects) if not self.dist.has_pp: - result = self.dist.broadcast(payloads, root=0) - return result if isinstance(result, tuple) else (result, None) + return self.dist.broadcast(payloads, root=0) # broadcast within first tp group before send/recv chain to other tp groups if self.dist.tp_size > 1 and self.dist.is_first_pp_rank: @@ -1140,42 +1169,14 @@ def _broadcast_new_requests( # tag = [0, num_micro_batches - 1] used for new_tokens send/recv tag = self.num_micro_batches - # 1. send metadata: len(num_requests) and serialized buffer size - new_requests = payloads[0] if isinstance(payloads, tuple) else payloads - if self.dist.is_first_pp_rank and len(new_requests) > 0: - buf = np.array(bytearray(dill.dumps(payloads))) - buf_size = len(buf) - else: - buf, buf_size = None, 0 - metadata_arr = np.array([len(new_requests), buf_size]) - + # send payloads if not self.dist.is_first_pp_rank: - self.dist.recv(metadata_arr, self.dist.prev_pp_rank, tag) + payloads = self.dist.recv_object(self.dist.prev_pp_rank, tag) if not self.dist.is_last_pp_rank: - self.dist.send(metadata_arr, self.dist.next_pp_rank, tag) - - # 2. send serialized buffer when new requests is not empty - num_new_requests = metadata_arr[0] - if num_new_requests > 0: - buf_size = metadata_arr[1] - if not self.dist.is_first_pp_rank: - buf = np.array(bytearray(buf_size)) - self.dist.recv(buf, self.dist.prev_pp_rank, tag) - - if not self.dist.is_last_pp_rank: - self.dist.send(buf, self.dist.next_pp_rank, tag) - - if not self.dist.is_first_pp_rank: - buf_data = dill.loads(buf.tobytes()) # nosec B301 - if isinstance(buf_data, tuple): - new_requests, py_request_objects = buf_data - else: - new_requests = buf_data - - assert len(new_requests) == num_new_requests + self.dist.send_object(payloads, self.dist.next_pp_rank, tag) - return new_requests, py_request_objects + return payloads @nvtx_range("_fetch_new_requests") def _fetch_new_requests(self): @@ -1979,7 +1980,7 @@ def _handle_cancelled_requests(self): self.active_requests = left_requests # enqueue the cancelled requests' responses as they are not - # active_requests and be discarded in the decoder loop. + # active_requests and be discarded in the sampler loop. self._enqueue_responses(cancelled_responses) @nvtx_range("_enqueue_responses") diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index e94258f2308..eb121bf7e30 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -37,6 +37,7 @@ class SampleState: scheduled_requests: ScheduledRequests logits: torch.Tensor = None + logits_host: torch.Tensor = None # Set when decode_async() has evaluated these to avoid computing again in update_requests() # log_probs[request_idx][token_idx] @@ -81,13 +82,14 @@ def update_requests(self, state: SampleState) -> None: request.state = LlmRequestState.GENERATION_COMPLETE # NOTE: This is a hack: set finish reason manually and set the beam 0 request.set_finished_reason(FinishReason.LENGTH, 0) - logits = state.logits[idx] - if logits.ndim == 1: - # For BERT: Add axis to be compatible with LogitsStorage - # (LogitsStorage will interpret this dim as the prompt_len which - # is not relevant for outputting logits of encoder only model). - logits = logits.unsqueeze(0) - request.py_result.append_context_logits(logits) + if request.py_return_context_logits: + logits = state.logits[idx] + if logits.ndim == 1: + # For BERT: Add axis to be compatible with LogitsStorage + # (LogitsStorage will interpret this dim as the prompt_len which + # is not relevant for outputting logits of encoder only model). + logits = logits.unsqueeze(0) + request.py_result.append_context_logits(logits) def top_k_sampling_batch(logits, top_k=50): @@ -409,7 +411,8 @@ def update_one_request(self, request: LlmRequest, num_tokens = request.add_new_token(new_token, beam_idx) current_logits = logits[output_token_idx].unsqueeze(0) - request.py_result.append_generation_logits(current_logits) + if request.py_return_generation_logits: + request.py_result.append_generation_logits(current_logits) if request.py_return_log_probs: _, log_probs = greedy_search_sampling_batch(current_logits) request.py_result.append_log_probs([[{ diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index aa11a4c14fe..efe76e14b8b 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -523,6 +523,23 @@ def mpi_recv(buf, source, tag): return None +def mpi_send_object(obj, dest, tag=0): + if ENABLE_MULTI_DEVICE: + mpi_comm().send(obj, dest=dest, tag=tag) + + +def mpi_isend_object(obj, dest, tag=0): + if ENABLE_MULTI_DEVICE: + return mpi_comm().isend(obj, dest=dest, tag=tag) + return None + + +def mpi_recv_object(source, tag): + if ENABLE_MULTI_DEVICE: + return mpi_comm().recv(source=source, tag=tag) + return None + + def pad_vocab_size(vocab_size, tp_size): return int(math.ceil(vocab_size / tp_size) * tp_size) @@ -647,7 +664,6 @@ def trace_func(func): @wraps(func) def wrapper(*args, **kwargs): - import dill # nosec B403 def globaltrace(frame, why, arg): if why == "call": @@ -676,8 +692,7 @@ def localtrace(frame, why, arg): return localtrace ignoredirs = [ - os.path.dirname(package.__file__) - for package in [os, torch, trace, dill] + os.path.dirname(package.__file__) for package in [os, torch, trace] ] tracer = trace.Trace(trace=1, count=0, ignoredirs=ignoredirs) rank = global_mpi_rank() From 6bdecbb0905716c71834287c480ac409232ff185 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Wed, 21 May 2025 04:43:28 +0000 Subject: [PATCH 2/7] Address comment. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 72612c11b2a..aff1f34a941 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -663,6 +663,15 @@ def _need_return_logits(self, scheduled_requests: ScheduledRequests): return True return False + def _need_return_log_probs(self, scheduled_requests: ScheduledRequests): + for req in scheduled_requests.context_requests: + if req.py_return_log_probs: + return True + for req in scheduled_requests.generation_requests: + if req.py_return_log_probs: + return True + return False + def _executor_loop_pp(self): torch.cuda.set_device(self.device_id) got_finish_signal = False @@ -785,9 +794,12 @@ def _executor_loop_pp(self): self.send_handles[ prev_microbatch_id] = self.dist.isend_object( ( - sample_state.logits_host.numpy() - if sample_state.logits_host is not None else - None, + sample_state.logits_host.numpy() if + self._need_return_logits(scheduled_batch) or + (self._need_return_log_probs( + scheduled_batch) + and sample_state.log_probs is not None) + else None, sample_state.log_probs, sample_state.new_tensors_host, ), From dc075bef8a78be492cfc10ee0f2309f7568bbe08 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Thu, 22 May 2025 03:38:53 +0000 Subject: [PATCH 3/7] Address comment. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index aff1f34a941..349bb13aa13 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -656,7 +656,7 @@ def _executor_loop_cleanup(self): def _need_return_logits(self, scheduled_requests: ScheduledRequests): for req in scheduled_requests.context_requests: - if req.py_return_generation_logits: + if req.py_return_context_logits: return True for req in scheduled_requests.generation_requests: if req.py_return_generation_logits: From 5d703fc8074666a9b57430730b1bc518079dc6cc Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Fri, 23 May 2025 04:39:04 +0000 Subject: [PATCH 4/7] Fix CI error. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index eb121bf7e30..84f59f62c4e 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -32,7 +32,7 @@ def values(self): return vars(self).values() -@dataclass(frozen=True, kw_only=True) +@dataclass(kw_only=True) class SampleState: scheduled_requests: ScheduledRequests From d653814fadebff5e6008aa51c581badf8503b676 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Fri, 23 May 2025 05:12:34 +0000 Subject: [PATCH 5/7] Fix CI error. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/sampler.py | 2 +- tensorrt_llm/_torch/speculative/mtp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 84f59f62c4e..191033bd3d9 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -455,7 +455,7 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors): sequence_lengths: torch.Tensor -@dataclass(frozen=True, kw_only=True) +@dataclass(kw_only=True) class SampleStateTRTLLM(SampleState): host: SampleStateTensorsHostTRTLLM device: SampleStateTensors diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 144633f4c40..098faee6867 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -20,7 +20,7 @@ class SampleStateTensorsMTP(SampleStateTensors): next_draft_tokens: torch.Tensor -@dataclass(frozen=True, kw_only=True) +@dataclass(kw_only=True) class SampleStateMTP(SampleState): device: SampleStateTensorsMTP host: SampleStateTensorsMTP From 7d8acd2f6a12b4fb20d3ab23af43ac7193ac5dbd Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Fri, 23 May 2025 05:41:00 +0000 Subject: [PATCH 6/7] Fix rebase error. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 349bb13aa13..86466f9c746 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -773,7 +773,7 @@ def _executor_loop_pp(self): ( logits, sample_state.log_probs, - sample_state.new_tensors_host, + sample_state.host, ) = self.dist.recv_object( src=self.dist.prev_pp_rank, tag=prev_microbatch_id, @@ -801,7 +801,7 @@ def _executor_loop_pp(self): and sample_state.log_probs is not None) else None, sample_state.log_probs, - sample_state.new_tensors_host, + sample_state.host, ), dest=self.dist.next_pp_rank, tag=prev_microbatch_id) From 783bcb0cffa1095b9b6b437cf4ff4b721f869a4a Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Fri, 23 May 2025 15:51:29 +0000 Subject: [PATCH 7/7] Fix CI error. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 2 +- tensorrt_llm/_torch/pyexecutor/sampler.py | 3 +++ tensorrt_llm/_torch/speculative/mtp.py | 2 ++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 86466f9c746..c343f76273b 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1145,7 +1145,7 @@ def _forward_step_inter_pp(self, scheduled_batch) -> SampleState: sampler_event.record() self._update_request_states(scheduled_batch) sampler_event.synchronize() - return SampleState( + return self.sampler.SampleState( scheduled_requests=scheduled_batch, sampler_event=sampler_event, ) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 191033bd3d9..e7992009d72 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -51,6 +51,8 @@ class SampleState: class Sampler(ABC): + SampleState = SampleState + def setup_sampler_step(self, scheduled_requests: ScheduledRequests): pass @@ -463,6 +465,7 @@ class SampleStateTRTLLM(SampleState): class TRTLLMSampler(Sampler): MAX_DECODING_TOKENS = 1 # It must be 1 when not in speculative decoding + SampleState = SampleStateTRTLLM def __init__( self, diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 098faee6867..6f5ed5bc7aa 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -221,6 +221,8 @@ class MTPSampler(TorchSampler): MTP sampler. """ + SampleState = SampleStateMTP + def __init__(self, max_seq_len: int, config: MTPConfig): super().__init__(max_seq_len, False) self.mapping = None