|  | 
| 10 | 10 | from vllm.worker.embedding_model_runner import ( | 
| 11 | 11 |     ModelInputForGPUWithPoolingMetadata) | 
| 12 | 12 | from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata | 
|  | 13 | +from vllm.worker.multi_step_model_runner import StatefulModelInput | 
| 13 | 14 | 
 | 
| 14 | 15 | 
 | 
| 15 | 16 | class MockAttentionBackend(AttentionBackend): | 
| @@ -154,3 +155,79 @@ def test_embedding_model_runner_input(): | 
| 154 | 155 |                        None) == getattr(attn_metadata, field.name, None) | 
| 155 | 156 |     # Pooling metadata is not broadcast. | 
| 156 | 157 |     assert received_model_input.pooling_metadata is None | 
|  | 158 | + | 
|  | 159 | + | 
|  | 160 | +def test_multi_step_model_runner_input(): | 
|  | 161 | +    sampling_metadata = SamplingMetadata( | 
|  | 162 | +        ["seq_group"], | 
|  | 163 | +        "selected_token_indices", | 
|  | 164 | +        "categorized_sample_indices", | 
|  | 165 | +        "num_prompts", | 
|  | 166 | +    ) | 
|  | 167 | +    attn_metadata = AttentionMetadata( | 
|  | 168 | +        num_prefills=1, | 
|  | 169 | +        num_prefill_tokens=2, | 
|  | 170 | +        num_decode_tokens=3, | 
|  | 171 | +        slot_mapping=torch.zeros(1), | 
|  | 172 | +    ) | 
|  | 173 | +    frozen_model_input = ModelInputForGPUWithSamplingMetadata( | 
|  | 174 | +        input_tokens=torch.ones(10), | 
|  | 175 | +        input_positions=torch.ones(10), | 
|  | 176 | +        sampling_metadata=sampling_metadata, | 
|  | 177 | +        attn_metadata=attn_metadata) | 
|  | 178 | + | 
|  | 179 | +    model_input = StatefulModelInput( | 
|  | 180 | +        frozen_model_input=frozen_model_input, | 
|  | 181 | +        is_last_step=True, | 
|  | 182 | +        is_first_multi_step=False, | 
|  | 183 | +        current_step=4, | 
|  | 184 | +        last_sampled_token_ids=torch.ones((10, 1)), | 
|  | 185 | +        is_multi_step=True, | 
|  | 186 | +        num_queries=8, | 
|  | 187 | +        num_seqs=5, | 
|  | 188 | +        cached_outputs=[], | 
|  | 189 | +    ) | 
|  | 190 | + | 
|  | 191 | +    assert isinstance(model_input, StatefulModelInput) | 
|  | 192 | + | 
|  | 193 | +    # Test round trip serialization. | 
|  | 194 | +    tensor_dict = model_input.as_broadcastable_tensor_dict() | 
|  | 195 | +    attn_backend = MockAttentionBackend() | 
|  | 196 | +    received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict( | 
|  | 197 | +        tensor_dict, attn_backend=attn_backend)) | 
|  | 198 | + | 
|  | 199 | +    receieved_frozen_input = received_model_input.frozen_model_input | 
|  | 200 | + | 
|  | 201 | +    # Check that received copy has correct values. | 
|  | 202 | +    assert isinstance(received_model_input, StatefulModelInput) | 
|  | 203 | +    assert receieved_frozen_input.input_tokens is not None | 
|  | 204 | +    assert (receieved_frozen_input.input_tokens == | 
|  | 205 | +            frozen_model_input.input_tokens).all() | 
|  | 206 | +    assert receieved_frozen_input.input_positions is not None | 
|  | 207 | +    assert (receieved_frozen_input.input_positions == | 
|  | 208 | +            frozen_model_input.input_positions).all() | 
|  | 209 | +    assert receieved_frozen_input.multi_modal_kwargs is None | 
|  | 210 | +    assert (frozen_model_input.multi_modal_kwargs == | 
|  | 211 | +            frozen_model_input.multi_modal_kwargs) | 
|  | 212 | +    assert receieved_frozen_input.lora_requests is None | 
|  | 213 | +    assert (receieved_frozen_input.lora_requests == | 
|  | 214 | +            frozen_model_input.lora_requests) | 
|  | 215 | +    assert receieved_frozen_input.lora_mapping is None | 
|  | 216 | +    assert ( | 
|  | 217 | +        receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping) | 
|  | 218 | +    for field in dataclasses.fields(AttentionMetadata): | 
|  | 219 | +        assert getattr(receieved_frozen_input.attn_metadata, field.name, | 
|  | 220 | +                       None) == getattr(attn_metadata, field.name, None) | 
|  | 221 | +    # For sampling metadata, only selected_token_indices is copied. | 
|  | 222 | +    assert (receieved_frozen_input.sampling_metadata.selected_token_indices == | 
|  | 223 | +            sampling_metadata.selected_token_indices) | 
|  | 224 | +    assert receieved_frozen_input.sampling_metadata.seq_groups is None | 
|  | 225 | + | 
|  | 226 | +    # check non frozen fields | 
|  | 227 | +    assert received_model_input.is_last_step == model_input.is_last_step | 
|  | 228 | +    assert (received_model_input.is_first_multi_step == | 
|  | 229 | +            model_input.is_first_multi_step) | 
|  | 230 | +    assert received_model_input.current_step == model_input.current_step | 
|  | 231 | +    assert (received_model_input.last_sampled_token_ids == | 
|  | 232 | +            model_input.last_sampled_token_ids).all() | 
|  | 233 | +    assert received_model_input.is_multi_step == model_input.is_multi_step | 
0 commit comments