File tree Expand file tree Collapse file tree 2 files changed +8
-5
lines changed Expand file tree Collapse file tree 2 files changed +8
-5
lines changed Original file line number Diff line number Diff line change @@ -37,11 +37,10 @@ def swap_blocks(
3737 ) -> None :
3838 src_k_cache , src_v_cache = src_kv_cache
3939 dst_k_cache , dst_v_cache = dst_kv_cache
40+ src_indices , dst_indices = src_to_dst
41+ device = dst_k_cache .device
4042 torch .ops .xla .dynamo_set_buffer_donor_ (dst_k_cache , True )
4143 torch .ops .xla .dynamo_set_buffer_donor_ (dst_v_cache , True )
42-
43- device = dst_k_cache .device
44- src_indices , dst_indices = src_to_dst
4544 dst_k_cache [:, dst_indices ] = src_k_cache [:, src_indices ].to (device )
4645 dst_v_cache [:, dst_indices ] = src_v_cache [:, src_indices ].to (device )
4746
Original file line number Diff line number Diff line change @@ -156,14 +156,18 @@ def initialize_cache(
156156 self .tpu_cache = []
157157 tpu_cache_shape = self .model_runner .attn_backend .get_kv_cache_shape (
158158 num_gpu_blocks , self .block_size , num_kv_heads , head_size )
159+ cpu_cache_shape = self .model_runner .attn_backend .get_kv_cache_shape (
160+ num_cpu_blocks , self .block_size , num_kv_heads , head_size )
159161 for _ in range (num_layers ):
160162 tpu_k_cache = torch .zeros (tpu_cache_shape ,
161163 dtype = dtype ,
162164 device = self .device )
163165 tpu_v_cache = torch .zeros_like (tpu_k_cache )
164166 self .tpu_cache .append ((tpu_k_cache , tpu_v_cache ))
165- cpu_k_cache = torch .zeros_like (tpu_k_cache , device = "cpu" )
166- cpu_v_cache = torch .zeros_like (tpu_v_cache , device = "cpu" )
167+ cpu_k_cache = torch .zeros (cpu_cache_shape ,
168+ dtype = dtype ,
169+ device = "cpu" )
170+ cpu_v_cache = torch .zeros_like (cpu_k_cache )
167171 self .cpu_cache .append ((cpu_k_cache , cpu_v_cache ))
168172 self ._warmup_model ()
169173
You can’t perform that action at this time.
0 commit comments