@@ -576,22 +576,59 @@ def test_fp8_block_scales(self, mtp_nextn, fp8kv, attention_dp, cuda_graph,
576576 task .evaluate (llm )
577577
578578 @pytest .mark .skip_device_not_contain (["H100" ])
579- def test_fp8_block_scales_cuda_graph_padding (self ):
579+ @parametrize_with_ids ("mtp_nextn" , [0 , 2 ])
580+ def test_fp8_block_scales_cuda_graph_padding (self , mtp_nextn ):
580581 # OOM on H100 with default free_gpu_memory_fraction=0.9
581582 kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.8 )
583+ mtp_config = None
584+ if mtp_nextn > 0 :
585+ mtp_config = MTPDecodingConfig (num_nextn_predict_layers = mtp_nextn )
582586 pytorch_config = PyTorchConfig (disable_overlap_scheduler = False ,
583587 use_cuda_graph = True ,
584588 cuda_graph_max_batch_size = 512 ,
585589 cuda_graph_padding_enabled = True )
586590 llm = LLM (f"{ llm_models_root ()} /DeepSeek-V3-Lite/fp8" ,
587591 kv_cache_config = kv_cache_config ,
588- pytorch_backend_config = pytorch_config )
592+ pytorch_backend_config = pytorch_config ,
593+ speculative_config = mtp_config )
589594 assert llm .args .quant_config .quant_algo == QuantAlgo .FP8_BLOCK_SCALES
590595 with llm :
591- task = CnnDailymail (self .MODEL_NAME )
596+ task = MMLU (self .MODEL_NAME )
592597 task .evaluate (llm )
598+ task = GSM8K (self .MODEL_NAME )
599+ task .evaluate (llm )
600+
601+ @pytest .mark .skip_less_device (4 )
602+ @pytest .mark .skip_device_not_contain (["H100" , "H200" ])
603+ @parametrize_with_ids ("mtp_nextn" , [0 , 2 ])
604+ @parametrize_with_ids ("attention_dp" , [False , True ])
605+ def test_fp8_block_scales_cuda_graph_padding_4gpus (self , mtp_nextn ,
606+ attention_dp ):
607+ kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.9 )
608+ mtp_config = None
609+ if mtp_nextn > 0 :
610+ mtp_config = MTPDecodingConfig (num_nextn_predict_layers = mtp_nextn )
611+ pytorch_config = PyTorchConfig (
612+ disable_overlap_scheduler = False ,
613+ use_cuda_graph = True ,
614+ cuda_graph_padding_enabled = True ,
615+ )
616+ quant_config = QuantConfig ()
617+ quant_config .quant_algo = QuantAlgo .FP8_BLOCK_SCALES
618+
619+ llm = LLM (f"{ llm_models_root ()} /DeepSeek-V3-Lite/fp8" ,
620+ tensor_parallel_size = 4 ,
621+ kv_cache_config = kv_cache_config ,
622+ pytorch_backend_config = pytorch_config ,
623+ quant_config = quant_config ,
624+ enable_attention_dp = attention_dp ,
625+ speculative_config = mtp_config )
626+ assert llm .args .quant_config .quant_algo == QuantAlgo .FP8_BLOCK_SCALES
627+ with llm :
593628 task = MMLU (self .MODEL_NAME )
594629 task .evaluate (llm )
630+ task = GSM8K (self .MODEL_NAME )
631+ task .evaluate (llm )
595632
596633 @pytest .mark .skip_less_device (4 )
597634 @pytest .mark .skip_device_not_contain (["H100" , "H200" ])
0 commit comments