From 19885ebe7ebd9f11cd820768e5e9c3cb4fa0046b Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 30 May 2024 10:51:48 -0700 Subject: [PATCH 1/5] Added testcases for dynamic shape support of split --- tests/py/dynamo/conversion/test_split_aten.py | 36 +++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/tests/py/dynamo/conversion/test_split_aten.py b/tests/py/dynamo/conversion/test_split_aten.py index 142f9b337c..aa28d8f813 100644 --- a/tests/py/dynamo/conversion/test_split_aten.py +++ b/tests/py/dynamo/conversion/test_split_aten.py @@ -119,6 +119,7 @@ def forward(self, input): @parameterized.expand( [ ("select_split_size_or_sections_dim_dynamic_shape", 2, 1), + ("select_split_size_or_sections_non_divisible_dim_dynamic_shape", 3, 1), ] ) def test_split_dynamic(self, _, split_size_or_tensor, dim): @@ -132,9 +133,11 @@ def forward(self, input): input_specs = [ Input( - shape=(1, 10, -1), dtype=torch.float32, - shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))], + min_shape=[1, 10, 1], + opt_shape=[1, 10, 10], + max_shape=[1, 10, 10], + name = "input", ), ] self.run_test_with_dynamic_shape( @@ -142,6 +145,35 @@ def forward(self, input): input_specs, ) + @parameterized.expand( + [ + ("select_split_size_or_sections_dim_dynamic_shape_on_first_axis", 2, 1), + ] + ) + def test_split_dynamic_first_axis_dynamic(self, _, split_size_or_tensor, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.split.Tensor(input, split_size_or_tensor, dim) + return out + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=[1, 10, 10], + opt_shape=[3, 10, 10], + max_shape=[5, 10, 10], + name = "input", + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + ) + + @parameterized.expand( [ ("select_chunk_dim", 6, 0), From cc81f6c2457ad88d9d21ed274775e552e0b2dbd9 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 30 May 2024 10:52:08 -0700 Subject: [PATCH 2/5] Add dynamic support to decorator --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 8a82e52cb6..d698e66aee 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -629,14 +629,15 @@ def aten_ops_softmax( @dynamo_tensorrt_converter( - torch.ops.aten.split.Tensor, capability_validator=has_static_shapes_in_args([1]) + torch.ops.aten.split.Tensor, capability_validator=has_static_shapes_in_args([1]),supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( - torch.ops.aten.split.sizes, capability_validator=has_static_shapes_in_args([1]) + torch.ops.aten.split.sizes, capability_validator=has_static_shapes_in_args([1]),supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( torch.ops.aten.split_with_sizes.default, capability_validator=has_static_shapes_in_args([1]), + supports_dynamic_shapes=True, ) def aten_ops_split( ctx: ConversionContext, From 69da712f5910a6ce3cf179794c80da2881c350bb Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 30 May 2024 11:01:30 -0700 Subject: [PATCH 3/5] Formated the code --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index d698e66aee..c3553c0f65 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -629,10 +629,10 @@ def aten_ops_softmax( @dynamo_tensorrt_converter( - torch.ops.aten.split.Tensor, capability_validator=has_static_shapes_in_args([1]),supports_dynamic_shapes=True, + torch.ops.aten.split.Tensor, capability_validator=has_static_shapes_in_args([1]), supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( - torch.ops.aten.split.sizes, capability_validator=has_static_shapes_in_args([1]),supports_dynamic_shapes=True, + torch.ops.aten.split.sizes, capability_validator=has_static_shapes_in_args([1]), supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( torch.ops.aten.split_with_sizes.default, From 69af5f8083a6e595e634b8c1d4ae4c4cf440d98c Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 30 May 2024 14:35:59 -0700 Subject: [PATCH 4/5] Formatted the lint --- .../dynamo/conversion/aten_ops_converters.py | 8 ++++++-- tests/py/dynamo/conversion/test_convolution_aten.py | 1 - 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index c3553c0f65..9ca084f276 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -629,10 +629,14 @@ def aten_ops_softmax( @dynamo_tensorrt_converter( - torch.ops.aten.split.Tensor, capability_validator=has_static_shapes_in_args([1]), supports_dynamic_shapes=True, + torch.ops.aten.split.Tensor, + capability_validator=has_static_shapes_in_args([1]), + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( - torch.ops.aten.split.sizes, capability_validator=has_static_shapes_in_args([1]), supports_dynamic_shapes=True, + torch.ops.aten.split.sizes, + capability_validator=has_static_shapes_in_args([1]), + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( torch.ops.aten.split_with_sizes.default, diff --git a/tests/py/dynamo/conversion/test_convolution_aten.py b/tests/py/dynamo/conversion/test_convolution_aten.py index 7d69c871a9..1f87400e43 100644 --- a/tests/py/dynamo/conversion/test_convolution_aten.py +++ b/tests/py/dynamo/conversion/test_convolution_aten.py @@ -1,7 +1,6 @@ import torch from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests - from torch_tensorrt import Input from .harness import DispatchTestCase From d800e0ffd29a0f51f50438b433a0be49176fbd0e Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 30 May 2024 15:03:06 -0700 Subject: [PATCH 5/5] Refactor the code --- tests/py/dynamo/conversion/test_split_aten.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/py/dynamo/conversion/test_split_aten.py b/tests/py/dynamo/conversion/test_split_aten.py index aa28d8f813..aa26340452 100644 --- a/tests/py/dynamo/conversion/test_split_aten.py +++ b/tests/py/dynamo/conversion/test_split_aten.py @@ -137,7 +137,6 @@ def forward(self, input): min_shape=[1, 10, 1], opt_shape=[1, 10, 10], max_shape=[1, 10, 10], - name = "input", ), ] self.run_test_with_dynamic_shape( @@ -165,7 +164,6 @@ def forward(self, input): min_shape=[1, 10, 10], opt_shape=[3, 10, 10], max_shape=[5, 10, 10], - name = "input", ), ] self.run_test_with_dynamic_shape( @@ -173,7 +171,6 @@ def forward(self, input): input_specs, ) - @parameterized.expand( [ ("select_chunk_dim", 6, 0),