Skip to content

Commit 9bc1e0b

Browse files
committed
feat: support aten.as_strided converter
1 parent 4142d3f commit 9bc1e0b

File tree

3 files changed

+143
-0
lines changed

3 files changed

+143
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2186,6 +2186,26 @@ def aten_ops_linear(
21862186
)
21872187

21882188

2189+
@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default)
2190+
def aten_ops_as_strided(
2191+
ctx: ConversionContext,
2192+
target: Target,
2193+
args: Tuple[Argument, ...],
2194+
kwargs: Dict[str, Argument],
2195+
name: str,
2196+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2197+
return impl.slice.as_strided(
2198+
ctx,
2199+
target,
2200+
source_ir=SourceIR.ATEN,
2201+
name=name,
2202+
input=args[0],
2203+
size=args[1],
2204+
stride=args[2],
2205+
storage_offset=args_bounds_check(args, 3, 0),
2206+
)
2207+
2208+
21892209
def avg_pool_param_validator(pool_node: Node) -> bool:
21902210
ceil_mode = args_bounds_check(pool_node.args, 4, False)
21912211
divisor_override = args_bounds_check(pool_node.args, 6)

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
import numpy as np
55
import tensorrt as trt
6+
import torch
67
from torch.fx.node import Target
78
from torch_tensorrt.dynamo._SourceIR import SourceIR
89
from torch_tensorrt.dynamo.conversion import impl
910
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1011
from torch_tensorrt.dynamo.conversion.converter_utils import (
12+
flatten_dims,
1113
get_positive_dim,
1214
get_trt_tensor,
1315
)
@@ -259,3 +261,59 @@ def flip(
259261
)
260262
set_layer_name(layer, target, name, source_ir)
261263
return layer.get_output(0)
264+
265+
266+
def as_strided(
267+
ctx: ConversionContext,
268+
target: Target,
269+
source_ir: Optional[SourceIR],
270+
name: str,
271+
input: TRTTensor,
272+
size: Sequence[int],
273+
stride: Sequence[int],
274+
storage_offset: int,
275+
) -> TRTTensor:
276+
assert len(size) == len(stride), "size and stride shapes must be the same"
277+
278+
flatten_shape = flatten_dims(input, 0, -1)
279+
flatten_output = impl.shuffle.reshape(
280+
ctx, target, source_ir, f"{name}_reshape", input, flatten_shape
281+
)
282+
283+
indices = []
284+
285+
# Recursive function to compute indices for as_strided operation
286+
def nested(
287+
rank: int, size: Sequence[int], stride: Sequence[int], current: int, dim: int
288+
) -> None:
289+
if (
290+
dim == rank
291+
): # If the current dimension equals the rank, append the computed index
292+
indices.append(current)
293+
return
294+
for i in range(size[dim]): # Recursively compute indices across dimensions
295+
nested(
296+
rank, size, stride, current + stride[dim] * i, dim + 1
297+
) # Calculate the index for the current dimension and recursively explore further dimensions
298+
299+
nested(len(size), size, stride, storage_offset, 0)
300+
301+
indices = torch.tensor(indices, dtype=torch.int)
302+
303+
indices_tensor = get_trt_tensor(ctx, (indices), f"{name}_indices")
304+
305+
# Use gather to reorder elements based on computed indices
306+
gather_layer = ctx.net.add_gather(flatten_output, indices_tensor, axis=0)
307+
gather_output = gather_layer.get_output(0)
308+
309+
# Reshape the gathered tensor to the desired size
310+
reshape_output = impl.shuffle.reshape(
311+
ctx,
312+
target,
313+
source_ir,
314+
f"{name}_reshape",
315+
gather_output,
316+
tuple(size),
317+
)
318+
319+
return reshape_output
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
from parameterized import parameterized
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt import Input
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestAsStridedConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
(
13+
(5, 5),
14+
(2, 3),
15+
(1, 2),
16+
0,
17+
),
18+
(
19+
(5, 5),
20+
(2, 3),
21+
(2, 2),
22+
1,
23+
),
24+
(
25+
(20, 20),
26+
(2, 3, 2),
27+
(2, 2, 2),
28+
0,
29+
),
30+
(
31+
(8, 8, 8),
32+
(2, 2, 3),
33+
(1, 2, 2),
34+
1,
35+
),
36+
(
37+
(200, 200, 200),
38+
(9, 9, 3, 2),
39+
(2, 2, 2, 3),
40+
1,
41+
),
42+
]
43+
)
44+
def test_as_strided(
45+
self,
46+
input_shape,
47+
output_size,
48+
stride,
49+
storage_offset=0,
50+
):
51+
class TestModule(torch.nn.Module):
52+
def forward(self, x):
53+
return torch.ops.aten.as_strided.default(
54+
x, output_size, stride, storage_offset
55+
)
56+
57+
inputs = [torch.randn(input_shape)]
58+
self.run_test(
59+
TestModule(),
60+
inputs,
61+
)
62+
63+
64+
if __name__ == "__main__":
65+
run_tests()

0 commit comments

Comments
 (0)