4343 equal ,
4444 shape_of ,
4545 log ,
46- concatenate
46+ concatenate ,
4747)
4848from .transform import (
4949 broadcast_to_like ,
6262 split ,
6363 squeeze ,
6464 strided_set ,
65- arange
65+ arange ,
6666)
6767
6868
@@ -679,7 +679,7 @@ def cross_entropy_with_logits_grad(orig, grad):
679679def take_grad (orig , grad ):
680680 def make_scalar_tensor (v ):
681681 if isinstance (v , int ):
682- v = const (v , dtype = ' int32' )
682+ v = const (v , dtype = " int32" )
683683 return reshape (v , (1 ,))
684684
685685 # TODO(@altanh): we currently assume indices are in range
@@ -690,7 +690,7 @@ def make_scalar_tensor(v):
690690 try :
691691 data_shape = data .checked_type .concrete_shape
692692 except TypeError :
693- raise OpError (' currently take_grad only supports data with concrete shape' )
693+ raise OpError (" currently take_grad only supports data with concrete shape" )
694694 if axis is None :
695695 axis = 0
696696 data_grad = reshape (data_grad , (- 1 ,))
@@ -710,7 +710,7 @@ def make_scalar_tensor(v):
710710 elif len (indices .checked_type .shape ) == 1 :
711711 num_indices = take (shape_of (indices ), zero , axis = 0 )
712712 else :
713- raise OpError (' take_grad only supports scalar or 1D indices' )
713+ raise OpError (" take_grad only supports scalar or 1D indices" )
714714
715715 def loop_cond (data_grad , i ):
716716 return squeeze (less (i , num_indices ))
@@ -731,8 +731,8 @@ def loop_body(data_grad, i):
731731 return (next_data_grad , i + one )
732732
733733 loop_vars = [
734- Var (' data_grad' , type_annotation = TensorType (data_shape , data .checked_type .dtype )),
735- Var ('i' , type_annotation = TensorType ((1 ,), ' int32' )),
734+ Var (" data_grad" , type_annotation = TensorType (data_shape , data .checked_type .dtype )),
735+ Var ("i" , type_annotation = TensorType ((1 ,), " int32" )),
736736 ]
737737
738738 loop = while_loop (loop_cond , loop_vars , loop_body )
@@ -777,11 +777,11 @@ def expand_dims_grad(orig, grad):
777777@register_gradient ("arange" )
778778def arange_grad (orig , grad ):
779779 start , stop , step = orig .args
780- length = take (shape_of (orig ), const (0 , dtype = ' int32' ), axis = 0 )
780+ length = take (shape_of (orig ), const (0 , dtype = " int32" ), axis = 0 )
781781
782782 grad_start = cast_like (_sum (grad ), start )
783783 grad_stop = zeros_like (stop )
784- grad_step = cast_like (arange (length , dtype = ' int32' ), grad ) * grad
784+ grad_step = cast_like (arange (length , dtype = " int32" ), grad ) * grad
785785 grad_step = cast_like (_sum (grad_step ), step )
786786
787787 return [grad_start , grad_stop , grad_step ]
0 commit comments