Skip to content

Commit 11b8cd3

Browse files
altanhanwang2009
andauthored
[ONNX] Add imports for BERT contrib operators (#10949)
* EmbedLayerNormalization, Attention * fix Attention * SkipLayerNormalization * fix dtype bug in Gelu Co-authored-by: An Wang <[email protected]> * missing parameterize_targets * lint * lint * comments * fix small thing * factor out layer norm computation * layernorm func * add optional args to test * upgrade onnxrt version * no upgrade onnx * fix tests * int32 * fix tests Co-authored-by: An Wang <[email protected]>
1 parent 814e856 commit 11b8cd3

File tree

2 files changed

+440
-3
lines changed

2 files changed

+440
-3
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 221 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,22 @@ def flatten_to_nd(x, x_shape, nd=3):
329329
return _op.nn.dense(inputs[0], input_1_t, out_dtype=out_dtype)
330330

331331

332+
def layer_norm(x, eps, gamma, beta):
333+
"""Common function to handle layer norm"""
334+
eps_dtype = infer_type(x).checked_type.dtype
335+
336+
u, s = _op.mean_variance(x, axis=-1, keepdims=True)
337+
output = _op.divide(
338+
_op.subtract(x, u),
339+
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
340+
)
341+
output = _op.multiply(output, gamma)
342+
if beta is not None:
343+
output = _op.add(output, beta)
344+
345+
return output
346+
347+
332348
class OnnxOpConverter(object):
333349
"""A helper class for holding onnx op converters."""
334350

@@ -807,9 +823,10 @@ def _impl_v1(cls, inputs, attr, params):
807823
x = inputs[0]
808824

809825
# Declare consts
810-
half = _expr.const(0.5)
811-
one = _expr.const(1.0)
812-
sqrt2 = _expr.const(math.sqrt(2))
826+
const_dtype = infer_type(x).checked_type.dtype
827+
half = _expr.const(0.5, dtype=const_dtype)
828+
one = _expr.const(1.0, dtype=const_dtype)
829+
sqrt2 = _expr.const(math.sqrt(2), dtype=const_dtype)
813830

814831
# Compute gelu
815832
term1 = _op.multiply(half, x)
@@ -836,6 +853,201 @@ def _impl_v1(cls, inputs, attr, params):
836853
return Gelu._impl_v1([inp], attr, params)
837854

838855

856+
class EmbedLayerNormalization(OnnxOpConverter):
857+
"""Operator converter for EmbedLayerNormalization from Microsoft onnxruntime contrib opset.
858+
859+
This layer embeds the input tokens, sums them, and applies layer normalization.
860+
"""
861+
862+
@classmethod
863+
def _impl_v1(cls, inputs, attr, params):
864+
input_ids = inputs[0]
865+
segment_ids = inputs[1]
866+
word_emb = inputs[2]
867+
pos_emb = inputs[3]
868+
segment_emb = inputs[4]
869+
gamma = inputs[5]
870+
beta = inputs[6]
871+
872+
mask = inputs[7]
873+
pos_ids = inputs[8]
874+
875+
eps = attr.get("epsilon", 1e-12)
876+
877+
(batch_size, seq_len) = infer_shape(input_ids)
878+
879+
if segment_ids:
880+
assert segment_emb
881+
882+
if pos_ids is None:
883+
pos_ids = _op.const([list(range(seq_len))] * seq_len, dtype="int32")
884+
885+
word_vec = _op.take(word_emb, input_ids, axis=0)
886+
segment_vec = _op.take(segment_emb, segment_ids, axis=0)
887+
pos_vec = _op.take(pos_emb, pos_ids, axis=0)
888+
889+
vec_sum = _op.add(word_vec, pos_vec)
890+
if segment_ids:
891+
vec_sum = _op.add(vec_sum, segment_vec)
892+
893+
ln = layer_norm(vec_sum, eps, gamma, beta)
894+
895+
mask_index = _op.const(np.zeros((batch_size,), dtype="int32"))
896+
if mask:
897+
# calculate number of words per sentence
898+
mask_index = _op.sum(mask, axis=1)
899+
900+
# TODO(@anwang2009): onnxruntime v1.10.0 requires a third output of vec_sum
901+
return _expr.TupleWrapper(_expr.Tuple([ln, mask_index]), 2)
902+
903+
904+
class SkipLayerNormalization(OnnxOpConverter):
905+
"""Operator converter for SkipLayerNormalization from Microsoft onnxruntime contrib opset.
906+
907+
This layer sums the two input tensors (along with optional bias), and applies layer
908+
normalization.
909+
"""
910+
911+
@classmethod
912+
def _impl_v1(cls, inputs, attr, params):
913+
data = inputs[0]
914+
skip = inputs[1]
915+
gamma = inputs[2]
916+
beta = inputs[3]
917+
bias = inputs[4]
918+
919+
assert (
920+
beta is not None and bias is not None
921+
), "SkipLayerNormalization import currently only supports required beta and bias"
922+
923+
eps = attr.get("epsilon", 1e-12)
924+
925+
x = _op.add(data, skip)
926+
if bias is not None:
927+
x = _op.add(x, bias)
928+
929+
output = layer_norm(x, eps, gamma, beta)
930+
931+
# onnxruntime doesn't compute the other outputs, despite the documentation
932+
placeholder = _op.const(0, dtype="float32")
933+
934+
return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3)
935+
936+
937+
class Attention(OnnxOpConverter):
938+
"""Operator converter for Attention from Microsoft onnxruntime contrib opset.
939+
940+
This is the self-attention mechanism used in transformer models.
941+
"""
942+
943+
@classmethod
944+
def _impl_v1(cls, inputs, attr, params):
945+
num_heads = attr["num_heads"]
946+
assert (
947+
"qkv_hidden_sizes" not in attr
948+
), "different hidden sizes for Q, K, V are not currently supported"
949+
assert "unidirectional" not in attr, "unidirectional attention not current supported"
950+
951+
# (batch, seq, in_hidden)
952+
input_emb = inputs[0]
953+
954+
# (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
955+
weight = inputs[1]
956+
957+
# (3 * out_hidden,)
958+
bias = inputs[2]
959+
960+
# 1. ( batch, 1, max_seq, max_seq)
961+
# 2. ( batch, past_seq + seq,)
962+
# 3. ( batch, seq, past_seq + seq,)
963+
# 4. ( batch,)
964+
# 5. (2 * batch,)
965+
# For now, we only support case 2.
966+
mask_index = inputs[3]
967+
968+
# (2, batch, num_heads, past_seq, head_size)
969+
past = inputs[4]
970+
971+
# (batch, num_heads, seq, seq)
972+
extra_add = inputs[5]
973+
974+
(batch_size, seq_len, _) = infer_shape(input_emb)
975+
(out_hidden_x3,) = infer_shape(bias)
976+
assert out_hidden_x3 % 3 == 0, "bias shape should be divisible by 3"
977+
out_hidden = out_hidden_x3 // 3
978+
assert (
979+
out_hidden % num_heads == 0
980+
), "output hidden size should be divisible by number of attention heads"
981+
head_size = out_hidden // num_heads
982+
983+
assert (
984+
mask_index is not None
985+
), "Attention import currently only supports required mask_index"
986+
mask_index_shape = infer_shape(mask_index)
987+
assert (
988+
len(mask_index_shape) == 2
989+
and mask_index_shape[0] == batch_size
990+
and mask_index_shape[1] == seq_len
991+
), "currently only support (batch_size, sequence_length) mask index"
992+
993+
assert past is None, "past K, V state is not currently supported"
994+
assert extra_add is None, "extra add to QxK not currently supported"
995+
996+
# split weight and biases and do the matmuls
997+
w_Q, w_K, w_V = _op.split(weight, 3, axis=1)
998+
b_Q, b_K, b_V = _op.split(bias, 3, axis=0)
999+
# need to merge batch dimensions since TVM matmul is 2D
1000+
input_emb = _op.reverse_reshape(input_emb, (-1, 0))
1001+
Q = _op.add(_op.nn.matmul(input_emb, w_Q), b_Q)
1002+
K = _op.add(_op.nn.matmul(input_emb, w_K), b_K)
1003+
V = _op.add(_op.nn.matmul(input_emb, w_V), b_V)
1004+
1005+
# massage tensors in preparation for batched matmul
1006+
def massage(tensor):
1007+
tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, head_size))
1008+
1009+
# (batch_size, num_heads, seq_len, head_size)
1010+
tensor = _op.transpose(tensor, axes=[0, 2, 1, 3])
1011+
1012+
# (batch_size * num_heads, seq_len, head_size)
1013+
return _op.reverse_reshape(tensor, (-1, 0, 0))
1014+
1015+
Q = massage(Q)
1016+
K = massage(K)
1017+
V = massage(V)
1018+
1019+
K_present = _op.reshape(K, (batch_size, num_heads, seq_len, head_size))
1020+
V_present = _op.reshape(V, (batch_size, num_heads, seq_len, head_size))
1021+
present = _op.stack([K_present, V_present], axis=0)
1022+
1023+
att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, transpose_b=True)
1024+
score_dtype = infer_type(att_scores).checked_type.dtype
1025+
att_scores = _op.divide(
1026+
att_scores,
1027+
_op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype),
1028+
)
1029+
att_scores = _op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len))
1030+
1031+
# build the attention mask
1032+
att_mask = _op.cast(mask_index, score_dtype)
1033+
att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2)
1034+
att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask)
1035+
att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype))
1036+
1037+
# apply the mask
1038+
att_scores = _op.add(att_scores, att_mask)
1039+
att_scores = _op.reshape(att_scores, (batch_size * num_heads, seq_len, seq_len))
1040+
1041+
att_probs = _op.nn.softmax(att_scores, axis=-1)
1042+
1043+
output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, transpose_b=False)
1044+
output = _op.reverse_reshape(output, (-1, num_heads, 0, 0))
1045+
output = _op.transpose(output, axes=[0, 2, 1, 3])
1046+
output = _op.reshape(output, (0, 0, out_hidden))
1047+
1048+
return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)
1049+
1050+
8391051
class Gemm(OnnxOpConverter):
8401052
"""Operator converter for Gemm."""
8411053

@@ -4808,6 +5020,12 @@ def _get_convert_map(opset):
48085020
"Elu": Elu.get_converter(opset),
48095021
"Gelu": Gelu.get_converter(opset),
48105022
"BiasGelu": BiasGelu.get_converter(opset),
5023+
# TODO: We need a better way to handle different domains, in case
5024+
# of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention
5025+
# are in the `com.microsoft` domain.
5026+
"EmbedLayerNormalization": EmbedLayerNormalization.get_converter(opset),
5027+
"SkipLayerNormalization": SkipLayerNormalization.get_converter(opset),
5028+
"Attention": Attention.get_converter(opset),
48115029
"Exp": Renamer("exp"),
48125030
"Greater": Renamer("greater"),
48135031
"GreaterOrEqual": Renamer("greater_equal"),

0 commit comments

Comments
 (0)