diff --git a/csrc/xpu/attention_xpu.cpp b/csrc/xpu/attention_xpu.cpp index 4fc9afe78680..3f4d71e9d7fc 100644 --- a/csrc/xpu/attention_xpu.cpp +++ b/csrc/xpu/attention_xpu.cpp @@ -1254,4 +1254,71 @@ void paged_attention_v2( query.scalar_type(), "paged_attention_xpu_v2_impl", [&] { CALL_V2_LAUNCHER_BLOCK_SIZE(scalar_t); }); -} \ No newline at end of file +} + + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +void advance_step_ipex(int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables) { + // std::cout << "advance step ipex get called!!!!!!" << std::endl; + sycl::queue& queue = vllm::xpu::vllmGetQueue(); + // TODO: we might want to adjust this value + int num_blocks = 1024; + int num_threads = 32; + long* input_tokens_ptr = reinterpret_cast(input_tokens.data_ptr()); + long const* sampled_token_ids_ptr = reinterpret_cast(sampled_token_ids.data_ptr()); + long* input_positions_ptr = reinterpret_cast(input_positions.data_ptr()); + int* seq_lens_ptr = reinterpret_cast(seq_lens.data_ptr()); + long* slot_mapping_ptr = reinterpret_cast(slot_mapping.data_ptr()); + int const* block_tables_ptr = reinterpret_cast(block_tables.data_ptr()); + int64_t const block_tables_stride = block_tables.stride(0); + sycl::range<1> grid(num_blocks); + sycl::range<1> block(num_threads); + queue.submit([&](sycl::handler & cgh){ + cgh.parallel_for( + sycl::nd_range<1>(grid * block, block), + [=](sycl::nd_item<1> item_ct1){ + //constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + int num_query_blocks = div_ceil(num_queries, num_threads); + + int group = item_ct1.get_group(0); + + if (group >= num_query_blocks) { + return; + } + + int cur_query_id = group * num_threads + item_ct1.get_local_id(0); + + if (cur_query_id >= num_queries) { + return; + } + + input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; + int seq_len = seq_lens_ptr[cur_query_id]; + int next_seq_len = seq_len + 1; + int next_input_pos = next_seq_len - 1; + + // Update seq_lens + seq_lens_ptr[cur_query_id] = next_seq_len; + // Update input_positions + input_positions_ptr[cur_query_id] = next_input_pos; + + int const* seq_block_tables_ptr = + block_tables_ptr + block_tables_stride * cur_query_id; + int block_index = next_input_pos / block_size; + int block_offset = next_input_pos % block_size; + + int slot_num = + seq_block_tables_ptr[block_index] * block_size + block_offset; + // Update slot_mapping + slot_mapping_ptr[cur_query_id] = slot_num; + } + ); + }); +} diff --git a/csrc/xpu/pybind.cpp b/csrc/xpu/pybind.cpp index 4e7f2fa6bd80..b70b2eba121a 100644 --- a/csrc/xpu/pybind.cpp +++ b/csrc/xpu/pybind.cpp @@ -69,6 +69,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "reshape_and_cache", &reshape_and_cache, "Reshape the key and value tensors and cache them"); + + ops.def( + "advance_step_ipex", + &advance_step_ipex, + "Advance steps function used in multi-steps scheduler" + ); // Quant ops.def( diff --git a/csrc/xpu/xpu_ops.h b/csrc/xpu/xpu_ops.h index 6125b19ac80b..2dcc82dcbbc0 100644 --- a/csrc/xpu/xpu_ops.h +++ b/csrc/xpu/xpu_ops.h @@ -93,6 +93,15 @@ torch::Tensor marlin_gemm( TORCH_CHECK(false, "marlin_gemm is not supported on XPU."); } + +void advance_step_ipex(int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables); + torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, diff --git a/vllm/sequence.py b/vllm/sequence.py index 49a198df045b..a46042a4f821 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -692,7 +692,6 @@ def __init__( self.encoder_seq = encoder_seq self.trace_headers = trace_headers self.priority = priority - self.cached_request_output = None @property diff --git a/vllm/worker/xpu_multi_step_model_runner.py b/vllm/worker/xpu_multi_step_model_runner.py index 2bd1f84698a3..71e851dd5210 100644 --- a/vllm/worker/xpu_multi_step_model_runner.py +++ b/vllm/worker/xpu_multi_step_model_runner.py @@ -472,9 +472,25 @@ def _advance_step(self, model_input: XPUStatefulModelInput, attn_metadata = frozen_model_input.attn_metadata assert isinstance(attn_metadata, IpexAttnMetadata) + # Add one to self.seq_lens attn_metadata.advance_step(num_seqs, num_queries) + sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids + + # cloned_input_tokens = frozen_model_input.input_tokens.clone() + # cloned_sampled_token_ids = sampled_token_ids.clone() + # cloned_input_positions = frozen_model_input.input_positions.clone() + # cloned_seq_lens = attn_metadata.seq_lens_tensor.clone() + # cloned_slot_mappings = attn_metadata.slot_mapping.clone() + # cloned_block_tables = attn_metadata.block_tables.clone() + + ############### New implementation ############################## + # import vllm._C.ops + # vllm._C.ops.advance_step_ipex(num_seqs, num_queries, self.block_size, frozen_model_input.input_tokens, sampled_token_ids, frozen_model_input.input_positions, attn_metadata.seq_lens_tensor, attn_metadata.slot_mapping, attn_metadata.block_tables) + # torch.xpu.synchronize() + # vllm._C.ops.advance_step_ipex(num_seqs, num_queries, self.block_size, cloned_input_tokens, cloned_sampled_token_ids, cloned_input_positions, cloned_seq_lens, cloned_slot_mappings, cloned_block_tables) # refer ops.advance_step() + ##################### Original implementation ################### next_seq_len = attn_metadata.seq_lens_tensor + 1 next_input_pos = next_seq_len - 1 attn_metadata.seq_lens_tensor = next_seq_len @@ -486,7 +502,6 @@ def _advance_step(self, model_input: XPUStatefulModelInput, attn_metadata.slot_mapping = slot_num.to(dtype=torch.long) tmp_input_tokens = frozen_model_input.input_tokens - sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids if sampled_token_ids.dim() > 1 and sampled_token_ids.size(-1) == 1: sampled_token_ids = sampled_token_ids.squeeze(-1) tmp_input_tokens[:num_queries] = sampled_token_ids[:num_queries] @@ -498,6 +513,7 @@ def _advance_step(self, model_input: XPUStatefulModelInput, input_positions=tmp_input_positions, ) + # Reset seq_lens if frozen_model_input.seq_lens is not None: tmp_seq_lens = frozen_model_input.seq_lens tmp_seq_lens[:num_queries] = attn_metadata.seq_lens[:num_queries] @@ -505,6 +521,12 @@ def _advance_step(self, model_input: XPUStatefulModelInput, frozen_model_input, seq_lens=tmp_seq_lens, ) + # assert torch.equal(frozen_model_input.input_tokens, cloned_input_tokens) + # assert torch.equal(frozen_model_input.input_positions, cloned_input_positions) + # assert torch.equal(attn_metadata.slot_mapping, cloned_slot_mappings) + # assert torch.equal(attn_metadata.seq_lens_tensor, cloned_seq_lens) + + # print("All checked passed") return model_input