Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.

Commit 5afbd46

Browse files
authored
Removed double quantization of output of context layer. (#45)
1 parent 86c51a9 commit 5afbd46

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

src/transformers/models/distilbert/modeling_distilbert.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _create_sinusoidal_embeddings(n_pos, dim, out):
9191
out.detach_()
9292

9393

94-
class QATMatMul(nn.Module):
94+
class QATAttentionScores(nn.Module):
9595
def __init__(self):
9696
super().__init__()
9797

@@ -106,6 +106,22 @@ def __init__(self):
106106
def forward(self, a: torch.Tensor, b: torch.Tensor):
107107
return torch.matmul(a, b)
108108

109+
class QATContextLayer(nn.Module):
110+
def __init__(self):
111+
super().__init__()
112+
113+
# behaves like normal torch.matmul unless a SparseML QuantizationModifier
114+
# is initialized
115+
self.wrap_qat = True
116+
self.qat_wrapper_kwargs = {
117+
"num_inputs": 2,
118+
"num_outputs": 0,
119+
"input_qconfigs": ["asymmetric", "symmetric"],
120+
}
121+
122+
def forward(self, a: torch.Tensor, b: torch.Tensor):
123+
return torch.matmul(a, b)
124+
109125

110126
class Embeddings(nn.Module):
111127
def __init__(self, config):
@@ -171,8 +187,8 @@ def __init__(self, config):
171187

172188
# non-parameterized matmuls will behave as normal torch.matmul ops unless
173189
# Quantization-Aware-Training is invoked
174-
self.attention_scores_matmul = QATMatMul()
175-
self.context_layer_matmul = QATMatMul()
190+
self.attention_scores_matmul = QATAttentionScores()
191+
self.context_layer_matmul = QATContextLayer()
176192

177193
def prune_heads(self, heads):
178194
attention_head_size = self.dim // self.n_heads

0 commit comments

Comments
 (0)