Skip to content

Commit 973e4ae

Browse files
committed
Review comments- adding cases for stride, correcting validator and changing call to torch.ops.aten.empty.memory_format
1 parent fd330cf commit 973e4ae

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,6 @@ def aten_ops_randperm(
123123

124124
def empty_validator(empty_node: Node) -> bool:
125125
layout = empty_node.kwargs.get("layout", None)
126-
pin_memory = empty_node.kwargs.get("pin_memory", None)
127-
memory_format = empty_node.kwargs.get("memory_format", None)
128126
if layout is not None:
129127
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
130128
return False

tests/py/dynamo/conversion/test_empty_aten.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import torch.nn as nn
44
import torch_tensorrt
5-
from harness import DispatchTestCase
5+
from .harness import DispatchTestCase
66
from parameterized import parameterized
77
from torch.testing._internal.common_utils import run_tests
88

@@ -75,16 +75,16 @@
7575

7676
class TestRandConverter(DispatchTestCase):
7777
@parameterized.expand(
78-
[(empty_op[0], empty_op[1], empty_op[2], empty_op[3]) for empty_op in empty_ops]
78+
[(empty_op[0], empty_op[1], empty_op[2], empty_op[3], empty_op[4]) for empty_op in empty_ops]
7979
)
80-
def test_empty(self, name, shape_or_input, data_type, device):
80+
def test_empty(self, name, shape_or_input, data_type, device, memory_format):
8181
class TestModule(nn.Module):
8282
def __init__(self):
8383
super().__init__()
8484

8585
def forward(self, x):
8686
shape_or_input[0] = x.shape[0]
87-
return torch.empty(shape_or_input)
87+
return torch.ops.aten.empty.memory_format(shape_or_input, dtype = data_type, memory_format = layout, device = device)
8888

8989
empty_model = TestModule()
9090

0 commit comments

Comments
 (0)