Skip to content

Commit 8f6805a

Browse files
committed
chore: minor naming issues
1 parent 4604caf commit 8f6805a

File tree

3 files changed

+13
-25
lines changed

3 files changed

+13
-25
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2671,7 +2671,6 @@ def aten_ops_pixel_unshuffle(
26712671
)
26722672

26732673

2674-
@dynamo_tensorrt_converter(torch.ops.aten.resize.default)
26752674
@dynamo_tensorrt_converter(torch.ops.aten.resize_.default)
26762675
@enforce_tensor_types(
26772676
{
@@ -2685,7 +2684,7 @@ def aten_ops_resize(
26852684
kwargs: Dict[str, Argument],
26862685
name: str,
26872686
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2688-
return impl.shuffle.resize_(
2687+
return impl.shuffle.resize(
26892688
ctx,
26902689
target,
26912690
SourceIR.ATEN,

py/torch_tensorrt/dynamo/conversion/impl/shuffle.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,15 @@ def pixel_unshuffle(
140140
)
141141

142142

143-
def resize_(
143+
def resize(
144144
ctx: ConversionContext,
145145
target: Union[Target, str],
146146
source_ir: Optional[SourceIR],
147147
name: str,
148148
input: TRTTensor,
149149
sizes: Sequence[int],
150150
) -> TRTTensor:
151-
152151
input_np_dtype = unified_dtype_converter(input.dtype, Frameworks.NUMPY)
153-
154152
input_val = get_trt_tensor(ctx, input, f"{name}_input")
155153

156154
# Calculate the total number of elements for new and current shape
@@ -166,31 +164,34 @@ def resize_(
166164

167165
# Flatten input tensor to 1D for concatenation
168166
flatten_shape = flatten_dims(input_val, 0, -1)
169-
flattened_input = impl.shuffle.reshape(
167+
flattened_input = reshape(
170168
ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape
171169
)
172170

173171
# Concatenate the flattened input tensor and padding tensor
174-
concat_layer = ctx.net.add_concatenation([flattened_input, padding_tensor])
175-
concat_layer.axis = 0
176-
reshaped_tensor = concat_layer.get_output(0)
177-
172+
reshaped_tensor = impl.cat.cat(
173+
ctx,
174+
target,
175+
source_ir,
176+
f"{name}_cat",
177+
[flattened_input, padding_tensor],
178+
dim=0,
179+
)
178180
elif new_num_elements < current_num_elements:
179181
# Flatten input tensor to 1D for slicing
180182
flatten_shape = flatten_dims(input_val, 0, -1)
181-
flattened_input = impl.shuffle.reshape(
183+
flattened_input = reshape(
182184
ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape
183185
)
184186

185187
# Slice the flattened input tensor to the desired number of elements
186188
slice_layer = ctx.net.add_slice(flattened_input, [0], [new_num_elements], [1])
187189
reshaped_tensor = slice_layer.get_output(0)
188-
189190
else:
190191
reshaped_tensor = input_val
191192

192193
# Reshape the final output tensor to the target sizes
193-
resized_output = impl.shuffle.reshape(
194+
resized_output = reshape(
194195
ctx, target, source_ir, f"{name}_final_reshape", reshaped_tensor, sizes
195196
)
196197

tests/py/dynamo/conversion/test_resize_aten.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ class TestResizeConverter(DispatchTestCase):
2020
)
2121
def test_resize_1d_input_float(self, target_shape):
2222
class Resize(torch.nn.Module):
23-
def __init__(self):
24-
super().__init__()
25-
2623
def forward(self, x):
2724
return torch.ops.aten.resize_.default(x, target_shape)
2825

@@ -46,9 +43,6 @@ def forward(self, x):
4643
)
4744
def test_resize_1d_input_int(self, target_shape):
4845
class Resize(torch.nn.Module):
49-
def __init__(self):
50-
super().__init__()
51-
5246
def forward(self, x):
5347
return torch.ops.aten.resize_.default(x, target_shape)
5448

@@ -73,9 +67,6 @@ def forward(self, x):
7367
)
7468
def test_resize_2d_input_float(self, target_shape):
7569
class Resize(torch.nn.Module):
76-
def __init__(self):
77-
super().__init__()
78-
7970
def forward(self, x):
8071
return torch.ops.aten.resize_.default(x, target_shape)
8172

@@ -100,9 +91,6 @@ def forward(self, x):
10091
)
10192
def test_resize_2d_input_int(self, target_shape):
10293
class Resize(torch.nn.Module):
103-
def __init__(self):
104-
super().__init__()
105-
10694
def forward(self, x):
10795
return torch.ops.aten.resize_.default(x, target_shape)
10896

0 commit comments

Comments
 (0)