Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1760,8 +1760,8 @@ def aten_ops_minimum(
)


@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar)
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar, supports_dynamic_shapes=True)
def aten_ops_sub(
ctx: ConversionContext,
target: Target,
Expand All @@ -1777,7 +1777,7 @@ def aten_ops_sub(
ctx,
target,
SourceIR.ATEN,
name,
name + "_alpha",
other,
alpha,
)
Expand Down
95 changes: 95 additions & 0 deletions tests/py/dynamo/conversion/test_sub_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,101 @@ def forward(self, lhs_val):
inputs,
)

@parameterized.expand(
[
(
"3d_2d_alpha_float32",
torch.float32,
(1, 1, 1),
(3, 2, 2),
(3, 2, 4),
(1, 1),
(2, 2),
(2, 4),
1.5,
),
(
"2d_2d_alpha_int32",
torch.int32,
(3, 2),
(3, 2),
(3, 3),
(3, 2),
(3, 2),
(3, 3),
2,
),
]
)
def test_dynamic_shape_sub(self, *args):
class sub(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.sub.Tensor(lhs_val, rhs_val, alpha=args[8])

input_specs = [
Input(
min_shape=args[2],
opt_shape=args[3],
max_shape=args[4],
dtype=args[1],
),
Input(
min_shape=args[5],
opt_shape=args[6],
max_shape=args[7],
dtype=args[1],
),
]

self.run_test_with_dynamic_shape(sub(), input_specs)

@parameterized.expand(
[
(
"3d_scalar_float32",
torch.float32,
(1, 1, 1),
(3, 2, 2),
(3, 2, 4),
0.3,
)
]
)
def test_dynamic_shape_sub_scalar(self, *args):
class sub(nn.Module):
def forward(self, lhs_val):
return torch.ops.aten.sub.Tensor(lhs_val, args[5])

input_specs = [
Input(
min_shape=args[2],
opt_shape=args[3],
max_shape=args[4],
dtype=args[1],
),
]

self.run_test_with_dynamic_shape(sub(), input_specs)

@parameterized.expand(
[("scalar_2d_alpha_float32", torch.float32, (1, 1), (2, 2), (3, 4), 0.3, 1.5)]
)
def test_dynamic_shape_sub_scalar_alpha(self, *args):
class sub(nn.Module):
def forward(self, rhs_val):
return torch.ops.aten.sub.Tensor(args[5], rhs_val, alpha=args[6])

input_specs = [
Input(
min_shape=args[2],
opt_shape=args[3],
max_shape=args[4],
dtype=args[1],
),
]

self.run_test_with_dynamic_shape(sub(), input_specs)


if __name__ == "__main__":
run_tests()