@@ -210,6 +210,22 @@ def forward(
210210 return embeddings
211211
212212
213+ class QATMatMul (nn .Module ):
214+ def __init__ (self ):
215+ super ().__init__ ()
216+
217+ # behaves like normal torch.matmul unless a SparseML QuantizationModifier
218+ # is initialized
219+ self .wrap_qat = True
220+ self .qat_wrapper_kwargs = {
221+ "num_inputs" : 2 ,
222+ "input_qconfigs" : ["asymmetric" , "symmetric" ],
223+ }
224+
225+ def forward (self , a : torch .Tensor , b : torch .Tensor ):
226+ return torch .matmul (a , b )
227+
228+
213229class BertSelfAttention (nn .Module ):
214230 def __init__ (self , config ):
215231 super ().__init__ ()
@@ -227,6 +243,11 @@ def __init__(self, config):
227243 self .key = nn .Linear (config .hidden_size , self .all_head_size )
228244 self .value = nn .Linear (config .hidden_size , self .all_head_size )
229245
246+ # non-parameterized matmuls will behave as normal torch.matmul ops unless
247+ # Quantization-Aware-Training is invoked
248+ self .attention_scores_matmul = QATMatMul ()
249+ self .context_layer_matmul = QATMatMul ()
250+
230251 self .dropout = nn .Dropout (config .attention_probs_dropout_prob )
231252 self .position_embedding_type = getattr (config , "position_embedding_type" , "absolute" )
232253 if self .position_embedding_type == "relative_key" or self .position_embedding_type == "relative_key_query" :
@@ -288,7 +309,7 @@ def forward(
288309 past_key_value = (key_layer , value_layer )
289310
290311 # Take the dot product between "query" and "key" to get the raw attention scores.
291- attention_scores = torch . matmul (query_layer , key_layer .transpose (- 1 , - 2 ))
312+ attention_scores = self . attention_scores_matmul (query_layer , key_layer .transpose (- 1 , - 2 ))
292313
293314 if self .position_embedding_type == "relative_key" or self .position_embedding_type == "relative_key_query" :
294315 seq_length = hidden_states .size ()[1 ]
@@ -322,7 +343,7 @@ def forward(
322343 if head_mask is not None :
323344 attention_probs = attention_probs * head_mask
324345
325- context_layer = torch . matmul (attention_probs , value_layer )
346+ context_layer = self . context_layer_matmul (attention_probs , value_layer )
326347
327348 context_layer = context_layer .permute (0 , 2 , 1 , 3 ).contiguous ()
328349 new_context_layer_shape = context_layer .size ()[:- 2 ] + (self .all_head_size ,)
0 commit comments