Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.

Commit d7878ea

Browse files
authored
enable a QATWrapper for non-parameterized matmuls in BERT self attention (#9)
1 parent af6338e commit d7878ea

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

src/transformers/models/bert/modeling_bert.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,22 @@ def forward(
210210
return embeddings
211211

212212

213+
class QATMatMul(nn.Module):
214+
def __init__(self):
215+
super().__init__()
216+
217+
# behaves like normal torch.matmul unless a SparseML QuantizationModifier
218+
# is initialized
219+
self.wrap_qat = True
220+
self.qat_wrapper_kwargs = {
221+
"num_inputs": 2,
222+
"input_qconfigs": ["asymmetric", "symmetric"],
223+
}
224+
225+
def forward(self, a: torch.Tensor, b: torch.Tensor):
226+
return torch.matmul(a, b)
227+
228+
213229
class BertSelfAttention(nn.Module):
214230
def __init__(self, config):
215231
super().__init__()
@@ -227,6 +243,11 @@ def __init__(self, config):
227243
self.key = nn.Linear(config.hidden_size, self.all_head_size)
228244
self.value = nn.Linear(config.hidden_size, self.all_head_size)
229245

246+
# non-parameterized matmuls will behave as normal torch.matmul ops unless
247+
# Quantization-Aware-Training is invoked
248+
self.attention_scores_matmul = QATMatMul()
249+
self.context_layer_matmul = QATMatMul()
250+
230251
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
231252
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
232253
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
@@ -288,7 +309,7 @@ def forward(
288309
past_key_value = (key_layer, value_layer)
289310

290311
# Take the dot product between "query" and "key" to get the raw attention scores.
291-
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
312+
attention_scores = self.attention_scores_matmul(query_layer, key_layer.transpose(-1, -2))
292313

293314
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
294315
seq_length = hidden_states.size()[1]
@@ -322,7 +343,7 @@ def forward(
322343
if head_mask is not None:
323344
attention_probs = attention_probs * head_mask
324345

325-
context_layer = torch.matmul(attention_probs, value_layer)
346+
context_layer = self.context_layer_matmul(attention_probs, value_layer)
326347

327348
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
328349
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)

0 commit comments

Comments
 (0)