-
Notifications
You must be signed in to change notification settings - Fork 388
Plugin TensorRT EP using ORT EP ABI #527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…e mutiple GPU devices
…rs built from ORT repo
…' in CMake for Windows debug build
…dencies if it's release build
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need all of these helper files? this one doesn't seem to be compiled, with the suffix ".ccc".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for catching that, i removed them.
: severity == Severity::kWARNING ? "WARNING" | ||
: severity == Severity::kINFO ? " INFO" | ||
: "UNKNOWN"); | ||
if (severity <= Severity::kERROR) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be good to actually log something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added ORT default logger for TRT logger to print/log messages.
Will also add back default logger for plugin TRT EP as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
general comment: can we put all the code that doesn't need to be in the global namespace into a top-level namespace? maybe trt_ep
or something. there is some existing code in onnxruntime
but we probably should change that too.
// char hostname[HOST_NAME_MAX]; | ||
// if (gethostname(hostname, HOST_NAME_MAX) != 0) | ||
// strcpy(hostname, "?"); | ||
// #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
general: there seems to be quite a lot of commented out code in this PR. it's not ideal because it can easily get out of date. can we avoid adding commented out code?
@@ -0,0 +1,161 @@ | |||
# usage: | |||
# cd build/ | |||
# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DORT_HOME=/home/lochi/onnxruntime-win-x64-gpu-1.23.0 -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DTENSORRT_HOME=/home/lochi/tensorrt/TensorRT-10.3.0.26 -DCMAKE_POSITION_INDEPENDENT_CODE=ON (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: perhaps should replace lochi
with a generic user
or something like it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could it be put in the c_cxx folder along with other C/C++ examples?
/* | ||
std::vector<const OrtOpAttr*> node_attributes(num_node_attributes); | ||
RETURN_IF_ERROR(ort_api.Node_GetAttributes(node, node_attributes.data(), node_attributes.size())); | ||
*/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: not needed anymore?
|
||
auto node = nodes[0]; | ||
|
||
size_t num_node_attributes = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like this is not used.
|
||
const OrtOpAttr* node_attr = nullptr; | ||
RETURN_IF_ERROR(ort_api.Node_GetAttributeByName(node, "embed_mode", &node_attr)); | ||
const int64_t embed_mode = reinterpret_cast<const ONNX_NAMESPACE::AttributeProto*>(node_attr)->i(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this EP is largely an example of how to develop an EP, should we try to use the public C apis to get the attribute values (i.e., ReadOpAttr
) when possible? I think we want to show that an EP doesn't necessarily have to build with onnx to use these APIs.
Perhaps this wasn't done initially because the C API is cumbersome. But now that we have the C++ ORT APIs, getting the attribute values should hopefully be a one-liner.
// Get engine from byte stream. | ||
node_attr = nullptr; | ||
RETURN_IF_ERROR(ort_api.Node_GetAttributeByName(node, "ep_cache_context", &node_attr)); | ||
const std::string& context_binary = reinterpret_cast<const ONNX_NAMESPACE::AttributeProto*>(node_attr)->s(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. Could potentially use the C++ ORT API to get attr value?
} else { | ||
output_tensors[i] = ctx.GetOutput(output_index, output_shapes); | ||
auto& output_tensor = output_tensors[i]; | ||
const auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C++ API functions like this one can throw exceptions. Are these exceptions caught/handled somewhere in the EP (and maybe converted to a Status that can be returned to ORT)?
} | ||
|
||
// Create (optional) fusion options for the supported nodes to fuse. | ||
OrtNodeFusionOptions node_fusion_options = {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At runtime, does TRT need access to the original initializers in the onnx model? If TRT copies the weights into its compiled binary, then there's an opportunity to allow ORT to free the weights if they are not used by another EP.
// LOGS_DEFAULT(WARNING) << "[TensorRT EP] No graph will run on TensorRT execution provider"; | ||
} else if (number_of_trt_nodes == nodes.size()) { | ||
// LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; | ||
} else { | ||
// LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should the log statements be uncommented? (or maybe remove the if statements).
@chilo-ms Can you please also guide the changes for the wheel creation for using Python APIs independently with ORT TRT EP or any custom EP standalone code, for the latest API/ABI interfaces offered by ORT Core (available from ORT version 1.23.0) ? As we know that we have decoupled the TRT EP from ORT source code, we can no longer access & compile below file - |
You don't need to make any changes for creating the ORT GPU wheel. Here is the reference code: import onnxruntime as onnxrt
import numpy as np
ep_lib_path = "C:\\path\\to\\plugin_trt_ep\\TensorRTEp.dll"
ep_name = "TensorRTEp"
ep_registration_name = ep_name
onnxrt.register_execution_provider_library(ep_registration_name, ep_lib_path)
ep_devices = onnxrt.get_ep_devices()
trt_ep_device = None
for ep_device in ep_devices:
if ep_device.ep_name == ep_name:
trt_ep_device = ep_device
assert trt_ep_device != None
sess_options = onnxrt.SessionOptions()
sess_options.add_provider_for_devices([trt_ep_device], {'trt_engine_cache_enable': '1'})
assert sess_options.has_providers() == True
# Run sample model and check output
sess = onnxrt.InferenceSession("C:\\modles\\mul_1.onnx", sess_options=sess_options)
x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
input_name = sess.get_inputs()[0].name
res = sess.run([], {input_name: x})
output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32)
np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08)
onnxrt.unregister_execution_provider_library(ep_registration_name) The |
Description
This plugin TRT EP is migrated from the original TRT EP and provides the implementations of
OrtEpFactory
,OrtEp
,OrtNodeComputeInfo
,OrtDataTransferImpl
... that are required for a plugin EP to be able to interact with ONNX Runtime via the EP ABI (introduced in ORT 1.23.0).Plugin EP should be built independently without the ORT source code, as it relies on the API/ABI provided by ORT. Therefore, it should reside in a separate repository outside the main ORT repository.
This plugin TRT EP can be built on Linux and Windows with "Debug" and "Release" mode.
Build plugin TRT EP on Windows:
(Note: The ORT_HOME should contain the include and lib folder as below)
Build plugin TRT EP on Linux:
Run the plugin TRT EP:
Please use
onnxruntime_perf_test
oronnx_test_runner
TODO
-Currently
GetCapability
assumes the whole graph is TRT eligible. Will have another PR to add TRT parser call for partition.-Add simple unit test