Skip to content

Commit 27a9f6d

Browse files
authored
fix: type error in embedding_bag (#2418)
1 parent 3e612c1 commit 27a9f6d

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

py/torch_tensorrt/dynamo/conversion/impl/embedding.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def embedding_bag(
7575

7676
# TODO: support 2D inputs
7777
# indices = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape_indices", indices, (-1,))
78-
78+
reduce_name = ""
7979
if mode == 0: # sum
8080
reduce_op = functools.partial(
8181
impl.reduce.sum, ctx=ctx, target=target, source_ir=source_ir
@@ -143,7 +143,6 @@ def embedding_bag(
143143
# however, pytorch doc says if `include_last_offset` is True, the size of offsets
144144
# is equal to the number of bags + 1. The last element is the size of the input,
145145
# or the ending index position of the last bag (sequence).
146-
147146
offsets[-1] = indices.shape[0]
148147

149148
# separately reduce embeddings for different bags
@@ -158,8 +157,8 @@ def embedding_bag(
158157
f"{name}_slice_embed_{i}",
159158
embed,
160159
0,
161-
offsets[i],
162-
offsets[i + 1],
160+
int(offsets[i]),
161+
int(offsets[i + 1]),
163162
1,
164163
)
165164
reduced_sliced_embed = reduce_op(

0 commit comments

Comments
 (0)