Skip to content

Commit f97b696

Browse files
committed
Review comments- adding cases for stride, correcting validator and changing call to torch.ops.aten.empty.memory_format
1 parent 973e4ae commit f97b696

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

tests/py/dynamo/conversion/test_empty_aten.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import torch
33
import torch.nn as nn
44
import torch_tensorrt
5-
from .harness import DispatchTestCase
65
from parameterized import parameterized
76
from torch.testing._internal.common_utils import run_tests
87

8+
from .harness import DispatchTestCase
9+
910
empty_ops = [
1011
(
1112
"empty_one_dimension",
@@ -75,7 +76,10 @@
7576

7677
class TestRandConverter(DispatchTestCase):
7778
@parameterized.expand(
78-
[(empty_op[0], empty_op[1], empty_op[2], empty_op[3], empty_op[4]) for empty_op in empty_ops]
79+
[
80+
(empty_op[0], empty_op[1], empty_op[2], empty_op[3], empty_op[4])
81+
for empty_op in empty_ops
82+
]
7983
)
8084
def test_empty(self, name, shape_or_input, data_type, device, memory_format):
8185
class TestModule(nn.Module):
@@ -84,7 +88,9 @@ def __init__(self):
8488

8589
def forward(self, x):
8690
shape_or_input[0] = x.shape[0]
87-
return torch.ops.aten.empty.memory_format(shape_or_input, dtype = data_type, memory_format = layout, device = device)
91+
return torch.ops.aten.empty.memory_format(
92+
shape_or_input, dtype=data_type, memory_format=layout, device=device
93+
)
8894

8995
empty_model = TestModule()
9096

0 commit comments

Comments
 (0)