@@ -67,7 +67,7 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"):
6767 if not isinstance (axis , int ):
6868 axis = get_const_int (axis )
6969
70- def gen_ir (data , indices , updates , out ):
70+ def gen_ir (data , indices , updates , out , axis ):
7171 ib = tir .ir_builder .create ()
7272
7373 data_ptr = ib .buffer_ptr (data )
@@ -92,7 +92,7 @@ def gen_ir(data, indices, updates, out):
9292 full_range = before_axis_range * before_axis_stride
9393
9494 ind_shape = indices .shape
95- ind_axis_range = shape [axis ]
95+ ind_axis_range = ind_shape [axis ]
9696
9797 ind_before_axis_range = 1
9898 ind_after_axis_range = 1
@@ -173,7 +173,7 @@ def gen_ir(data, indices, updates, out):
173173 return te .extern (
174174 [data .shape ],
175175 [data , indices , updates ],
176- lambda ins , outs : gen_ir (ins [0 ], ins [1 ], ins [2 ], outs [0 ]),
176+ lambda ins , outs : gen_ir (ins [0 ], ins [1 ], ins [2 ], outs [0 ], axis ),
177177 dtype = data .dtype ,
178178 out_buffers = [out_buf ],
179179 name = "scatter_elements_cuda" ,
0 commit comments