@@ -329,6 +329,22 @@ def flatten_to_nd(x, x_shape, nd=3):
329329 return _op .nn .dense (inputs [0 ], input_1_t , out_dtype = out_dtype )
330330
331331
332+ def layer_norm (x , eps , gamma , beta ):
333+ """Common function to handle layer norm"""
334+ eps_dtype = infer_type (x ).checked_type .dtype
335+
336+ u , s = _op .mean_variance (x , axis = - 1 , keepdims = True )
337+ output = _op .divide (
338+ _op .subtract (x , u ),
339+ _op .sqrt (_op .add (s , _op .const (eps , dtype = eps_dtype ))),
340+ )
341+ output = _op .multiply (output , gamma )
342+ if beta is not None :
343+ output = _op .add (output , beta )
344+
345+ return output
346+
347+
332348class OnnxOpConverter (object ):
333349 """A helper class for holding onnx op converters."""
334350
@@ -807,9 +823,10 @@ def _impl_v1(cls, inputs, attr, params):
807823 x = inputs [0 ]
808824
809825 # Declare consts
810- half = _expr .const (0.5 )
811- one = _expr .const (1.0 )
812- sqrt2 = _expr .const (math .sqrt (2 ))
826+ const_dtype = infer_type (x ).checked_type .dtype
827+ half = _expr .const (0.5 , dtype = const_dtype )
828+ one = _expr .const (1.0 , dtype = const_dtype )
829+ sqrt2 = _expr .const (math .sqrt (2 ), dtype = const_dtype )
813830
814831 # Compute gelu
815832 term1 = _op .multiply (half , x )
@@ -836,6 +853,201 @@ def _impl_v1(cls, inputs, attr, params):
836853 return Gelu ._impl_v1 ([inp ], attr , params )
837854
838855
856+ class EmbedLayerNormalization (OnnxOpConverter ):
857+ """Operator converter for EmbedLayerNormalization from Microsoft onnxruntime contrib opset.
858+
859+ This layer embeds the input tokens, sums them, and applies layer normalization.
860+ """
861+
862+ @classmethod
863+ def _impl_v1 (cls , inputs , attr , params ):
864+ input_ids = inputs [0 ]
865+ segment_ids = inputs [1 ]
866+ word_emb = inputs [2 ]
867+ pos_emb = inputs [3 ]
868+ segment_emb = inputs [4 ]
869+ gamma = inputs [5 ]
870+ beta = inputs [6 ]
871+
872+ mask = inputs [7 ]
873+ pos_ids = inputs [8 ]
874+
875+ eps = attr .get ("epsilon" , 1e-12 )
876+
877+ (batch_size , seq_len ) = infer_shape (input_ids )
878+
879+ if segment_ids :
880+ assert segment_emb
881+
882+ if pos_ids is None :
883+ pos_ids = _op .const ([list (range (seq_len ))] * seq_len , dtype = "int32" )
884+
885+ word_vec = _op .take (word_emb , input_ids , axis = 0 )
886+ segment_vec = _op .take (segment_emb , segment_ids , axis = 0 )
887+ pos_vec = _op .take (pos_emb , pos_ids , axis = 0 )
888+
889+ vec_sum = _op .add (word_vec , pos_vec )
890+ if segment_ids :
891+ vec_sum = _op .add (vec_sum , segment_vec )
892+
893+ ln = layer_norm (vec_sum , eps , gamma , beta )
894+
895+ mask_index = _op .const (np .zeros ((batch_size ,), dtype = "int32" ))
896+ if mask :
897+ # calculate number of words per sentence
898+ mask_index = _op .sum (mask , axis = 1 )
899+
900+ # TODO(@anwang2009): onnxruntime v1.10.0 requires a third output of vec_sum
901+ return _expr .TupleWrapper (_expr .Tuple ([ln , mask_index ]), 2 )
902+
903+
904+ class SkipLayerNormalization (OnnxOpConverter ):
905+ """Operator converter for SkipLayerNormalization from Microsoft onnxruntime contrib opset.
906+
907+ This layer sums the two input tensors (along with optional bias), and applies layer
908+ normalization.
909+ """
910+
911+ @classmethod
912+ def _impl_v1 (cls , inputs , attr , params ):
913+ data = inputs [0 ]
914+ skip = inputs [1 ]
915+ gamma = inputs [2 ]
916+ beta = inputs [3 ]
917+ bias = inputs [4 ]
918+
919+ assert (
920+ beta is not None and bias is not None
921+ ), "SkipLayerNormalization import currently only supports required beta and bias"
922+
923+ eps = attr .get ("epsilon" , 1e-12 )
924+
925+ x = _op .add (data , skip )
926+ if bias is not None :
927+ x = _op .add (x , bias )
928+
929+ output = layer_norm (x , eps , gamma , beta )
930+
931+ # onnxruntime doesn't compute the other outputs, despite the documentation
932+ placeholder = _op .const (0 , dtype = "float32" )
933+
934+ return _expr .TupleWrapper (_expr .Tuple ([output , placeholder , placeholder ]), 3 )
935+
936+
937+ class Attention (OnnxOpConverter ):
938+ """Operator converter for Attention from Microsoft onnxruntime contrib opset.
939+
940+ This is the self-attention mechanism used in transformer models.
941+ """
942+
943+ @classmethod
944+ def _impl_v1 (cls , inputs , attr , params ):
945+ num_heads = attr ["num_heads" ]
946+ assert (
947+ "qkv_hidden_sizes" not in attr
948+ ), "different hidden sizes for Q, K, V are not currently supported"
949+ assert "unidirectional" not in attr , "unidirectional attention not current supported"
950+
951+ # (batch, seq, in_hidden)
952+ input_emb = inputs [0 ]
953+
954+ # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
955+ weight = inputs [1 ]
956+
957+ # (3 * out_hidden,)
958+ bias = inputs [2 ]
959+
960+ # 1. ( batch, 1, max_seq, max_seq)
961+ # 2. ( batch, past_seq + seq,)
962+ # 3. ( batch, seq, past_seq + seq,)
963+ # 4. ( batch,)
964+ # 5. (2 * batch,)
965+ # For now, we only support case 2.
966+ mask_index = inputs [3 ]
967+
968+ # (2, batch, num_heads, past_seq, head_size)
969+ past = inputs [4 ]
970+
971+ # (batch, num_heads, seq, seq)
972+ extra_add = inputs [5 ]
973+
974+ (batch_size , seq_len , _ ) = infer_shape (input_emb )
975+ (out_hidden_x3 ,) = infer_shape (bias )
976+ assert out_hidden_x3 % 3 == 0 , "bias shape should be divisible by 3"
977+ out_hidden = out_hidden_x3 // 3
978+ assert (
979+ out_hidden % num_heads == 0
980+ ), "output hidden size should be divisible by number of attention heads"
981+ head_size = out_hidden // num_heads
982+
983+ assert (
984+ mask_index is not None
985+ ), "Attention import currently only supports required mask_index"
986+ mask_index_shape = infer_shape (mask_index )
987+ assert (
988+ len (mask_index_shape ) == 2
989+ and mask_index_shape [0 ] == batch_size
990+ and mask_index_shape [1 ] == seq_len
991+ ), "currently only support (batch_size, sequence_length) mask index"
992+
993+ assert past is None , "past K, V state is not currently supported"
994+ assert extra_add is None , "extra add to QxK not currently supported"
995+
996+ # split weight and biases and do the matmuls
997+ w_Q , w_K , w_V = _op .split (weight , 3 , axis = 1 )
998+ b_Q , b_K , b_V = _op .split (bias , 3 , axis = 0 )
999+ # need to merge batch dimensions since TVM matmul is 2D
1000+ input_emb = _op .reverse_reshape (input_emb , (- 1 , 0 ))
1001+ Q = _op .add (_op .nn .matmul (input_emb , w_Q ), b_Q )
1002+ K = _op .add (_op .nn .matmul (input_emb , w_K ), b_K )
1003+ V = _op .add (_op .nn .matmul (input_emb , w_V ), b_V )
1004+
1005+ # massage tensors in preparation for batched matmul
1006+ def massage (tensor ):
1007+ tensor = _op .reshape (tensor , (batch_size , seq_len , num_heads , head_size ))
1008+
1009+ # (batch_size, num_heads, seq_len, head_size)
1010+ tensor = _op .transpose (tensor , axes = [0 , 2 , 1 , 3 ])
1011+
1012+ # (batch_size * num_heads, seq_len, head_size)
1013+ return _op .reverse_reshape (tensor , (- 1 , 0 , 0 ))
1014+
1015+ Q = massage (Q )
1016+ K = massage (K )
1017+ V = massage (V )
1018+
1019+ K_present = _op .reshape (K , (batch_size , num_heads , seq_len , head_size ))
1020+ V_present = _op .reshape (V , (batch_size , num_heads , seq_len , head_size ))
1021+ present = _op .stack ([K_present , V_present ], axis = 0 )
1022+
1023+ att_scores = _op .nn .batch_matmul (Q , K , transpose_a = False , transpose_b = True )
1024+ score_dtype = infer_type (att_scores ).checked_type .dtype
1025+ att_scores = _op .divide (
1026+ att_scores ,
1027+ _op .const (np .sqrt (head_size ), dtype = infer_type (att_scores ).checked_type .dtype ),
1028+ )
1029+ att_scores = _op .reshape (att_scores , (batch_size , num_heads , seq_len , seq_len ))
1030+
1031+ # build the attention mask
1032+ att_mask = _op .cast (mask_index , score_dtype )
1033+ att_mask = _op .expand_dims (att_mask , 1 , num_newaxis = 2 )
1034+ att_mask = _op .subtract (_op .const (1 , dtype = score_dtype ), att_mask )
1035+ att_mask = _op .multiply (att_mask , _op .const (- 10000 , dtype = score_dtype ))
1036+
1037+ # apply the mask
1038+ att_scores = _op .add (att_scores , att_mask )
1039+ att_scores = _op .reshape (att_scores , (batch_size * num_heads , seq_len , seq_len ))
1040+
1041+ att_probs = _op .nn .softmax (att_scores , axis = - 1 )
1042+
1043+ output = _op .nn .batch_matmul (att_probs , V , transpose_a = False , transpose_b = False )
1044+ output = _op .reverse_reshape (output , (- 1 , num_heads , 0 , 0 ))
1045+ output = _op .transpose (output , axes = [0 , 2 , 1 , 3 ])
1046+ output = _op .reshape (output , (0 , 0 , out_hidden ))
1047+
1048+ return _expr .TupleWrapper (_expr .Tuple ([output , present ]), 2 )
1049+
1050+
8391051class Gemm (OnnxOpConverter ):
8401052 """Operator converter for Gemm."""
8411053
@@ -4808,6 +5020,12 @@ def _get_convert_map(opset):
48085020 "Elu" : Elu .get_converter (opset ),
48095021 "Gelu" : Gelu .get_converter (opset ),
48105022 "BiasGelu" : BiasGelu .get_converter (opset ),
5023+ # TODO: We need a better way to handle different domains, in case
5024+ # of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention
5025+ # are in the `com.microsoft` domain.
5026+ "EmbedLayerNormalization" : EmbedLayerNormalization .get_converter (opset ),
5027+ "SkipLayerNormalization" : SkipLayerNormalization .get_converter (opset ),
5028+ "Attention" : Attention .get_converter (opset ),
48115029 "Exp" : Renamer ("exp" ),
48125030 "Greater" : Renamer ("greater" ),
48135031 "GreaterOrEqual" : Renamer ("greater_equal" ),
0 commit comments