@@ -91,6 +91,22 @@ def _create_sinusoidal_embeddings(n_pos, dim, out):
9191 out .detach_ ()
9292
9393
94+ class QATMatMul (nn .Module ):
95+ def __init__ (self ):
96+ super ().__init__ ()
97+
98+ # behaves like normal torch.matmul unless a SparseML QuantizationModifier
99+ # is initialized
100+ self .wrap_qat = True
101+ self .qat_wrapper_kwargs = {
102+ "num_inputs" : 2 ,
103+ "input_qconfigs" : ["asymmetric" , "symmetric" ],
104+ }
105+
106+ def forward (self , a : torch .Tensor , b : torch .Tensor ):
107+ return torch .matmul (a , b )
108+
109+
94110class Embeddings (nn .Module ):
95111 def __init__ (self , config ):
96112 super ().__init__ ()
@@ -153,6 +169,11 @@ def __init__(self, config):
153169
154170 self .pruned_heads = set ()
155171
172+ # non-parameterized matmuls will behave as normal torch.matmul ops unless
173+ # Quantization-Aware-Training is invoked
174+ self .attention_scores_matmul = QATMatMul ()
175+ self .context_layer_matmul = QATMatMul ()
176+
156177 def prune_heads (self , heads ):
157178 attention_head_size = self .dim // self .n_heads
158179 if len (heads ) == 0 :
@@ -202,7 +223,7 @@ def unshape(x):
202223 v = shape (self .v_lin (value )) # (bs, n_heads, k_length, dim_per_head)
203224
204225 q = q / math .sqrt (dim_per_head ) # (bs, n_heads, q_length, dim_per_head)
205- scores = torch . matmul (q , k .transpose (2 , 3 )) # (bs, n_heads, q_length, k_length)
226+ scores = self . attention_scores_matmul (q , k .transpose (2 , 3 )) # (bs, n_heads, q_length, k_length)
206227 mask = (mask == 0 ).view (mask_reshp ).expand_as (scores ) # (bs, n_heads, q_length, k_length)
207228 scores = scores .masked_fill (mask , - float ("inf" )) # (bs, n_heads, q_length, k_length)
208229
@@ -213,7 +234,7 @@ def unshape(x):
213234 if head_mask is not None :
214235 weights = weights * head_mask
215236
216- context = torch . matmul (weights , v ) # (bs, n_heads, q_length, dim_per_head)
237+ context = self . context_layer_matmul (weights , v ) # (bs, n_heads, q_length, dim_per_head)
217238 context = unshape (context ) # (bs, q_length, dim)
218239 context = self .out_lin (context ) # (bs, q_length, dim)
219240
@@ -625,7 +646,6 @@ def forward(
625646 loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
626647 """
627648 return_dict = return_dict if return_dict is not None else self .config .use_return_dict
628-
629649 dlbrt_output = self .distilbert (
630650 input_ids = input_ids ,
631651 attention_mask = attention_mask ,
0 commit comments