Skip to content

Commit 7577227

Browse files
author
Valery Chernov
committed
fix axis
1 parent 00a5f6c commit 7577227

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

python/tvm/topi/cuda/scatter_elements.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"):
6767
if not isinstance(axis, int):
6868
axis = get_const_int(axis)
6969

70-
def gen_ir(data, indices, updates, out):
70+
def gen_ir(data, indices, updates, out, axis):
7171
ib = tir.ir_builder.create()
7272

7373
data_ptr = ib.buffer_ptr(data)
@@ -92,7 +92,7 @@ def gen_ir(data, indices, updates, out):
9292
full_range = before_axis_range * before_axis_stride
9393

9494
ind_shape = indices.shape
95-
ind_axis_range = shape[axis]
95+
ind_axis_range = ind_shape[axis]
9696

9797
ind_before_axis_range = 1
9898
ind_after_axis_range = 1
@@ -173,7 +173,7 @@ def gen_ir(data, indices, updates, out):
173173
return te.extern(
174174
[data.shape],
175175
[data, indices, updates],
176-
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
176+
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], axis),
177177
dtype=data.dtype,
178178
out_buffers=[out_buf],
179179
name="scatter_elements_cuda",

0 commit comments

Comments
 (0)