Skip to content

Commit 2c7e88a

Browse files
committed
fix linter actually
1 parent f76e2fa commit 2c7e88a

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

python/tvm/relay/op/_tensor_grad.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
equal,
4444
shape_of,
4545
log,
46-
concatenate
46+
concatenate,
4747
)
4848
from .transform import (
4949
broadcast_to_like,
@@ -62,7 +62,7 @@
6262
split,
6363
squeeze,
6464
strided_set,
65-
arange
65+
arange,
6666
)
6767

6868

@@ -679,7 +679,7 @@ def cross_entropy_with_logits_grad(orig, grad):
679679
def take_grad(orig, grad):
680680
def make_scalar_tensor(v):
681681
if isinstance(v, int):
682-
v = const(v, dtype='int32')
682+
v = const(v, dtype="int32")
683683
return reshape(v, (1,))
684684

685685
# TODO(@altanh): we currently assume indices are in range
@@ -690,7 +690,7 @@ def make_scalar_tensor(v):
690690
try:
691691
data_shape = data.checked_type.concrete_shape
692692
except TypeError:
693-
raise OpError('currently take_grad only supports data with concrete shape')
693+
raise OpError("currently take_grad only supports data with concrete shape")
694694
if axis is None:
695695
axis = 0
696696
data_grad = reshape(data_grad, (-1,))
@@ -710,7 +710,7 @@ def make_scalar_tensor(v):
710710
elif len(indices.checked_type.shape) == 1:
711711
num_indices = take(shape_of(indices), zero, axis=0)
712712
else:
713-
raise OpError('take_grad only supports scalar or 1D indices')
713+
raise OpError("take_grad only supports scalar or 1D indices")
714714

715715
def loop_cond(data_grad, i):
716716
return squeeze(less(i, num_indices))
@@ -731,8 +731,8 @@ def loop_body(data_grad, i):
731731
return (next_data_grad, i + one)
732732

733733
loop_vars = [
734-
Var('data_grad', type_annotation=TensorType(data_shape, data.checked_type.dtype)),
735-
Var('i', type_annotation=TensorType((1,), 'int32')),
734+
Var("data_grad", type_annotation=TensorType(data_shape, data.checked_type.dtype)),
735+
Var("i", type_annotation=TensorType((1,), "int32")),
736736
]
737737

738738
loop = while_loop(loop_cond, loop_vars, loop_body)
@@ -777,11 +777,11 @@ def expand_dims_grad(orig, grad):
777777
@register_gradient("arange")
778778
def arange_grad(orig, grad):
779779
start, stop, step = orig.args
780-
length = take(shape_of(orig), const(0, dtype='int32'), axis=0)
780+
length = take(shape_of(orig), const(0, dtype="int32"), axis=0)
781781

782782
grad_start = cast_like(_sum(grad), start)
783783
grad_stop = zeros_like(stop)
784-
grad_step = cast_like(arange(length, dtype='int32'), grad) * grad
784+
grad_step = cast_like(arange(length, dtype="int32"), grad) * grad
785785
grad_step = cast_like(_sum(grad_step), step)
786786

787787
return [grad_start, grad_stop, grad_step]

0 commit comments

Comments
 (0)