diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 75aadf2d90a9..a0a37653f9c0 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -210,6 +210,22 @@ def forward( return embeddings +class QATMatMul(nn.Module): + def __init__(self): + super().__init__() + + # behaves like normal torch.matmul unless a SparseML QuantizationModifier + # is initialized + self.wrap_qat = True + self.qat_wrapper_kwargs = { + "num_inputs": 2, + "input_qconfigs": ["asymmetric", "symmetric"], + } + + def forward(self, a: torch.Tensor, b: torch.Tensor): + return torch.matmul(a, b) + + class BertSelfAttention(nn.Module): def __init__(self, config): super().__init__() @@ -227,6 +243,11 @@ def __init__(self, config): self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) + # non-parameterized matmuls will behave as normal torch.matmul ops unless + # Quantization-Aware-Training is invoked + self.attention_scores_matmul = QATMatMul() + self.context_layer_matmul = QATMatMul() + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": @@ -288,7 +309,7 @@ def forward( past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = self.attention_scores_matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": seq_length = hidden_states.size()[1] @@ -322,7 +343,7 @@ def forward( if head_mask is not None: attention_probs = attention_probs * head_mask - context_layer = torch.matmul(attention_probs, value_layer) + context_layer = self.context_layer_matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)