From 3ac0a8df465b7ef594ff4e06fe4e19f0ae6b007a Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Wed, 5 Jun 2024 11:24:14 +0000 Subject: [PATCH] feat: Add dynamic shape support for sub --- .../dynamo/conversion/aten_ops_converters.py | 6 +- tests/py/dynamo/conversion/test_sub_aten.py | 95 +++++++++++++++++++ 2 files changed, 98 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 09cc78f7db..d7ba5c4807 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1760,8 +1760,8 @@ def aten_ops_minimum( ) -@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar, supports_dynamic_shapes=True) def aten_ops_sub( ctx: ConversionContext, target: Target, @@ -1777,7 +1777,7 @@ def aten_ops_sub( ctx, target, SourceIR.ATEN, - name, + name + "_alpha", other, alpha, ) diff --git a/tests/py/dynamo/conversion/test_sub_aten.py b/tests/py/dynamo/conversion/test_sub_aten.py index fa4d8b5b80..cef6fac1d9 100644 --- a/tests/py/dynamo/conversion/test_sub_aten.py +++ b/tests/py/dynamo/conversion/test_sub_aten.py @@ -76,6 +76,101 @@ def forward(self, lhs_val): inputs, ) + @parameterized.expand( + [ + ( + "3d_2d_alpha_float32", + torch.float32, + (1, 1, 1), + (3, 2, 2), + (3, 2, 4), + (1, 1), + (2, 2), + (2, 4), + 1.5, + ), + ( + "2d_2d_alpha_int32", + torch.int32, + (3, 2), + (3, 2), + (3, 3), + (3, 2), + (3, 2), + (3, 3), + 2, + ), + ] + ) + def test_dynamic_shape_sub(self, *args): + class sub(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.sub.Tensor(lhs_val, rhs_val, alpha=args[8]) + + input_specs = [ + Input( + min_shape=args[2], + opt_shape=args[3], + max_shape=args[4], + dtype=args[1], + ), + Input( + min_shape=args[5], + opt_shape=args[6], + max_shape=args[7], + dtype=args[1], + ), + ] + + self.run_test_with_dynamic_shape(sub(), input_specs) + + @parameterized.expand( + [ + ( + "3d_scalar_float32", + torch.float32, + (1, 1, 1), + (3, 2, 2), + (3, 2, 4), + 0.3, + ) + ] + ) + def test_dynamic_shape_sub_scalar(self, *args): + class sub(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.sub.Tensor(lhs_val, args[5]) + + input_specs = [ + Input( + min_shape=args[2], + opt_shape=args[3], + max_shape=args[4], + dtype=args[1], + ), + ] + + self.run_test_with_dynamic_shape(sub(), input_specs) + + @parameterized.expand( + [("scalar_2d_alpha_float32", torch.float32, (1, 1), (2, 2), (3, 4), 0.3, 1.5)] + ) + def test_dynamic_shape_sub_scalar_alpha(self, *args): + class sub(nn.Module): + def forward(self, rhs_val): + return torch.ops.aten.sub.Tensor(args[5], rhs_val, alpha=args[6]) + + input_specs = [ + Input( + min_shape=args[2], + opt_shape=args[3], + max_shape=args[4], + dtype=args[1], + ), + ] + + self.run_test_with_dynamic_shape(sub(), input_specs) + if __name__ == "__main__": run_tests()