Skip to content
16 changes: 13 additions & 3 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def cloud_ai_100_exec_kv(
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
):
"""
Expand Down Expand Up @@ -356,6 +357,8 @@ def cloud_ai_100_exec_kv(
next tokens. For Speculative Decoding Target Language Model,
`return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative
Decoding Draft Language Model and `return_pdfs`=False for regular model.
:include_guided_decoding (bool, default=False): If True, enables guided token-level filtering
during decoding. Only works when `include_sampler`=True.
sampling_params (Dict[str, Any], default=None): A dictionary of sampling parameters supported by the QAIC backend.
The dictionary should contain the following keys:
`repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`,
Expand Down Expand Up @@ -394,6 +397,7 @@ def cloud_ai_100_exec_kv(
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
)

Expand Down Expand Up @@ -442,6 +446,7 @@ def __init__(
is_tlm: Optional[int] = None,
include_sampler: bool = False,
return_pdfs: bool = False,
include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
activate: bool = True,
) -> None:
Expand All @@ -451,6 +456,7 @@ def __init__(
self._write_io_dir = write_io_dir
self.is_tlm = is_tlm
self.return_pdfs = return_pdfs
self.include_guided_decoding = include_guided_decoding
self.sampling_params = sampling_params
self._qpc_path = qpc_path # Store qpc_path for later use

Expand All @@ -461,7 +467,9 @@ def __init__(

# Validate sampler inputs for On-Device Sampling
self.include_sampler = validate_sampler_inputs(
session_inputs=set(self._session.input_names), include_sampler=include_sampler
session_inputs=set(self._session.input_names),
include_sampler=include_sampler,
include_guided_decoding=include_guided_decoding,
)

# Fetch the variables from the QPC
Expand Down Expand Up @@ -628,7 +636,7 @@ def prepare_decode_inputs(self):
decode_inputs["batch_index"] = self.batch_index
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
for op in Constants.SAMPLER_OPS:
for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
if self.batch_index is not None:
decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()]
else:
Expand Down Expand Up @@ -795,7 +803,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
inputs["num_logits_to_keep"] = np.zeros((1, 1))
if self.include_sampler:
inputs["last_accepted_output_tokens"] = inputs["input_ids"]
for op in Constants.SAMPLER_OPS:
for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
if decode_batch_id is not None:
inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
else:
Expand Down Expand Up @@ -1067,6 +1075,7 @@ def __init__(
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
) -> None:
self._qaic_model = QEffTextGenerationBase(
Expand All @@ -1082,6 +1091,7 @@ def __init__(
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
)
self._full_batch_size = self._qaic_model.full_batch_size
Expand Down
139 changes: 135 additions & 4 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,20 +722,38 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model, **kwargs):
def __init__(self, model, continuous_batching: bool = False, qaic_config: Optional[dict] = None, **kwargs):
"""
Initializes the language decoder component for multimodal models.

Parameters
----------
model : nn.Module
The full HuggingFace multimodal model from which the language decoder is extracted.
continuous_batching : bool, optional
If True, enables continuous batching mode for future compilation and execution.
This setting must be consistent across `from_pretrained` and `compile` calls. Default is False.
qaic_config : dict, optional
A dictionary for QAIC-specific configurations.
Only the following keys are supported by the text model of the dual QPC multimodal model:
- **include_sampler** (bool): If True, enables on-device sampling of next tokens.
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
- **include_guided_decoding** (bool): If True, enables guided token-level filtering
during decoding. Only works when include_sampler=True.
Additional keys will be ignored.
**kwargs :
Additional keyword arguments passed to the base class constructor.
"""
super().__init__(model, **kwargs)
self.model = model.get_qeff_language_decoder()
self.hash_params["qeff_auto_class"] = self.__class__.__name__
self.continuous_batching = continuous_batching
self.model.qaic_config = qaic_config
# ---Sampling---
# Note: SamplerTransform should be applied after all other transforms
# are done. The role of the sampler is to just add nodes at the output of the
# previous transform function.
self.model, _ = SamplerTransform.apply(self.model, qaic_config, **kwargs)

def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True):
"""
Expand All @@ -759,10 +777,104 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt
str
Path to the generated ONNX graph file for the language decoder.
"""
if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False):
inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(
inputs, output_names, dynamic_axes
)
return self._export(
inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights
)

def get_sampling_inputs_and_outputs(
self,
example_inputs: Dict[str, torch.Tensor],
output_names: List[str],
dynamic_axes: Dict[str, Dict[int, str]],
):
"""
Updates the example inputs, output names, and dynamic axes to include
parameters relevant for on-device sampling during ONNX export.

Parameters
----------
example_inputs : Dict[str, torch.Tensor]
Current dictionary of example inputs.
output_names : List[str]
Current list of output names.
dynamic_axes : Dict[str, Dict[int, str]]
Current dictionary of dynamic axes configurations.

Returns
-------
Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
Updated example inputs, output names, and dynamic axes including
sampling-related parameters.
"""
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS

assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling"

logits_index = output_names.index("logits")
output_names[logits_index] = "next_tokens"

example_inputs["last_accepted_output_tokens"] = torch.zeros(
(bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64
)
dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"}

example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
(fbs if self.continuous_batching else bs, self.model.language_model.config.vocab_size), dtype=torch.bool
)
dynamic_axes["past_repetition_penalty_buffer"] = {
0: "full_batch_size" if self.continuous_batching else "batch_size",
}
output_names.append("past_repetition_penalty_buffer_RetainedState")

example_inputs["repetition_penalties"] = (
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
)
dynamic_axes["repetition_penalties"] = {0: "batch_size"}

example_inputs["past_presence_penalty_buffer"] = torch.zeros(
(fbs if self.continuous_batching else bs, self.model.language_model.config.vocab_size), dtype=torch.bool
)
dynamic_axes["past_presence_penalty_buffer"] = {
0: "full_batch_size" if self.continuous_batching else "batch_size",
}
output_names.append("past_presence_penalty_buffer_RetainedState")

example_inputs["presence_penalties"] = (
torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
)
dynamic_axes["presence_penalties"] = {0: "batch_size"}

example_inputs["temperatures"] = (
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES
)
dynamic_axes["temperatures"] = {0: "batch_size"}

max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS)
example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32)
dynamic_axes["top_ks"] = {0: "batch_size"}

example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS
dynamic_axes["top_ps"] = {0: "batch_size"}

example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
dynamic_axes["min_ps"] = {0: "batch_size"}

example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float)
dynamic_axes["random_numbers"] = {0: "batch_size"}

if self.model.qaic_config.get("include_guided_decoding", False):
example_inputs["token_bitmasks"] = torch.zeros(
(bs, self.model.language_model.config.vocab_size), dtype=torch.bool
)
dynamic_axes["token_bitmasks"] = {0: "batch_size"}

return example_inputs, output_names, dynamic_axes

def compile(
self,
compile_dir,
Expand Down Expand Up @@ -887,7 +999,7 @@ def __init__(
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)

self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs)
self.lang_model = QEffCausalLMForTextImageToTextModel(model, continuous_batching=continuous_batching, **kwargs)
self.continuous_batching = continuous_batching
self.input_shapes, self.output_names = None, None

Expand Down Expand Up @@ -1556,6 +1668,8 @@ def __init__(
"""
if kwargs.pop("full_batch_size", None):
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
if kwargs.pop("qaic_config", None):
raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.")
super().__init__(model, **kwargs)

self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)
Expand Down Expand Up @@ -2155,6 +2269,14 @@ def from_pretrained(
If True, uses the dual QPC approach (vision encoder KV offloaded).
If False, uses the single QPC approach (entire model in one QPC).
If None, the default behavior of the internal classes is used (typically dual QPC).
qaic_config : dict, optional
A dictionary for QAIC-specific configurations.
Only the following keys are supported by the text model of the dual QPC multimodal model:
- **include_sampler** (bool): If True, enables on-device sampling of next tokens.
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
- **include_guided_decoding** (bool): If True, enables guided token-level filtering
during decoding. Only works when include_sampler=True.
Additional keys will be ignored.
**kwargs :
Additional arguments passed to HuggingFace's ``from_pretrained``.

Expand Down Expand Up @@ -2182,7 +2304,8 @@ def from_pretrained(
logger.warning("Updating low_cpu_mem_usage=False")

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})

if qaic_config is not None:
qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(
model,
Expand Down Expand Up @@ -2258,6 +2381,8 @@ def __init__(
- **return_pdfs** (bool): If True, returns probability distributions along with sampled tokens.
For Speculative Decoding Target Language Models, this is always True.
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
- **include_guided_decoding** (bool): If True, enables guided token-level filtering
during decoding. Only works when include_sampler=True.
**kwargs :
Additional keyword arguments passed to the base class constructor.

Expand Down Expand Up @@ -2360,6 +2485,8 @@ def from_pretrained(
and ``return_pdfs=False`` for regular model.
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
The values provided in ``top_ks`` tensor must be less than this maximum limit.
- **include_guided_decoding** (bool): If True, enables guided token-level filtering
during decoding. Only works when include_sampler=True.

*args :
Positional arguments passed directly to `cls._hf_auto_class.from_pretrained`.
Expand Down Expand Up @@ -2608,9 +2735,13 @@ def get_sampling_inputs_and_outputs(
example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
dynamic_axes["min_ps"] = {0: "batch_size"}

example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float)
example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float)
dynamic_axes["random_numbers"] = {0: "batch_size"}

if self.model.qaic_config.get("include_guided_decoding", False):
example_inputs["token_bitmasks"] = torch.zeros((bs, self.model.config.vocab_size), dtype=torch.bool)
dynamic_axes["token_bitmasks"] = {0: "batch_size"}

return example_inputs, output_names, dynamic_axes

def build_prefill_specialization(
Expand Down
4 changes: 4 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@
QEffGrok1MultiHeadAttention,
)
from QEfficient.transformers.models.internvl.modeling_internvl import (
QEffInternDecoderWrapper,
QEffInternVisionEmbeddings,
QEffInternVLModel,
)
Expand Down Expand Up @@ -392,6 +393,7 @@
QEffQwen2_5_VLModel,
QEffQwen2_5_VLTextModel,
QEffQwen2_5_VLVisionAttention,
QEffQwen_2_5_vl_DecoderWrapper,
QEffQwen_2_5_vl_ForConditionalGeneration,
)
from QEfficient.transformers.models.qwen3.modeling_qwen3 import (
Expand Down Expand Up @@ -707,10 +709,12 @@ class SamplerTransform:
QEffGPTJForCausalLM,
QEffGraniteForCausalLM,
QEffGraniteMoeForCausalLM,
QEffInternDecoderWrapper,
QEffLlamaForCausalLM,
QEffMptForCausalLM,
QEffPhi3ForCausalLM,
QEffQwen2ForCausalLM,
QEffQwen_2_5_vl_DecoderWrapper,
}

@classmethod
Expand Down
Loading
Loading