@@ -84,7 +84,7 @@ def test_prefill(hash_algo):
8484    blocks  =  manager .allocate_slots (req0 , 55 ,
8585                                    len (computed_blocks .blocks ) *  16 ,
8686                                    computed_blocks )
87-     assert  blocks .get_block_ids () ==  [1 , 2 , 3 , 4 ]
87+     assert  blocks .get_block_ids () ==  [[ 1 , 2 , 3 , 4 ] ]
8888
8989    # Check full block metadata 
9090    parent_block_hash  =  None 
@@ -107,13 +107,13 @@ def test_prefill(hash_algo):
107107    req1  =  make_request ("1" , common_token_ids  +  unique_token_ids )
108108    computed_blocks , num_computed_tokens  =  manager .get_computed_blocks (req1 )
109109    assert  len (manager .req_to_block_hashes [req1 .request_id ]) ==  3 
110-     assert  computed_blocks .get_block_ids () ==  [1 , 2 , 3 ]
110+     assert  computed_blocks .get_block_ids () ==  [[ 1 , 2 , 3 ] ]
111111    assert  num_computed_tokens  ==  3  *  16 
112112    num_new_tokens  =  53  -  3  *  16 
113113    blocks  =  manager .allocate_slots (req1 , num_new_tokens ,
114114                                    len (computed_blocks .blocks ) *  16 ,
115115                                    computed_blocks )
116-     assert  blocks .get_block_ids () ==  [5 ]
116+     assert  blocks .get_block_ids () ==  [[ 5 ] ]
117117    for  block  in  computed_blocks .blocks :
118118        assert  block .ref_cnt  ==  2 
119119
@@ -141,13 +141,13 @@ def test_prefill(hash_algo):
141141    req2  =  make_request ("2" , common_token_ids  +  unique_token_ids )
142142    computed_blocks , num_computed_tokens  =  manager .get_computed_blocks (req2 )
143143    assert  len (manager .req_to_block_hashes [req2 .request_id ]) ==  3 
144-     assert  computed_blocks .get_block_ids () ==  [1 , 2 , 3 ]
144+     assert  computed_blocks .get_block_ids () ==  [[ 1 , 2 , 3 ] ]
145145    assert  num_computed_tokens  ==  3  *  16 
146146    num_new_tokens  =  53  -  3  *  16 
147147    blocks  =  manager .allocate_slots (req2 , num_new_tokens ,
148148                                    len (computed_blocks .blocks ) *  16 ,
149149                                    computed_blocks )
150-     assert  blocks .get_block_ids () ==  [6 ]
150+     assert  blocks .get_block_ids () ==  [[ 6 ] ]
151151
152152    # Although we only have 6 free blocks, we have 8 blocks in 
153153    # the free block queue due to lazy removal. 
@@ -171,7 +171,7 @@ def test_prefill(hash_algo):
171171                                    len (computed_blocks .blocks ) *  16 ,
172172                                    computed_blocks )
173173    # This block ID order also checks the eviction order. 
174-     assert  blocks .get_block_ids () ==  [7 , 8 , 9 , 10 , 4 , 5 , 6 , 3 , 2 , 1 ]
174+     assert  blocks .get_block_ids () ==  [[ 7 , 8 , 9 , 10 , 4 , 5 , 6 , 3 , 2 , 1 ] ]
175175    assert  manager .block_pool .free_block_queue .num_free_blocks  ==  0 
176176    assert  manager .block_pool .free_block_queue .free_list_head  is  None 
177177    assert  manager .block_pool .free_block_queue .free_list_tail  is  None 
@@ -208,7 +208,7 @@ def test_prefill_plp():
208208    blocks  =  manager .allocate_slots (req0 , 55 ,
209209                                    len (computed_blocks .blocks ) *  16 ,
210210                                    computed_blocks )
211-     assert  blocks .get_block_ids () ==  [1 , 2 , 3 , 4 ]
211+     assert  blocks .get_block_ids () ==  [[ 1 , 2 , 3 , 4 ] ]
212212    req0_block_hashes  =  [b .block_hash  for  b  in  blocks .blocks ]
213213
214214    # Check full block metadata 
@@ -233,13 +233,13 @@ def test_prefill_plp():
233233    req1  =  make_request ("1" , common_token_ids  +  unique_token_ids )
234234    computed_blocks , num_computed_tokens  =  manager .get_computed_blocks (req1 )
235235    assert  len (manager .req_to_block_hashes [req1 .request_id ]) ==  3 
236-     assert  computed_blocks .get_block_ids () ==  [1 , 2 , 3 ]
236+     assert  computed_blocks .get_block_ids () ==  [[ 1 , 2 , 3 ] ]
237237    assert  num_computed_tokens  ==  3  *  16 
238238    num_new_tokens  =  53  -  3  *  16 
239239    blocks  =  manager .allocate_slots (req1 , num_new_tokens ,
240240                                    len (computed_blocks .blocks ) *  16 ,
241241                                    computed_blocks )
242-     assert  blocks .get_block_ids () ==  [5 ]
242+     assert  blocks .get_block_ids () ==  [[ 5 ] ]
243243    for  block  in  computed_blocks .blocks :
244244        assert  block .ref_cnt  ==  2 
245245
@@ -277,11 +277,11 @@ def test_prefill_plp():
277277    block_ids  =  blocks .get_block_ids ()
278278    # Duplicate cached blocks have different ids but same hashes vs request #0 
279279    assert  [b .block_hash  for  b  in  blocks .blocks ] ==  req0_block_hashes 
280-     assert  block_ids  !=  [1 , 2 , 3 , 4 ]
280+     assert  block_ids  !=  [[ 1 , 2 , 3 , 4 ] ]
281281
282282    # Request #2 block hashes are valid since request #0 hashes are. 
283283    # Check block reference counts. 
284-     for  block_id  in  block_ids :
284+     for  block_id  in  block_ids [ 0 ] :
285285        assert  manager .block_pool .blocks [block_id ].ref_cnt  ==  1 
286286
287287    manager .free (req2 )
@@ -307,7 +307,7 @@ def test_decode():
307307    blocks  =  manager .allocate_slots (req0 , 55 ,
308308                                    len (computed_blocks .blocks ) *  16 ,
309309                                    computed_blocks )
310-     assert  blocks .get_block_ids () ==  [1 , 2 , 3 , 4 ]
310+     assert  blocks .get_block_ids () ==  [[ 1 , 2 , 3 , 4 ] ]
311311
312312    # Append slots without allocating a new block. 
313313    req0 .num_computed_tokens  =  55 
@@ -379,12 +379,12 @@ def test_evict():
379379    # Touch the first 2 blocks. 
380380    req2  =  make_request ("2" , list (range (2  *  16  +  3 )))
381381    computed_blocks , num_computed_tokens  =  manager .get_computed_blocks (req2 )
382-     assert  computed_blocks .get_block_ids () ==  [1 , 2 ]
382+     assert  computed_blocks .get_block_ids () ==  [[ 1 , 2 ] ]
383383    assert  num_computed_tokens  ==  2  *  16 
384384    blocks  =  manager .allocate_slots (req2 , 3 ,
385385                                    len (computed_blocks .blocks ) *  16 ,
386386                                    computed_blocks )
387-     assert  blocks .get_block_ids () ==  [10 ]
387+     assert  blocks .get_block_ids () ==  [[ 10 ] ]
388388    assert  manager .block_pool .free_block_queue .num_free_blocks  ==  7 
389389
390390
@@ -625,7 +625,7 @@ def test_mm_prefix_caching():
625625    blocks  =  manager .allocate_slots (req0 , 59 ,
626626                                    len (computed_blocks .blocks ) *  16 ,
627627                                    computed_blocks )
628-     assert  blocks .get_block_ids () ==  [1 , 2 , 3 , 4 ]
628+     assert  blocks .get_block_ids () ==  [[ 1 , 2 , 3 , 4 ] ]
629629    req0 .num_computed_tokens  =  59 
630630
631631    # Append slots without allocating a new block. 
@@ -686,7 +686,7 @@ def test_cache_key_salting():
686686    blocks  =  manager .allocate_slots (req0 , 59 ,
687687                                    len (computed_blocks .blocks ) *  16 ,
688688                                    computed_blocks )
689-     assert  blocks .get_block_ids () ==  [1 , 2 , 3 , 4 ]
689+     assert  blocks .get_block_ids () ==  [[ 1 , 2 , 3 , 4 ] ]
690690    req0 .num_computed_tokens  =  59 
691691
692692    # Append slots without allocating a new block. 
@@ -797,7 +797,7 @@ def test_reset_prefix_cache():
797797    all_token_ids  =  full_block_token_ids  +  unique_token_ids 
798798    req0  =  make_request ("0" , all_token_ids )
799799    blocks  =  manager .allocate_slots (req0 , 55 )
800-     assert  blocks .get_block_ids () ==  [1 , 2 , 3 , 4 ]
800+     assert  blocks .get_block_ids () ==  [[ 1 , 2 , 3 , 4 ] ]
801801
802802    unique_token_ids  =  [4 ] *  7 
803803    all_token_ids  =  full_block_token_ids  +  unique_token_ids 
@@ -808,7 +808,7 @@ def test_reset_prefix_cache():
808808    blocks  =  manager .allocate_slots (req1 , 7 ,
809809                                    len (computed_blocks .blocks ) *  16 ,
810810                                    computed_blocks )
811-     assert  blocks .get_block_ids () ==  [5 ]
811+     assert  blocks .get_block_ids () ==  [[ 5 ] ]
812812
813813    # Failed to reset prefix cache because some blocks are not freed yet. 
814814    assert  not  manager .reset_prefix_cache ()
0 commit comments