Skip to content

Commit 9e1a76f

Browse files
replace itemset due to numpy version 2.0 removed itemset api (#2879)
1 parent 9163720 commit 9e1a76f

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

py/torch_tensorrt/dynamo/conversion/impl/embedding.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,13 @@ def embedding_bag_with_traversable_offsets(
9494
# however, pytorch doc says if `include_last_offset` is True, the size of offsets
9595
# is equal to the number of bags + 1. The last element is the size of the input,
9696
# or the ending index position of the last bag (sequence).
97-
offsets.itemset(-1, len_embed)
97+
98+
# Notes: here offsets should always be 1d array
99+
if len(offsets.shape) != 1:
100+
raise TypeError(
101+
f"The offsets should be in 1d array, here offset shape is {offsets.shape}."
102+
)
103+
offsets[-1] = len_embed
98104
else:
99105
# add the end index to offsets
100106
offsets = np.append(offsets, len_embed)

0 commit comments

Comments
 (0)