33# 
44# This source code is licensed under the BSD 3-Clause license found in the 
55# LICENSE file in the root directory of this source tree. 
6- import  copy 
76import  logging 
87import  unittest 
98
109import  torch 
1110from  torch  import  nn 
12- from  torch .testing ._internal  import  common_utils 
11+ from  torch .testing ._internal . common_utils  import  TestCase 
1312
14- from  torchao .dtypes  import  MarlinSparseLayout , SemiSparseLayout 
1513from  torchao .quantization  import  (
16-     Float8DynamicActivationFloat8SemiSparseWeightConfig ,
1714    Float8DynamicActivationFloat8WeightConfig ,
1815)
1916from  torchao .quantization .quant_api  import  (
20-     Int4WeightOnlyConfig ,
21-     Int8DynamicActivationInt8WeightConfig ,
17+     ParamFqnToConfig ,
2218    PerRow ,
23-     PerTensor ,
2419    quantize_ ,
2520)
2621from  torchao .quantization .quantize_ .workflows .float8 .float8_tensor  import  Float8Tensor 
27- from  torchao .sparsity  import  apply_fake_sparsity , semi_sparse_weight , sparsify_ 
28- from  torchao .utils  import  is_sm_at_least_90 
29- import  torch .nn .functional  as  F 
30- 
31- import  re 
32- import  unittest 
33- import  warnings 
34- import  torch 
35- from  torch .testing ._internal .common_utils  import  TestCase , run_tests 
3622from  torchao .utils  import  is_fbcode , is_sm_at_least_90 
37- from  torchao .quantization .quant_api  import  ParamFqnToConfig  
3823
3924if  not  is_fbcode ():
40-     from   transformers   import   AutoModelForCausalLM ,  AutoTokenizer ,  TorchAoConfig 
25+     pass 
4126
4227logging .basicConfig (
4328    format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" , level = logging .INFO 
4429)
4530
31+ 
4632@unittest .skipIf (not  torch .cuda .is_available (), "Need CUDA available" ) 
4733@unittest .skipIf (not  is_sm_at_least_90 (), "Checkpoints are produced in SM90+" ) 
4834@unittest .skipIf ( 
4935    is_fbcode (), 
5036    "Skipping the test in fbcode for now, not sure how to download from transformers" , 
5137) 
52- class  TestQuantizeFQNParam  (TestCase ):
53- 
38+ class  TestQuantizeFQNParam (TestCase ):
5439    def  test_quantize_param_fqn_exact (self ):
55-         from  transformers  import  AutoConfig ,  AutoModel 
40+         from  transformers  import  AutoConfig 
5641        from  transformers .models .llama4 .modeling_llama4  import  Llama4TextMoe 
5742
58-         config  =  AutoConfig .from_pretrained ("unsloth/Llama-4-Scout-17B-16E-Instruct" ).text_config 
43+         config  =  AutoConfig .from_pretrained (
44+             "unsloth/Llama-4-Scout-17B-16E-Instruct" 
45+         ).text_config 
5946        model  =  Llama4TextMoe (config ).to (torch .bfloat16 ).cuda ()
60-         input_tensor   =   torch .randn (16 , 128 , config .hidden_size ).cuda ().bfloat16 ()
61- 
62- 
63-         quant_config   =   ParamFqnToConfig ( {
64-             "experts.gate_up_proj" : Float8DynamicActivationFloat8WeightConfig (
65-                 granularity = PerRow (),
66-             ),
67-         }) 
68- 
47+         torch .randn (16 , 128 , config .hidden_size ).cuda ().bfloat16 ()
48+ 
49+          quant_config   =   ParamFqnToConfig ( 
50+              {
51+                  "experts.gate_up_proj" : Float8DynamicActivationFloat8WeightConfig (
52+                      granularity = PerRow (),
53+                  ),
54+             } 
55+         ) 
6956
7057        quantize_ (
7158            model ,
@@ -75,24 +62,27 @@ def test_quantize_param_fqn_exact(self):
7562        assert  isinstance (model .experts .gate_up_proj , Float8Tensor )
7663
7764    def  test_quantize_param_fqn_regex (self ):
78-         from  transformers  import  AutoConfig ,  AutoModel 
65+         from  transformers  import  AutoConfig 
7966        from  transformers .models .llama4 .modeling_llama4  import  Llama4TextMoe 
8067
81-         config  =  AutoConfig .from_pretrained ("unsloth/Llama-4-Scout-17B-16E-Instruct" ).text_config 
68+         config  =  AutoConfig .from_pretrained (
69+             "unsloth/Llama-4-Scout-17B-16E-Instruct" 
70+         ).text_config 
8271        model  =  Llama4TextMoe (config ).to (torch .bfloat16 ).cuda ()
83-         input_tensor   =   torch .randn (16 , 128 , config .hidden_size ).cuda ().bfloat16 ()
72+         torch .randn (16 , 128 , config .hidden_size ).cuda ().bfloat16 ()
8473        # print(model.experts) 
8574        for  name , param  in  model .named_parameters ():
8675            print (name )
8776
88-         from  torchao .quantization .quant_api  import  ParamFqnToConfig  
89- 
90-         quant_config  =  ParamFqnToConfig ({
91-             ".*gate_up_proj" : Float8DynamicActivationFloat8WeightConfig (
92-                 granularity = PerRow (),
93-             ),
94-         })
77+         from  torchao .quantization .quant_api  import  ParamFqnToConfig 
9578
79+         quant_config  =  ParamFqnToConfig (
80+             {
81+                 ".*gate_up_proj" : Float8DynamicActivationFloat8WeightConfig (
82+                     granularity = PerRow (),
83+                 ),
84+             }
85+         )
9686
9787        quantize_ (
9888            model ,
@@ -103,5 +93,7 @@ def test_quantize_param_fqn_regex(self):
10393
10494    def  test_quantize_param_root (self ):
10595        param  =  nn .Parameter (torch .randn (1024 , 1024 ).cuda ().to (torch .bfloat16 ))
106-         new_param  =  quantize_ (param , Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ()))
96+         new_param  =  quantize_ (
97+             param , Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ())
98+         )
10799        assert  isinstance (new_param , Float8Tensor )
0 commit comments