@@ -91,7 +91,7 @@ def _create_sinusoidal_embeddings(n_pos, dim, out):
9191 out .detach_ ()
9292
9393
94- class QATMatMul (nn .Module ):
94+ class QATAttentionScores (nn .Module ):
9595 def __init__ (self ):
9696 super ().__init__ ()
9797
@@ -106,6 +106,22 @@ def __init__(self):
106106 def forward (self , a : torch .Tensor , b : torch .Tensor ):
107107 return torch .matmul (a , b )
108108
109+ class QATContextLayer (nn .Module ):
110+ def __init__ (self ):
111+ super ().__init__ ()
112+
113+ # behaves like normal torch.matmul unless a SparseML QuantizationModifier
114+ # is initialized
115+ self .wrap_qat = True
116+ self .qat_wrapper_kwargs = {
117+ "num_inputs" : 2 ,
118+ "num_outputs" : 0 ,
119+ "input_qconfigs" : ["asymmetric" , "symmetric" ],
120+ }
121+
122+ def forward (self , a : torch .Tensor , b : torch .Tensor ):
123+ return torch .matmul (a , b )
124+
109125
110126class Embeddings (nn .Module ):
111127 def __init__ (self , config ):
@@ -171,8 +187,8 @@ def __init__(self, config):
171187
172188 # non-parameterized matmuls will behave as normal torch.matmul ops unless
173189 # Quantization-Aware-Training is invoked
174- self .attention_scores_matmul = QATMatMul ()
175- self .context_layer_matmul = QATMatMul ()
190+ self .attention_scores_matmul = QATAttentionScores ()
191+ self .context_layer_matmul = QATContextLayer ()
176192
177193 def prune_heads (self , heads ):
178194 attention_head_size = self .dim // self .n_heads
0 commit comments