@@ -89,26 +89,22 @@ def _linear_bf16_act_uint4_weight_float_zero_impl(input_tensor, weight_tensor, b
8989    return  y .to (orig_dtype )
9090
9191
92- def  _linear_bf16_act_uint4_weight_int8_zero_check (input_tensor , weight_tensor , bias ):
92+ def  _linear_fp_act_uint4_weight_int8_zero_check (input_tensor , weight_tensor , bias ):
9393    return  (
94-         # input is native bfloat16 tensor 
9594        not  is_traceable_wrapper_subclass (input_tensor )
96-         and  input_tensor .dtype  ==  torch .bfloat16 
9795        and 
9896        # weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor 
9997        isinstance (weight_tensor , AffineQuantizedTensor )
10098        and  _aqt_is_xpu_layout_uint4 (weight_tensor )
101-         and  weight_tensor .dtype  ==  torch .bfloat16 
10299        and  len (weight_tensor .shape ) ==  2 
103100        and  weight_tensor .zero_point_domain  ==  ZeroPointDomain .INT 
104101        and  weight_tensor .tensor_impl .scale_and_zero  is  None 
105-         and  weight_tensor .tensor_impl .scale .dtype  ==  torch .bfloat16 
106102        and  weight_tensor .tensor_impl .zero .dtype  ==  torch .int8 
107103        and  isinstance (weight_tensor ._layout , Int4XPULayout )
108104    )
109105
110106
111- def  _linear_bf16_act_uint4_weight_int8_zero_impl (input_tensor , weight_tensor , bias ):
107+ def  _linear_fp_act_uint4_weight_int8_zero_impl (input_tensor , weight_tensor , bias ):
112108    assert  weight_tensor .block_size [0 ] ==  1 , (
113109        f"Requires groupwise quantization, got block_size: { weight_tensor .block_size }  " 
114110    )
@@ -129,7 +125,7 @@ def _linear_bf16_act_uint4_weight_int8_zero_impl(input_tensor, weight_tensor, bi
129125    orig_act_size  =  act_mat .size ()
130126    orig_dtype  =  act_mat .dtype 
131127
132-     act_mat  =  act_mat .reshape (- 1 , act_mat .shape [- 1 ]). to ( torch . bfloat16 ) 
128+     act_mat  =  act_mat .reshape (- 1 , act_mat .shape [- 1 ])
133129
134130    # groupwise int4 quantization 
135131    groupsize  =  weight_tensor .block_size [1 ]
0 commit comments