Skip to content

Commit ed97b61

Browse files
committed
empty_memory_format evaluator
1 parent 7d30714 commit ed97b61

File tree

2 files changed

+211
-0
lines changed

2 files changed

+211
-0
lines changed

py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
dynamo_tensorrt_converter,
1212
)
1313
from torch_tensorrt.fx.types import TRTTensor
14+
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
1415

1516
_LOGGER: logging.Logger = logging.getLogger(__name__)
1617

@@ -47,3 +48,87 @@ def aten_ops_arange_start_step(
4748
name: str,
4849
) -> Union[TRTTensor, Sequence[TRTTensor]]:
4950
return np.arange(*args)
51+
52+
53+
def empty_validator(empty_node: Node) -> bool:
54+
layout = empty_node.kwargs.get("layout", None)
55+
pin_memory = empty_node.kwargs.get("pin_memory", None)
56+
memory_format = empty_node.kwargs.get("memory_format", None)
57+
if layout is not None:
58+
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
59+
return False
60+
if pin_memory is not None:
61+
_LOGGER.debug(
62+
f"Currently we don't support specifying pin_memory, got {pin_memory}."
63+
)
64+
return False
65+
if memory_format is not None:
66+
_LOGGER.debug(
67+
f"Currently we don't support specifying layout, got {memory_format}."
68+
)
69+
return False
70+
return True
71+
72+
73+
@dynamo_tensorrt_converter(
74+
torch.ops.aten.empty.memory_format, capability_validator=empty_validator
75+
)
76+
def aten_ops_empty(
77+
ctx: ConversionContext,
78+
target: Target,
79+
args: Tuple[Argument, ...],
80+
kwargs: Dict[str, Argument],
81+
name: str,
82+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
83+
if kwargs.get("device") is not None:
84+
return np.empty(*args[0], dtype=kwargs.get("dtype")).to(
85+
device=kwargs.get("device")
86+
)
87+
return np.empty(
88+
*args[0], dtype=unified_dtype_converter(kwargs.get("dtype"), Frameworks.NUMPY)
89+
)
90+
91+
92+
def empty_validator(empty_node: Node) -> bool:
93+
layout = empty_node.kwargs.get("layout", None)
94+
if layout is not None:
95+
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
96+
return False
97+
return True
98+
99+
100+
@dynamo_tensorrt_converter(
101+
torch.ops.aten.empty.memory_format, capability_validator=empty_validator
102+
)
103+
def aten_ops_empty(
104+
ctx: ConversionContext,
105+
target: Target,
106+
args: Tuple[Argument, ...],
107+
kwargs: Dict[str, Argument],
108+
name: str,
109+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
110+
empty_np_tensor = None
111+
memory_format = kwargs.get("memory_format")
112+
if kwargs.get("dtype") is not None:
113+
empty_np_tensor = np.empty(
114+
tuple(args[0]),
115+
dtype=unified_dtype_converter(kwargs.get("dtype"), Frameworks.NUMPY),
116+
)
117+
else:
118+
# default returns np.float64. Verify the correctness of this
119+
empty_np_tensor = np.empty(tuple(args[0]))
120+
121+
empty_tensor = torch.Tensor(empty_np_tensor)
122+
# device
123+
if kwargs.get("device") is not None:
124+
empty_tensor = empty_tensor.to(device=kwargs.get("device"))
125+
126+
# memory_format. default is torch.contiguous_format
127+
if memory_format == torch.channels_last:
128+
# shape of args[0] must be 4
129+
empty_tensor = empty_tensor.to(memory_format=torch.channels_last)
130+
elif memory_format == torch.channels_last_3d:
131+
# shape of args[0] must be 5
132+
empty_tensor = empty_tensor.to(memory_format=torch.channels_last_3d)
133+
134+
return empty_tensor
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
import torch_tensorrt
5+
from harness import DispatchTestCase
6+
from parameterized import parameterized
7+
from torch.testing._internal.common_utils import run_tests
8+
9+
empty_ops = [
10+
(
11+
"empty_one_dimension",
12+
[1],
13+
None,
14+
None,
15+
None,
16+
),
17+
(
18+
"empty_two_dimension",
19+
[1, 2],
20+
None,
21+
None,
22+
None,
23+
),
24+
(
25+
"empty_three_dimension",
26+
[2, 3, 4],
27+
None,
28+
None,
29+
None,
30+
),
31+
(
32+
"empty_one_dimension_dtype",
33+
[1],
34+
torch.float32,
35+
None,
36+
None,
37+
),
38+
(
39+
"empty_two_dimension_dtype",
40+
[2, 3],
41+
torch.float32,
42+
None,
43+
None,
44+
),
45+
(
46+
"empty_one_dimension_dtype_device",
47+
[1],
48+
torch.float32,
49+
"cuda",
50+
None,
51+
),
52+
(
53+
"empty_two_dimension_dtype_device",
54+
[2, 3],
55+
torch.float32,
56+
"cuda",
57+
None,
58+
),
59+
(
60+
"empty_four_dimension_memformat",
61+
[1, 2, 2, 1],
62+
torch.float32,
63+
"cuda",
64+
torch.channels_last,
65+
),
66+
(
67+
"empty_five_dimension_memformat",
68+
[1, 2, 2, 2, 1],
69+
torch.float32,
70+
"cuda",
71+
torch.channels_last_3d,
72+
),
73+
]
74+
75+
76+
class TestRandConverter(DispatchTestCase):
77+
@parameterized.expand(
78+
[(empty_op[0], empty_op[1], empty_op[2], empty_op[3]) for empty_op in empty_ops]
79+
)
80+
def test_empty(self, name, shape_or_input, data_type, device):
81+
class TestModule(nn.Module):
82+
def __init__(self):
83+
super().__init__()
84+
85+
def forward(self, x):
86+
shape_or_input[0] = x.shape[0]
87+
return torch.empty(shape_or_input)
88+
89+
empty_model = TestModule()
90+
91+
inputs = [torch.randint(1, 3, shape_or_input, dtype=torch.int32)]
92+
comparator_shape_dtype_device = (
93+
lambda x, y, check_dtype, check_device: x.shape == y.shape
94+
and (x.stride() == y.stride())
95+
and (x.dtype == y.dtype if check_dtype else True)
96+
and (x.get_device() == y.get_device() if check_device else True)
97+
)
98+
expected_ops = []
99+
if "device" in name:
100+
self.run_test_compare_tensor_attributes_only(
101+
empty_model,
102+
inputs,
103+
expected_ops,
104+
[(comparator_shape_dtype_device, [True, True])],
105+
use_dynamo_tracer=True,
106+
)
107+
elif "dtype" in name:
108+
self.run_test_compare_tensor_attributes_only(
109+
empty_model,
110+
inputs,
111+
expected_ops,
112+
[(comparator_shape_dtype_device, [True, False])],
113+
use_dynamo_tracer=True,
114+
)
115+
else:
116+
self.run_test_compare_tensor_attributes_only(
117+
empty_model,
118+
inputs,
119+
expected_ops,
120+
[(comparator_shape_dtype_device, [False, False])],
121+
use_dynamo_tracer=True,
122+
)
123+
124+
125+
if __name__ == "__main__":
126+
run_tests()

0 commit comments

Comments
 (0)