Skip to content

Commit 96520a7

Browse files
authored
chore: cherry pick of #2709 (#2850)
1 parent ade63e7 commit 96520a7

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

core/runtime/execute_engine.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,15 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
152152
std::vector<int32_t> inputs_cpu_vec(
153153
input_cpu.data_ptr<int32_t>(), input_cpu.data_ptr<int32_t>() + input_cpu.numel());
154154
inputShapeTensorValues.emplace_back(inputs_cpu_vec);
155-
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data());
155+
TORCHTRT_CHECK(
156+
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
157+
"Error while setting the tensor address for shape inputs");
156158
} else {
157-
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims);
158-
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputs[i].view(shape).contiguous().data_ptr());
159+
TORCHTRT_CHECK(
160+
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape");
161+
TORCHTRT_CHECK(
162+
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputs[i].view(shape).contiguous().data_ptr()),
163+
"Error while setting the input tensor address for inputs");
159164
}
160165
}
161166

@@ -188,7 +193,9 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
188193
auto dims = core::util::toVec(out_shape);
189194
auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
190195
outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous());
191-
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr());
196+
TORCHTRT_CHECK(
197+
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()),
198+
"Error while setting the output tensor address");
192199
}
193200
}
194201

0 commit comments

Comments
 (0)