@@ -432,6 +432,38 @@ def _assign_ones_like_indice(self, node: Node, node_idx: int):
432432 """
433433 self ._assign_all_indice (node , node_idx )
434434
435+ def _assign_cat_indice (self , node : Node , node_idx : int ):
436+ """
437+ Assign indice for cat op.
438+
439+ Args:
440+ node (node)
441+ node_idx (int)
442+ """
443+ nodes_in = flat_list (node .args [0 ])
444+ self ._assign_indice_as_input (node , node_idx , input_node = nodes_in [0 ])
445+ for n in nodes_in [1 :]:
446+ self ._mark_computation_from_node (n , node )
447+ cat_dim = node .kwargs ["dim" ]
448+ self ._del_dim (node_idx , cat_dim )
449+ self ._add_dim (node_idx , cat_dim )
450+
451+ def _assign_sum_indice (self , node : Node , node_idx : int ):
452+ """
453+ Assign indice for sum op.
454+
455+ Args:
456+ node (node)
457+ node_idx (int)
458+ """
459+ nodes_in = flat_list (node .args [0 ])
460+ self ._add_dim (node_idx , 0 )
461+ self ._assign_indice_as_input (node , node_idx , input_node = nodes_in [0 ])
462+ for n in nodes_in [1 :]:
463+ self ._mark_computation_from_node (n , node )
464+ cat_dim = node .kwargs ["dim" ]
465+ self ._del_dim (node_idx , cat_dim )
466+
435467 def _assign_getitem_indice (self , node : Node , node_idx : int ):
436468 """
437469 Assign indice for getitem.
@@ -442,7 +474,16 @@ def _assign_getitem_indice(self, node: Node, node_idx: int):
442474 node_idx (int)
443475 """
444476 node_args = flat_list (node .args [1 :])
445- if not any (i == str (node_arg ) for i in ["None" , "Ellipsis" ] for node_arg in node_args ):
477+ flag = False
478+ for node_arg in node_args :
479+ node_arg_str = str (node_arg )
480+ if any (i == node_arg_str for i in ["None" , "Ellipsis" ]):
481+ flag = True
482+ break
483+ if "slice" in node_arg_str :
484+ flag = True
485+ break
486+ if flag == False :
446487 return
447488
448489 # node args should be like [Ellipsis, slice(start, step, end), None]
@@ -461,8 +502,11 @@ def _assign_getitem_indice(self, node: Node, node_idx: int):
461502 shape_gap = len (node_shape ) - len (node_args ) + 1
462503 origin_idx_count += shape_gap
463504 new_idx_count += shape_gap
464- # slice(None, None, None) means all indexes, doesn't support other slice
465- elif "slice(None, None, None)" == node_arg_str :
505+ # slice(None, None, None) means all indexes
506+ elif "slice" in node_arg_str :
507+ if "slice(None, None, None)" != node_arg_str :
508+ self ._del_dim (node_idx , new_idx_count )
509+ self ._add_dim (node_idx , new_idx_count )
466510 origin_idx_count += 1
467511 new_idx_count += 1
468512 # None means a new dim
@@ -565,7 +609,7 @@ def trace_indice(self):
565609 self ._assign_view_reshape_indice (node , idx )
566610 elif "unsqueeze" in node .name :
567611 self ._assign_unsqueeze_indice (node , idx )
568- elif any (i in node .name for i in ["to" , "contiguous" ]):
612+ elif any (i in node .name for i in ["to" , "contiguous" , "clone" ]):
569613 self ._assgin_no_change_indice (node , idx )
570614 elif "new_ones" in node .name :
571615 self ._assign_ones_like_indice (node , idx )
@@ -574,6 +618,8 @@ def trace_indice(self):
574618 elif node .op == "call_function" :
575619 if "linear" in node .name :
576620 self ._assign_linear_indice (node , idx )
621+ elif "cat" in node .name :
622+ self ._assign_cat_indice (node , idx )
577623 elif "matmul" in node .name :
578624 self ._assign_matmul_indice (node , idx )
579625 elif "softmax" in node .name :
@@ -586,6 +632,8 @@ def trace_indice(self):
586632 self ._assign_dropout_indice (node , idx )
587633 elif "einsum" in node .name :
588634 self ._assign_einsum_indice (node , idx )
635+ elif "sum" in node .name :
636+ self ._assign_sum_indice (node , idx )
589637 elif "layer_norm" in node .name :
590638 self ._assign_layernorm_indice (node , idx )
591639 elif "getitem" in node .name :
0 commit comments