Skip to content

Commit 8b43f4b

Browse files
authored
feat: dynamic shapes support for sqrt and copy (#2889)
1 parent 7182f50 commit 8b43f4b

File tree

3 files changed

+159
-2
lines changed

3 files changed

+159
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,7 +1289,7 @@ def aten_ops_log1p(
12891289
)
12901290

12911291

1292-
@dynamo_tensorrt_converter(torch.ops.aten.sqrt.default)
1292+
@dynamo_tensorrt_converter(torch.ops.aten.sqrt.default, supports_dynamic_shapes=True)
12931293
def aten_ops_sqrt(
12941294
ctx: ConversionContext,
12951295
target: Target,
@@ -2938,7 +2938,7 @@ def aten_ops_trunc(
29382938
)
29392939

29402940

2941-
@dynamo_tensorrt_converter(torch.ops.aten.copy.default)
2941+
@dynamo_tensorrt_converter(torch.ops.aten.copy.default, supports_dynamic_shapes=True)
29422942
@enforce_tensor_types(
29432943
{
29442944
1: (TRTTensor,),

tests/py/dynamo/conversion/test_copy_aten.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -26,6 +27,67 @@ def forward(self, input, src):
2627
inputs,
2728
)
2829

30+
@parameterized.expand(
31+
[
32+
(
33+
"1d_float32",
34+
torch.float32,
35+
(1,),
36+
(2,),
37+
(3,),
38+
False,
39+
),
40+
(
41+
"2d_float32",
42+
torch.float32,
43+
(1, 1),
44+
(2, 2),
45+
(3, 3),
46+
False,
47+
),
48+
(
49+
"3d_float32",
50+
torch.float32,
51+
(1, 1, 1),
52+
(2, 2, 2),
53+
(3, 3, 3),
54+
True,
55+
),
56+
(
57+
"4d_float32",
58+
torch.float32,
59+
(1, 1, 1, 1),
60+
(2, 2, 2, 2),
61+
(3, 3, 3, 3),
62+
True,
63+
),
64+
]
65+
)
66+
def test_dynamic_shape_copy_float(self, *args):
67+
class Copy(nn.Module):
68+
def forward(self, input, src):
69+
return torch.ops.aten.copy.default(input, src, args[5])
70+
71+
input_specs = [
72+
Input(
73+
min_shape=args[2],
74+
opt_shape=args[3],
75+
max_shape=args[4],
76+
dtype=args[1],
77+
),
78+
Input(
79+
min_shape=args[2],
80+
opt_shape=args[3],
81+
max_shape=args[4],
82+
dtype=args[1],
83+
),
84+
]
85+
86+
self.run_test_with_dynamic_shape(
87+
Copy(),
88+
input_specs,
89+
)
90+
2991

3092
if __name__ == "__main__":
3193
run_tests()

tests/py/dynamo/conversion/test_sqrt_aten.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -44,6 +45,100 @@ def forward(self, input):
4445
inputs,
4546
)
4647

48+
@parameterized.expand(
49+
[
50+
(
51+
"1d_float32",
52+
torch.float32,
53+
(1,),
54+
(2,),
55+
(3,),
56+
),
57+
(
58+
"2d_float32",
59+
torch.float32,
60+
(1, 1),
61+
(2, 2),
62+
(3, 3),
63+
),
64+
(
65+
"3d_float32",
66+
torch.float32,
67+
(1, 1, 1),
68+
(2, 2, 2),
69+
(3, 3, 3),
70+
),
71+
(
72+
"4d_float32",
73+
torch.float32,
74+
(1, 1, 1, 1),
75+
(2, 2, 2, 2),
76+
(3, 3, 3, 3),
77+
),
78+
]
79+
)
80+
def test_dynamic_shape_sqrt_float(self, *args):
81+
class sqrt(nn.Module):
82+
def forward(self, input):
83+
return torch.ops.aten.sqrt.default(input)
84+
85+
input_specs = [
86+
Input(
87+
min_shape=args[2],
88+
opt_shape=args[3],
89+
max_shape=args[4],
90+
dtype=args[1],
91+
),
92+
]
93+
self.run_test_with_dynamic_shape(sqrt(), input_specs)
94+
95+
@parameterized.expand(
96+
[
97+
(
98+
"1d_int32",
99+
torch.int,
100+
(1,),
101+
(2,),
102+
(3,),
103+
),
104+
(
105+
"2d_int32",
106+
torch.int32,
107+
(1, 1),
108+
(2, 2),
109+
(3, 3),
110+
),
111+
(
112+
"3d_int32",
113+
torch.int,
114+
(1, 1, 1),
115+
(2, 2, 2),
116+
(3, 3, 3),
117+
),
118+
(
119+
"4d_int32",
120+
torch.int32,
121+
(1, 1, 1, 1),
122+
(2, 2, 2, 2),
123+
(3, 3, 3, 3),
124+
),
125+
]
126+
)
127+
def test_dynamic_shape_sqrt_int(self, *args):
128+
class sqrt(nn.Module):
129+
def forward(self, input):
130+
return torch.ops.aten.sqrt.default(input)
131+
132+
input_specs = [
133+
Input(
134+
min_shape=args[2],
135+
opt_shape=args[3],
136+
max_shape=args[4],
137+
dtype=args[1],
138+
),
139+
]
140+
self.run_test_with_dynamic_shape(sqrt(), input_specs)
141+
47142

48143
if __name__ == "__main__":
49144
run_tests()

0 commit comments

Comments
 (0)