From 92523e3effba11a3476832ffe04f2633855498fd Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Thu, 20 Jun 2024 01:22:26 +0900 Subject: [PATCH] feat: dynamic shape support for tan, sinh, cosh, asin and acos --- .../dynamo/conversion/aten_ops_converters.py | 10 ++-- tests/py/dynamo/conversion/test_acos_aten.py | 51 +++++++++++++++++++ tests/py/dynamo/conversion/test_asin_aten.py | 50 ++++++++++++++++++ tests/py/dynamo/conversion/test_cosh_aten.py | 50 ++++++++++++++++++ tests/py/dynamo/conversion/test_sinh_aten.py | 50 ++++++++++++++++++ tests/py/dynamo/conversion/test_tan_aten.py | 51 +++++++++++++++++++ 6 files changed, 257 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index ba55d30acc..7d8dcc60ad 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1501,7 +1501,7 @@ def aten_ops_cos( ) -@dynamo_tensorrt_converter(torch.ops.aten.tan.default) +@dynamo_tensorrt_converter(torch.ops.aten.tan.default, supports_dynamic_shapes=True) def aten_ops_tan( ctx: ConversionContext, target: Target, @@ -1518,7 +1518,7 @@ def aten_ops_tan( ) -@dynamo_tensorrt_converter(torch.ops.aten.sinh.default) +@dynamo_tensorrt_converter(torch.ops.aten.sinh.default, supports_dynamic_shapes=True) def aten_ops_sinh( ctx: ConversionContext, target: Target, @@ -1535,7 +1535,7 @@ def aten_ops_sinh( ) -@dynamo_tensorrt_converter(torch.ops.aten.cosh.default) +@dynamo_tensorrt_converter(torch.ops.aten.cosh.default, supports_dynamic_shapes=True) def aten_ops_cosh( ctx: ConversionContext, target: Target, @@ -1552,7 +1552,7 @@ def aten_ops_cosh( ) -@dynamo_tensorrt_converter(torch.ops.aten.asin.default) +@dynamo_tensorrt_converter(torch.ops.aten.asin.default, supports_dynamic_shapes=True) def aten_ops_asin( ctx: ConversionContext, target: Target, @@ -1569,7 +1569,7 @@ def aten_ops_asin( ) -@dynamo_tensorrt_converter(torch.ops.aten.acos.default) +@dynamo_tensorrt_converter(torch.ops.aten.acos.default, supports_dynamic_shapes=True) def aten_ops_acos( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_acos_aten.py b/tests/py/dynamo/conversion/test_acos_aten.py index 503cc54f39..81b83bcc4a 100644 --- a/tests/py/dynamo/conversion/test_acos_aten.py +++ b/tests/py/dynamo/conversion/test_acos_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -44,6 +45,56 @@ def forward(self, input): inputs, ) + @parameterized.expand( + [ + ( + "3d_dim_dtype_int32", + (3, 2, 1), + (3, 2, 3), + (3, 3, 4), + torch.int32, + torch.float32, + ), + ( + "2d_dim_dtype_float16", + (1, 1), + (2, 2), + (4, 4), + torch.float16, + torch.float16, + ), + ( + "3d_dim_dtype_float", + (1, 1, 1), + (1, 2, 3), + (3, 3, 3), + torch.float, + torch.float, + ), + ] + ) + def test_acos_dynamic_shape( + self, _, min_shape, opt_shape, max_shape, type, output_type + ): + class acos(nn.Module): + def forward(self, input): + return torch.ops.aten.acos.default(input) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, + ), + ] + + self.run_test_with_dynamic_shape( + acos(), + input_specs, + output_dtypes=[output_type], + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_asin_aten.py b/tests/py/dynamo/conversion/test_asin_aten.py index c77452b370..2b8eb84144 100644 --- a/tests/py/dynamo/conversion/test_asin_aten.py +++ b/tests/py/dynamo/conversion/test_asin_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -44,6 +45,55 @@ def forward(self, input): inputs, ) + @parameterized.expand( + [ + ( + "3d_dim_dtype_int32", + (3, 2, 1), + (3, 2, 3), + (3, 3, 4), + torch.int32, + torch.float32, + ), + ( + "2d_dim_dtype_float16", + (1, 1), + (2, 2), + (4, 4), + torch.float16, + torch.float16, + ), + ( + "3d_dim_dtype_float", + (1, 1, 1), + (1, 2, 3), + (3, 3, 3), + torch.float, + torch.float, + ), + ] + ) + def test_asin_dynamic_shape( + self, _, min_shape, opt_shape, max_shape, type, output_type + ): + class asin(nn.Module): + def forward(self, input): + return torch.ops.aten.asin.default(input) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, + ), + ] + self.run_test_with_dynamic_shape( + asin(), + input_specs, + output_dtypes=[output_type], + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_cosh_aten.py b/tests/py/dynamo/conversion/test_cosh_aten.py index 1175613796..a647ec71be 100644 --- a/tests/py/dynamo/conversion/test_cosh_aten.py +++ b/tests/py/dynamo/conversion/test_cosh_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -44,6 +45,55 @@ def forward(self, input): inputs, ) + @parameterized.expand( + [ + ( + "3d_dim_dtype_int32", + (3, 2, 1), + (3, 2, 3), + (3, 3, 4), + torch.int32, + torch.float32, + ), + ( + "2d_dim_dtype_float16", + (1, 1), + (2, 2), + (4, 4), + torch.float16, + torch.float16, + ), + ( + "3d_dim_dtype_float", + (1, 1, 1), + (1, 2, 3), + (3, 3, 3), + torch.float, + torch.float, + ), + ] + ) + def test_cosh_dynamic_shape( + self, _, min_shape, opt_shape, max_shape, type, output_type + ): + class cosh(nn.Module): + def forward(self, input): + return torch.ops.aten.cosh.default(input) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, + ), + ] + self.run_test_with_dynamic_shape( + cosh(), + input_specs, + output_dtypes=[output_type], + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_sinh_aten.py b/tests/py/dynamo/conversion/test_sinh_aten.py index d17ab3b467..9238f092c8 100644 --- a/tests/py/dynamo/conversion/test_sinh_aten.py +++ b/tests/py/dynamo/conversion/test_sinh_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -44,6 +45,55 @@ def forward(self, input): inputs, ) + @parameterized.expand( + [ + ( + "3d_dim_dtype_int32", + (3, 2, 1), + (3, 2, 3), + (3, 3, 4), + torch.int32, + torch.float32, + ), + ( + "2d_dim_dtype_float16", + (1, 1), + (2, 2), + (4, 4), + torch.float16, + torch.float16, + ), + ( + "3d_dim_dtype_float", + (1, 1, 1), + (1, 2, 3), + (3, 3, 3), + torch.float, + torch.float, + ), + ] + ) + def test_sinh_dynamic_shape( + self, _, min_shape, opt_shape, max_shape, type, output_type + ): + class sinh(nn.Module): + def forward(self, input): + return torch.ops.aten.sinh.default(input) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, + ), + ] + self.run_test_with_dynamic_shape( + sinh(), + input_specs, + output_dtypes=[output_type], + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_tan_aten.py b/tests/py/dynamo/conversion/test_tan_aten.py index 137025dbc6..121364b0b2 100644 --- a/tests/py/dynamo/conversion/test_tan_aten.py +++ b/tests/py/dynamo/conversion/test_tan_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -44,6 +45,56 @@ def forward(self, input): inputs, ) + @parameterized.expand( + [ + ( + "3d_dim_dtype_int32", + (3, 2, 1), + (3, 2, 3), + (3, 3, 4), + torch.int32, + torch.float32, + ), + ( + "2d_dim_dtype_float16", + (1, 1), + (2, 2), + (4, 4), + torch.float16, + torch.float16, + ), + ( + "3d_dim_dtype_float", + (1, 1, 1), + (1, 2, 3), + (3, 3, 3), + torch.float, + torch.float, + ), + ] + ) + def test_tan_dynamic_shape( + self, _, min_shape, opt_shape, max_shape, type, output_type + ): + class tan(nn.Module): + def forward(self, input): + return torch.ops.aten.tan.default(input) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, + ), + ] + + self.run_test_with_dynamic_shape( + tan(), + input_specs, + output_dtypes=[output_type], + ) + if __name__ == "__main__": run_tests()