Skip to content

Commit 77c1d8f

Browse files
committed
feat: support aten.index_put converter except accumulate True
1 parent 852b211 commit 77c1d8f

File tree

3 files changed

+164
-44
lines changed

3 files changed

+164
-44
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,22 @@ def aten_ops_select(
769769
)
770770

771771

772+
def index_put_accumulate_validator(node: Node) -> bool:
773+
if args_bounds_check(node.args, 3, False):
774+
_LOGGER.debug("We do not support accumulate=True for aten.index_put operation")
775+
return False
776+
else:
777+
return True
778+
779+
780+
@dynamo_tensorrt_converter(
781+
torch.ops.aten.index_put_.default,
782+
capability_validator=index_put_accumulate_validator,
783+
)
784+
@dynamo_tensorrt_converter(
785+
torch.ops.aten.index_put.default,
786+
capability_validator=index_put_accumulate_validator,
787+
)
772788
@dynamo_tensorrt_converter(torch.ops.aten.index_put.default)
773789
@dynamo_tensorrt_converter(torch.ops.aten.index_put_.default)
774790
@enforce_tensor_types(
@@ -777,7 +793,7 @@ def aten_ops_select(
777793
2: (TRTTensor,),
778794
}
779795
)
780-
def aten_ops_index_put_(
796+
def aten_ops_index_put(
781797
ctx: ConversionContext,
782798
target: Target,
783799
args: Tuple[Argument, ...],
@@ -792,7 +808,7 @@ def aten_ops_index_put_(
792808
args[0],
793809
args[1],
794810
args[2],
795-
args_bounds_check(args, 3, []),
811+
args_bounds_check(args, 3, False),
796812
)
797813

798814

@@ -3208,27 +3224,6 @@ def aten_ops_roll(
32083224
)
32093225

32103226

3211-
@enforce_tensor_types(
3212-
{
3213-
0: (TRTTensor,),
3214-
}
3215-
)
3216-
@dynamo_tensorrt_converter(torch.ops.aten.scatter.src)
3217-
# @dynamo_tensorrt_converter(torch.ops.aten.scatter.src.default)
3218-
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
3219-
# @dynamo_tensorrt_converter(torch.ops.aten.scatter.value.default)
3220-
def aten_ops_scatter(
3221-
ctx: ConversionContext,
3222-
target: Target,
3223-
args: Tuple[Argument, ...],
3224-
kwargs: Dict[str, Argument],
3225-
name: str,
3226-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3227-
return impl.select.scatter(
3228-
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
3229-
)
3230-
3231-
32323227
@dynamo_tensorrt_converter(torch.ops.aten.index_select.default)
32333228
@enforce_tensor_types(
32343229
{

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from torch.fx.node import Target
88
from torch_tensorrt.dynamo._SourceIR import SourceIR
9+
from torch_tensorrt.dynamo.conversion import impl
910
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1011
from torch_tensorrt.dynamo.conversion.converter_utils import (
1112
broadcastable,
@@ -410,7 +411,7 @@ def scatter(
410411
dim = get_positive_dim(dim, len(input_shape))
411412
src_tensor = src
412413
# scatter.value
413-
if isinstance(src, int) or isinstance(src, float):
414+
if isinstance(src, (int, float)):
414415
src_tensor = get_trt_tensor(
415416
ctx, src * np.ones(index_shape_list), name + "_value_tensor"
416417
)
@@ -440,33 +441,28 @@ def index_put_converter(
440441
values: TRTTensor,
441442
accumulate: bool = False,
442443
) -> TRTTensor:
443-
from torch_tensorrt.dynamo.conversion import impl
444-
445-
trt_inputs = []
444+
# Reshape indices to add an extra dimension if necessary (indices is a Tuple of ITensors)
445+
reshaped_indices = []
446446
for i, each_input in enumerate(indices):
447447
if not isinstance(each_input, TRTTensor):
448448
each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}")
449449
each_input = impl.shuffle.reshape(
450450
ctx,
451451
target,
452452
source_ir,
453-
f"{name}_broadcast_{i}",
453+
f"{name}_reshape_{i}",
454454
each_input,
455-
(each_input.shape[0],),
455+
(-1, 1), # Reshape to (N, 1)
456456
)
457-
trt_inputs.append(each_input)
458-
concat_layer = ctx.net.add_concatenation(trt_inputs)
459-
dim = get_positive_dim(0, len(indices[0].shape))
460-
concat_layer.axis = dim
461-
set_layer_name(concat_layer, target, f"{name}_gather", source_ir)
462-
indices = concat_layer.get_output(0)
463-
464-
values = impl.shuffle.reshape(
465-
ctx, target, source_ir, f"{name}_broadcast", values, (values.shape[0],)
457+
reshaped_indices.append(each_input)
458+
459+
# Concatenate along the second dimension (columns)
460+
indices_cat = impl.cat.cat(
461+
ctx, target, source_ir, f"{name}_cat", reshaped_indices, dim=1
466462
)
467463

468464
scatter_layer = ctx.net.add_scatter(
469-
input_tensor, indices, values, trt.ScatterMode.ELEMENT # trt.ScatterMode.ND
465+
input_tensor, indices_cat, values, trt.ScatterMode.ND
470466
)
471467
scatter_layer.axis = 0
472468
set_layer_name(scatter_layer, target, f"{name}_scatter_layer", source_ir)

tests/py/dynamo/conversion/test_index_put_aten.py

Lines changed: 134 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,140 @@ class TestIndexPutConverter(DispatchTestCase):
2222
indices_tensor=(torch.tensor([0, 3], dtype=torch.int32),),
2323
value_tensor=torch.tensor([1, 3], dtype=torch.int32),
2424
),
25+
param(
26+
test_name="2d_indices_single",
27+
source_tensor=torch.zeros([5, 5], dtype=torch.int32),
28+
indices_tensor=(
29+
torch.tensor([2], dtype=torch.int32),
30+
torch.tensor([0], dtype=torch.int32),
31+
),
32+
value_tensor=torch.tensor([3], dtype=torch.int32),
33+
),
34+
param(
35+
test_name="2d_indices_multiple",
36+
source_tensor=torch.zeros([5, 5], dtype=torch.int32),
37+
indices_tensor=(
38+
torch.tensor([0, 2, 2], dtype=torch.int32),
39+
torch.tensor([2, 0, 2], dtype=torch.int32),
40+
),
41+
value_tensor=torch.tensor([1, 3, 4], dtype=torch.int32),
42+
),
43+
param(
44+
test_name="3d_indices_single",
45+
source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32),
46+
indices_tensor=(
47+
torch.tensor([1], dtype=torch.int32),
48+
torch.tensor([2], dtype=torch.int32),
49+
torch.tensor([2], dtype=torch.int32),
50+
),
51+
value_tensor=torch.tensor([7], dtype=torch.int32),
52+
),
53+
param(
54+
test_name="3d_indices_multiple",
55+
source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32),
56+
indices_tensor=(
57+
torch.tensor([0, 1, 1], dtype=torch.int32),
58+
torch.tensor([1, 2, 1], dtype=torch.int32),
59+
torch.tensor([2, 0, 2], dtype=torch.int32),
60+
),
61+
value_tensor=torch.tensor([5, 7, 2], dtype=torch.int32),
62+
),
63+
param(
64+
test_name="4d_indices_single",
65+
source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32),
66+
indices_tensor=(
67+
torch.tensor([1], dtype=torch.int32),
68+
torch.tensor([1], dtype=torch.int32),
69+
torch.tensor([0], dtype=torch.int32),
70+
torch.tensor([1], dtype=torch.int32),
71+
),
72+
value_tensor=torch.tensor([5], dtype=torch.int32),
73+
),
74+
param(
75+
test_name="4d_indices_multiple",
76+
source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32),
77+
indices_tensor=(
78+
torch.tensor([0, 1], dtype=torch.int32),
79+
torch.tensor([1, 1], dtype=torch.int32),
80+
torch.tensor([1, 0], dtype=torch.int32),
81+
torch.tensor([1, 0], dtype=torch.int32),
82+
),
83+
value_tensor=torch.tensor([5, 7], dtype=torch.int32),
84+
),
85+
param(
86+
test_name="negative_indices",
87+
source_tensor=torch.zeros([5, 5], dtype=torch.int32),
88+
indices_tensor=(
89+
torch.tensor([-1, -2], dtype=torch.int32),
90+
torch.tensor([2, 0], dtype=torch.int32),
91+
),
92+
value_tensor=torch.tensor([1, 3], dtype=torch.int32),
93+
),
94+
param(
95+
test_name="mixed_indices",
96+
source_tensor=torch.zeros([4, 4], dtype=torch.int32),
97+
indices_tensor=(
98+
torch.tensor([0, 1, -1, -2], dtype=torch.int32),
99+
torch.tensor([0, -1, 2, 1], dtype=torch.int32),
100+
),
101+
value_tensor=torch.tensor([2, 4, 6, 8], dtype=torch.int32),
102+
),
103+
param(
104+
test_name="1d_indices_float",
105+
source_tensor=torch.zeros([5], dtype=torch.float32),
106+
indices_tensor=(torch.tensor([0, 3], dtype=torch.int32),),
107+
value_tensor=torch.tensor([1.5, 3.5], dtype=torch.float32),
108+
),
109+
param(
110+
test_name="2d_indices_float",
111+
source_tensor=torch.zeros([5, 5], dtype=torch.float32),
112+
indices_tensor=(
113+
torch.tensor([0, 2], dtype=torch.int32),
114+
torch.tensor([2, 0], dtype=torch.int32),
115+
),
116+
value_tensor=torch.tensor([1.5, 3.5], dtype=torch.float32),
117+
),
118+
param(
119+
test_name="3d_indices_float",
120+
source_tensor=torch.zeros([3, 3, 3], dtype=torch.float32),
121+
indices_tensor=(
122+
torch.tensor([0, 1], dtype=torch.int32),
123+
torch.tensor([1, 2], dtype=torch.int32),
124+
torch.tensor([2, 0], dtype=torch.int32),
125+
),
126+
value_tensor=torch.tensor([5.5, 7.5], dtype=torch.float32),
127+
),
128+
param(
129+
test_name="4d_indices_float",
130+
source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.float32),
131+
indices_tensor=(
132+
torch.tensor([0, 1], dtype=torch.int32),
133+
torch.tensor([1, 0], dtype=torch.int32),
134+
torch.tensor([0, 1], dtype=torch.int32),
135+
torch.tensor([1, 0], dtype=torch.int32),
136+
),
137+
value_tensor=torch.tensor([5.5, 7.5], dtype=torch.float32),
138+
),
139+
# param(
140+
# test_name="2d_indices_accumulate_True",
141+
# source_tensor=torch.zeros([5, 5], dtype=torch.int32),
142+
# indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)),
143+
# value_tensor=torch.tensor([1, 2], dtype=torch.int32),
144+
# accumulate=True,
145+
# ),
146+
# param(
147+
# test_name="3d_indices_accumulate_True",
148+
# source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32),
149+
# indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([2, 2], dtype=torch.int32)),
150+
# value_tensor=torch.tensor([1, 2], dtype=torch.int32),
151+
# accumulate=True,
152+
# ),
25153
# param(
26-
# test_name="2d_indices",
27-
# source_tensor=torch.zeros([5,5], dtype=torch.int32),
28-
# indices_tensor=(torch.tensor([0,2], dtype=torch.int32),torch.tensor([2,0], dtype=torch.int32),),
29-
# value_tensor=torch.tensor([1,3], dtype=torch.int32),
154+
# test_name="4d_indices_accumulate_True",
155+
# source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32),
156+
# indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)),
157+
# value_tensor=torch.tensor([1, 2], dtype=torch.int32),
158+
# accumulate=True,
30159
# ),
31160
]
32161
)
@@ -36,7 +165,7 @@ def test_index_put(
36165
class TestIndexPut(torch.nn.Module):
37166
def forward(self, source_tensor, value_tensor):
38167
return torch.ops.aten.index_put_.default(
39-
source_tensor, indices_tensor, value_tensor
168+
source_tensor, indices_tensor, value_tensor, accumulate
40169
)
41170

42171
self.run_test(

0 commit comments

Comments
 (0)