22"""Tensor transformation ops"""
33from __future__ import absolute_import
44
5- import tvm
65import topi
76from .tensor import _fschedule_broadcast , _fschedule_injective
87from . import registry as reg
98from .registry import OpPattern
109
11- # Need add reshape
10+ # expand_dims
1211@reg .register_compute ("expand_dims" )
1312def compute_expand_dims (attrs , inputs , out_info ):
1413 """Compute definition of expand_dims"""
@@ -18,34 +17,46 @@ def compute_expand_dims(attrs, inputs, out_info):
1817reg .register_pattern ("expand_dims" , OpPattern .BROADCAST )
1918reg .register_schedule ("expand_dims" , _fschedule_broadcast )
2019
21-
20+ # transpose
2221@reg .register_compute ("transpose" )
2322def compute_transpose (attrs , inputs , out_info ):
24- """Compute definition of expand_dims """
23+ """Compute definition of transpose """
2524 axes = attrs .get_int_tuple ("axes" )
2625 axes = tuple (axes ) if axes else None
2726 return topi .transpose (inputs [0 ], axes )
2827reg .register_pattern ("transpose" , OpPattern .INJECTIVE )
2928reg .register_schedule ("transpose" , _fschedule_injective )
3029
31-
32- def _flatten_index (indices , shape ):
33- """flatten the index to 1D"""
34- idx = 0
35- for i , value in enumerate (shape ):
36- if i != 0 :
37- idx *= value
38- idx = idx + indices [i ]
39- return idx
40-
4130# reshape
4231@reg .register_compute ("reshape" )
4332def compute_reshape (attrs , inputs , out_info ):
44- """Compute definition of softmax"""
45- # TODO(sxj) add support for general reshape
46- assert len (inputs [0 ].shape ) == 1 , "Only support 1d input for now"
33+ """Compute definition of reshape"""
4734 oshape = out_info [0 ].shape
48- x = inputs [0 ]
49- return tvm .compute (oshape , lambda * i : x (_flatten_index (i , oshape )))
35+ return topi .reshape (inputs [0 ], oshape )
5036reg .register_pattern ("reshape" , OpPattern .INJECTIVE )
5137reg .register_schedule ("reshape" , _fschedule_injective )
38+
39+ # concatenate
40+ @reg .register_compute ("concatenate" )
41+ def compute_concatenate (attrs , inputs , out_info ):
42+ """Compute definition of concatenate"""
43+ axis = attrs .get_int ("axis" )
44+ return topi .concatenate ([x for x in inputs ], axis = axis )
45+
46+ reg .register_pattern ("concatenate" , OpPattern .INJECTIVE )
47+ reg .register_schedule ("concatenate" , _fschedule_injective )
48+
49+ # split
50+ @reg .register_compute ("split" )
51+ def compute_split (attrs , inputs , out_info ):
52+ """Compute definition of split"""
53+ x = attrs ["indices_or_sections" ]
54+ if x .startswith ("(" ) or x .startswith ("[" ):
55+ indices_or_sections = attrs .get_int_tuple ("indices_or_sections" )
56+ else :
57+ indices_or_sections = attrs .get_int ("indices_or_sections" )
58+ return topi .split (inputs [0 ], indices_or_sections , axis = attrs .get_int ("axis" ))
59+
60+
61+ reg .register_pattern ("split" , OpPattern .INJECTIVE )
62+ reg .register_schedule ("split" , _fschedule_injective )
0 commit comments