Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.

Commit 0858256

Browse files
authored
Using QATMatMul in DistilBERT model class (#41)
1 parent b90eb20 commit 0858256

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

src/transformers/models/distilbert/modeling_distilbert.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,22 @@ def _create_sinusoidal_embeddings(n_pos, dim, out):
9191
out.detach_()
9292

9393

94+
class QATMatMul(nn.Module):
95+
def __init__(self):
96+
super().__init__()
97+
98+
# behaves like normal torch.matmul unless a SparseML QuantizationModifier
99+
# is initialized
100+
self.wrap_qat = True
101+
self.qat_wrapper_kwargs = {
102+
"num_inputs": 2,
103+
"input_qconfigs": ["asymmetric", "symmetric"],
104+
}
105+
106+
def forward(self, a: torch.Tensor, b: torch.Tensor):
107+
return torch.matmul(a, b)
108+
109+
94110
class Embeddings(nn.Module):
95111
def __init__(self, config):
96112
super().__init__()
@@ -153,6 +169,11 @@ def __init__(self, config):
153169

154170
self.pruned_heads = set()
155171

172+
# non-parameterized matmuls will behave as normal torch.matmul ops unless
173+
# Quantization-Aware-Training is invoked
174+
self.attention_scores_matmul = QATMatMul()
175+
self.context_layer_matmul = QATMatMul()
176+
156177
def prune_heads(self, heads):
157178
attention_head_size = self.dim // self.n_heads
158179
if len(heads) == 0:
@@ -202,7 +223,7 @@ def unshape(x):
202223
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
203224

204225
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
205-
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
226+
scores = self.attention_scores_matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
206227
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
207228
scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length)
208229

@@ -213,7 +234,7 @@ def unshape(x):
213234
if head_mask is not None:
214235
weights = weights * head_mask
215236

216-
context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
237+
context = self.context_layer_matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
217238
context = unshape(context) # (bs, q_length, dim)
218239
context = self.out_lin(context) # (bs, q_length, dim)
219240

@@ -625,7 +646,6 @@ def forward(
625646
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
626647
"""
627648
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
628-
629649
dlbrt_output = self.distilbert(
630650
input_ids=input_ids,
631651
attention_mask=attention_mask,

src/transformers/trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2436,6 +2436,12 @@ def evaluation_loop(
24362436
observed_num_examples = 0
24372437
# Main evaluation loop
24382438
for step, inputs in enumerate(dataloader):
2439+
inputs = {
2440+
k: inputs[k]
2441+
for k in inputs
2442+
if k in list(inspect.signature(model.forward).parameters.keys())
2443+
}
2444+
24392445
# Update the observed num examples
24402446
observed_batch_size = find_batch_size(inputs)
24412447
if observed_batch_size is not None:

0 commit comments

Comments
 (0)