|
15 | 15 | _select_rt_device,
|
16 | 16 | multi_gpu_device_check,
|
17 | 17 | )
|
| 18 | +from torch_tensorrt.logging import TRT_LOGGER |
18 | 19 |
|
19 | 20 | logger = logging.getLogger(__name__)
|
20 | 21 |
|
@@ -64,35 +65,19 @@ def _initialize(self) -> None:
|
64 | 65 | ) == (len(self.input_names) + len(self.output_names))
|
65 | 66 |
|
66 | 67 | self.input_dtypes = [
|
67 |
| - dtype._from(self.engine.get_binding_dtype(idx)) |
68 |
| - for idx in self.input_binding_indices_in_order |
| 68 | + dtype._from(self.engine.get_tensor_dtype(input_name)) |
| 69 | + for input_name in self.input_names |
69 | 70 | ]
|
70 | 71 | self.input_shapes = [
|
71 | 72 | self.engine.get_tensor_shape(input_name) for input_name in self.input_names
|
72 | 73 | ]
|
73 | 74 | self.output_dtypes = [
|
74 |
| - dtype._from(self.engine.get_binding_dtype(idx)) |
75 |
| - for idx in self.output_binding_indices_in_order |
| 75 | + dtype._from(self.engine.get_tensor_dtype(output_name)) |
| 76 | + for output_name in self.output_names |
76 | 77 | ]
|
77 | 78 | self.output_shapes = [
|
78 |
| - ( |
79 |
| - tuple(self.engine.get_binding_shape(idx)) |
80 |
| - if self.engine.has_implicit_batch_dimension |
81 |
| - else tuple() |
82 |
| - ) |
83 |
| - for idx in self.output_binding_indices_in_order |
84 |
| - ] |
85 |
| - self.hidden_output_dtypes = [ |
86 |
| - dtype._from(self.engine.get_binding_dtype(idx)) |
87 |
| - for idx in self.hidden_output_binding_indices_in_order |
88 |
| - ] |
89 |
| - self.hidden_output_shapes = [ |
90 |
| - ( |
91 |
| - tuple(self.engine.get_binding_shape(idx)) |
92 |
| - if self.engine.has_implicit_batch_dimension |
93 |
| - else tuple() |
94 |
| - ) |
95 |
| - for idx in self.hidden_output_binding_indices_in_order |
| 79 | + self.engine.get_tensor_shape(output_name) |
| 80 | + for output_name in self.output_names |
96 | 81 | ]
|
97 | 82 |
|
98 | 83 | def _check_initialized(self) -> None:
|
@@ -234,15 +219,11 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
|
234 | 219 | bindings.append(output.data_ptr())
|
235 | 220 | outputs.append(output)
|
236 | 221 |
|
237 |
| - for i, idx in enumerate(self.hidden_output_binding_indices_in_order): |
238 |
| - shape = tuple(self.context.get_binding_shape(idx)) |
239 |
| - |
240 |
| - output = torch.empty( |
241 |
| - size=shape, |
242 |
| - dtype=self.hidden_output_dtypes[i].to(torch.dtype), |
243 |
| - device=torch.cuda.current_device(), |
244 |
| - ) |
245 |
| - bindings[idx] = output.data_ptr() |
| 222 | + # Assign tensor address appropriately |
| 223 | + for idx in range(self.engine.num_io_tensors): |
| 224 | + self.context.set_tensor_address( |
| 225 | + self.engine.get_tensor_name(idx), bindings[idx] |
| 226 | + ) |
246 | 227 |
|
247 | 228 | with (
|
248 | 229 | torch.autograd.profiler.record_function(
|
|
0 commit comments