Skip to content

Commit 29e1380

Browse files
peri044zewenli98
andcommitted
chore: Upgrade TensorRT version to TRT 10 EA (#2699)
Co-authored-by: Evan Li <[email protected]>
1 parent 30f5094 commit 29e1380

File tree

7 files changed

+25
-36
lines changed

7 files changed

+25
-36
lines changed

.github/workflows/build-test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ jobs:
264264
pre-script: ${{ matrix.pre-script }}
265265
script: |
266266
export USE_HOST_DEPS=1
267+
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
267268
pushd .
268269
cd tests/py/core
269270
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd
116116
These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.
117117

118118
- Bazel 5.2.0
119-
- Libtorch 2.4.0.dev (latest nightly) (built with CUDA 12.1)
119+
- Libtorch 2.3.0 (built with CUDA 12.1)
120120
- CUDA 12.1
121121
- cuDNN 8.9.5
122122
- TensorRT 10.0.0.6

py/torch_tensorrt/_enums.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _from(
107107
return dtype.f16
108108
elif t == trt.float32:
109109
return dtype.f32
110-
elif trt.__version__ >= "7.0" and t == trt.bool:
110+
elif t == trt.bool:
111111
return dtype.b
112112
else:
113113
raise TypeError(

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def run(
313313
)
314314
timing_cache = self._create_timing_cache(builder_config, existing_cache)
315315

316-
engine = self.builder.build_engine(self.ctx.net, builder_config)
316+
engine = self.builder.build_serialized_network(self.ctx.net, builder_config)
317317
assert engine
318318

319319
serialized_cache = (

py/torch_tensorrt/dynamo/conversion/impl/shape.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch_tensorrt.dynamo._SourceIR import SourceIR
1010
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1111
from torch_tensorrt.dynamo.conversion.converter_utils import (
12+
cast_trt_tensor,
1213
get_positive_dim,
1314
get_trt_tensor,
1415
)
@@ -38,6 +39,12 @@ def shape(
3839
"""
3940
shape_layer = ctx.net.add_shape(input_val)
4041
input_shape = shape_layer.get_output(0)
42+
input_shape = cast_trt_tensor(
43+
ctx,
44+
input_shape,
45+
trt.int32,
46+
name + "_shape_casted",
47+
)
4148
set_layer_name(shape_layer, target, name + "_shape", source_ir)
4249

4350
n_dims = len(input_val.shape)

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
_select_rt_device,
1616
multi_gpu_device_check,
1717
)
18+
from torch_tensorrt.logging import TRT_LOGGER
1819

1920
logger = logging.getLogger(__name__)
2021

@@ -64,35 +65,19 @@ def _initialize(self) -> None:
6465
) == (len(self.input_names) + len(self.output_names))
6566

6667
self.input_dtypes = [
67-
dtype._from(self.engine.get_binding_dtype(idx))
68-
for idx in self.input_binding_indices_in_order
68+
dtype._from(self.engine.get_tensor_dtype(input_name))
69+
for input_name in self.input_names
6970
]
7071
self.input_shapes = [
7172
self.engine.get_tensor_shape(input_name) for input_name in self.input_names
7273
]
7374
self.output_dtypes = [
74-
dtype._from(self.engine.get_binding_dtype(idx))
75-
for idx in self.output_binding_indices_in_order
75+
dtype._from(self.engine.get_tensor_dtype(output_name))
76+
for output_name in self.output_names
7677
]
7778
self.output_shapes = [
78-
(
79-
tuple(self.engine.get_binding_shape(idx))
80-
if self.engine.has_implicit_batch_dimension
81-
else tuple()
82-
)
83-
for idx in self.output_binding_indices_in_order
84-
]
85-
self.hidden_output_dtypes = [
86-
dtype._from(self.engine.get_binding_dtype(idx))
87-
for idx in self.hidden_output_binding_indices_in_order
88-
]
89-
self.hidden_output_shapes = [
90-
(
91-
tuple(self.engine.get_binding_shape(idx))
92-
if self.engine.has_implicit_batch_dimension
93-
else tuple()
94-
)
95-
for idx in self.hidden_output_binding_indices_in_order
79+
self.engine.get_tensor_shape(output_name)
80+
for output_name in self.output_names
9681
]
9782

9883
def _check_initialized(self) -> None:
@@ -234,15 +219,11 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
234219
bindings.append(output.data_ptr())
235220
outputs.append(output)
236221

237-
for i, idx in enumerate(self.hidden_output_binding_indices_in_order):
238-
shape = tuple(self.context.get_binding_shape(idx))
239-
240-
output = torch.empty(
241-
size=shape,
242-
dtype=self.hidden_output_dtypes[i].to(torch.dtype),
243-
device=torch.cuda.current_device(),
244-
)
245-
bindings[idx] = output.data_ptr()
222+
# Assign tensor address appropriately
223+
for idx in range(self.engine.num_io_tensors):
224+
self.context.set_tensor_address(
225+
self.engine.get_tensor_name(idx), bindings[idx]
226+
)
246227

247228
with (
248229
torch.autograd.profiler.record_function(

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ requires = [
88
"cffi>=1.15.1",
99
"typing-extensions>=4.7.0",
1010
"future>=0.18.3",
11-
"tensorrt>=8.6,<8.7",
11+
"tensorrt",
1212
"torch==2.3.0",
1313
"pybind11==2.6.2",
1414
"numpy",
@@ -42,7 +42,7 @@ requires-python = ">=3.8"
4242
keywords = ["pytorch", "torch", "tensorrt", "trt", "ai", "artificial intelligence", "ml", "machine learning", "dl", "deep learning", "compiler", "dynamo", "torchscript", "inference"]
4343
dependencies = [
4444
"torch==2.3.0",
45-
"tensorrt>=8.6,<8.7",
45+
"tensorrt",
4646
"packaging>=23",
4747
"numpy",
4848
"typing-extensions>=4.7.0",

0 commit comments

Comments
 (0)