From 8bc398735fed22c3ecf0cf4e99b0dbb18c19bc09 Mon Sep 17 00:00:00 2001 From: Lan Luo Date: Mon, 3 Jun 2024 20:21:07 -0700 Subject: [PATCH 1/2] add dynamic shape support for aten.ops.gt and aten.ops.ge --- tests/py/dynamo/conversion/test_ge_aten.py | 51 ++++++++++++++++++++++ tests/py/dynamo/conversion/test_gt_aten.py | 51 ++++++++++++++++++++++ 2 files changed, 102 insertions(+) diff --git a/tests/py/dynamo/conversion/test_ge_aten.py b/tests/py/dynamo/conversion/test_ge_aten.py index bacfedafc8..a803c1c6b1 100644 --- a/tests/py/dynamo/conversion/test_ge_aten.py +++ b/tests/py/dynamo/conversion/test_ge_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -61,6 +62,56 @@ def forward(self, lhs_val): inputs, ) + @parameterized.expand( + [ + ("2d_2d", (5, 3), (5, 1)), + ("3d_2d", (5, 3, 2), (3, 1)), + ("4d_3d", (5, 3, 4, 1), (3, 1, 1)), + ] + ) + def test_ge_tensor_broadcast(self, _, lshape, rshape): + class ge(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.ge.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 3, lshape, dtype=torch.int32), + torch.randint(0, 3, rshape, dtype=torch.int32), + ] + self.run_test( + ge(), + inputs, + ) + + @parameterized.expand( + [ + ("2d_2d", (2, 3), (4, 3), (5, 3), (2, 3), (4, 3), (5, 3)), + ("3d_2d", (2, 2, 2), (2, 3, 2), (2, 4, 2), (2, 1), (3, 1), (4, 1)), + ] + ) + def test_ge_dynamic_tensor(self, *args): + class ge(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.ge.Tensor(lhs_val, rhs_val) + + input_specs = [ + Input( + min_shape=args[1], + opt_shape=args[2], + max_shape=args[3], + ), + Input( + min_shape=args[4], + opt_shape=args[5], + max_shape=args[6], + ), + ] + + self.run_test_with_dynamic_shape( + ge(), + input_specs, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_gt_aten.py b/tests/py/dynamo/conversion/test_gt_aten.py index 0eab7c84ff..48e56f71d3 100644 --- a/tests/py/dynamo/conversion/test_gt_aten.py +++ b/tests/py/dynamo/conversion/test_gt_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -58,6 +59,56 @@ def forward(self, lhs_val): inputs, ) + @parameterized.expand( + [ + ("2d_2d", (5, 3), (5, 1)), + ("3d_2d", (5, 3, 2), (3, 1)), + ("4d_3d", (5, 3, 4, 1), (3, 1, 1)), + ] + ) + def test_gt_tensor_broadcast(self, _, lshape, rshape): + class gt(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.gt.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 3, lshape, dtype=torch.int32), + torch.randint(0, 3, rshape, dtype=torch.int32), + ] + self.run_test( + gt(), + inputs, + ) + + @parameterized.expand( + [ + ("2d_2d", (2, 3), (4, 3), (5, 3), (2, 3), (4, 3), (5, 3)), + ("3d_2d", (2, 2, 2), (2, 3, 2), (2, 4, 2), (2, 1), (3, 1), (4, 1)), + ] + ) + def test_gt_dynamic_tensor(self, *args): + class gt(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.gt.Tensor(lhs_val, rhs_val) + + input_specs = [ + Input( + min_shape=args[1], + opt_shape=args[2], + max_shape=args[3], + ), + Input( + min_shape=args[4], + opt_shape=args[5], + max_shape=args[6], + ), + ] + + self.run_test_with_dynamic_shape( + gt(), + input_specs, + ) + if __name__ == "__main__": run_tests() From 83e4ef03c2c56aed7855d3c5f006d2b72a926170 Mon Sep 17 00:00:00 2001 From: Lan Luo Date: Tue, 4 Jun 2024 13:11:17 -0700 Subject: [PATCH 2/2] Add supports_dynamic_shapes=True annotation --- .../dynamo/conversion/aten_ops_converters.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 09cc78f7db..b4b979bb7f 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2143,8 +2143,8 @@ def aten_ops_ne( ) -@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), @@ -2167,8 +2167,8 @@ def aten_ops_gt( ) -@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,),