From 36c0dc17c1ffc9847c90dbd79305360929324bef Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Sun, 22 Jun 2025 19:49:41 -0700 Subject: [PATCH 01/60] plugin TRT EP init --- .../tensorrt/CMakeLists.txt | 119 + .../cuda/cu_inc/unary_elementwise_impl.cuh | 78 + .../cuda/unary_elementwise_ops_impl.cu | 93 + .../cuda/unary_elementwise_ops_impl.h | 54 + .../tensorrt/ep_abi_utils.cc | 12 + .../tensorrt/ep_abi_utils.h | 77 + .../tensorrt/nv_includes.h | 19 + .../tensorrt/onnx_ctx_model_helper.ccc | 258 ++ .../tensorrt/onnx_ctx_model_helper.h | 66 + .../tensorrt/ort_trt_int8_cal_table.fbs.h | 144 + .../tensorrt/tensorrt_cuda_allocator.cc | 79 + .../tensorrt/tensorrt_cuda_allocator.h | 82 + .../tensorrt/tensorrt_execution_provider.cc | 3822 +++++++++++++++++ .../tensorrt/tensorrt_execution_provider.h | 406 ++ .../tensorrt_execution_provider_info.cc | 339 ++ .../tensorrt_execution_provider_info.h | 65 + .../tensorrt_execution_provider_utils.h | 397 ++ .../tensorrt/tensorrt_provider_factory.cc | 152 + .../tensorrt/tensorrt_provider_factory.h | 58 + .../tensorrt/utils/code_location.h | 58 + .../tensorrt/utils/common.h | 169 + .../tensorrt/utils/cuda/cuda_call.h | 69 + .../tensorrt/utils/cuda/cuda_common.h | 14 + .../tensorrt/utils/endian.h | 27 + .../tensorrt/utils/exceptions.h | 91 + .../tensorrt/utils/helper.cc | 59 + .../tensorrt/utils/make_string.h | 126 + .../tensorrt/utils/murmurhash3.cc | 349 ++ .../tensorrt/utils/murmurhash3.h | 16 + .../tensorrt/utils/parse_string.h | 85 + .../tensorrt/utils/path_string.h | 70 + .../tensorrt/utils/provider_options.h | 18 + .../tensorrt/utils/provider_options_utils.h | 164 + .../tensorrt/utils/status.cc | 91 + .../tensorrt/utils/status.h | 192 + 35 files changed, 7918 insertions(+) create mode 100644 plugin_execution_providers/tensorrt/CMakeLists.txt create mode 100644 plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh create mode 100644 plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu create mode 100644 plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h create mode 100644 plugin_execution_providers/tensorrt/ep_abi_utils.cc create mode 100644 plugin_execution_providers/tensorrt/ep_abi_utils.h create mode 100644 plugin_execution_providers/tensorrt/nv_includes.h create mode 100644 plugin_execution_providers/tensorrt/onnx_ctx_model_helper.ccc create mode 100644 plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h create mode 100644 plugin_execution_providers/tensorrt/ort_trt_int8_cal_table.fbs.h create mode 100644 plugin_execution_providers/tensorrt/tensorrt_cuda_allocator.cc create mode 100644 plugin_execution_providers/tensorrt/tensorrt_cuda_allocator.h create mode 100644 plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc create mode 100644 plugin_execution_providers/tensorrt/tensorrt_execution_provider.h create mode 100644 plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc create mode 100644 plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h create mode 100644 plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h create mode 100644 plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc create mode 100644 plugin_execution_providers/tensorrt/tensorrt_provider_factory.h create mode 100644 plugin_execution_providers/tensorrt/utils/code_location.h create mode 100644 plugin_execution_providers/tensorrt/utils/common.h create mode 100644 plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h create mode 100644 plugin_execution_providers/tensorrt/utils/cuda/cuda_common.h create mode 100644 plugin_execution_providers/tensorrt/utils/endian.h create mode 100644 plugin_execution_providers/tensorrt/utils/exceptions.h create mode 100644 plugin_execution_providers/tensorrt/utils/helper.cc create mode 100644 plugin_execution_providers/tensorrt/utils/make_string.h create mode 100644 plugin_execution_providers/tensorrt/utils/murmurhash3.cc create mode 100644 plugin_execution_providers/tensorrt/utils/murmurhash3.h create mode 100644 plugin_execution_providers/tensorrt/utils/parse_string.h create mode 100644 plugin_execution_providers/tensorrt/utils/path_string.h create mode 100644 plugin_execution_providers/tensorrt/utils/provider_options.h create mode 100644 plugin_execution_providers/tensorrt/utils/provider_options_utils.h create mode 100644 plugin_execution_providers/tensorrt/utils/status.cc create mode 100644 plugin_execution_providers/tensorrt/utils/status.h diff --git a/plugin_execution_providers/tensorrt/CMakeLists.txt b/plugin_execution_providers/tensorrt/CMakeLists.txt new file mode 100644 index 00000000..bf2d7b0c --- /dev/null +++ b/plugin_execution_providers/tensorrt/CMakeLists.txt @@ -0,0 +1,119 @@ +# usage: +# cd build/ +# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DORT_HOME=/home/lochi/repos/ort -DTENSORRT_HOME=/home/lochi/tensorrt/TensorRT-10.3.0.26 (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") +# cmake --build ./ --config Debug +cmake_minimum_required(VERSION 3.26) +project(TensorRTEp VERSION 1.0) +set(CMAKE_CXX_STANDARD 17) +enable_language(CUDA) +file(TO_CMAKE_PATH CUDAToolkit_ROOT "/usr/local/cuda") +find_package(CUDAToolkit REQUIRED) + +add_definitions(-DONNX_NAMESPACE=onnx) +add_definitions(-DONNX_ML) +add_definitions(-DNV_TENSORRT_MAJOR=10) +add_definitions(-DNOMINMAX) +file(GLOB tensorrt_src "./*.cc" "./utils/*.cc" "./cuda/unary_elementwise_ops_impl.cu") +add_library(TensorRTEp SHARED ${tensorrt_src}) + +if (NOT ORT_HOME) + message(FATAL_ERROR "Please specify ORT_HOME, e.g. -DORT_HOME=/path/to/ort/") +endif() + +if (NOT TENSORRT_HOME) + message(FATAL_ERROR "Please specify TENSORRT_HOME, e.g. -DTENSORRT_HOME=/path/to/trt/") +endif() + +# Use release mode if not specified +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release") +endif() + +# Add dependencies +include(FetchContent) + +# Add GSL +FetchContent_Declare( + gsl + GIT_REPOSITORY https://github.com/microsoft/GSL.git + GIT_TAG v4.0.0 # Use a specific tag or commit +) + +FetchContent_MakeAvailable(gsl) + +# Add flatbuffers +FetchContent_Declare( + flatbuffers + GIT_REPOSITORY https://github.com/google/flatbuffers.git + GIT_TAG v23.5.26 # Use a specific tag or commit +) + +FetchContent_MakeAvailable(flatbuffers) + +if (WIN32) + set(PLATFORM "Windows") + set(ORT_LIB "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/${CMAKE_BUILD_TYPE}/onnxruntime.lib") + set(DEPS_PATH "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/_deps") + set(TRT_LIBS "${TENSORRT_HOME}/lib/nvinfer_10.lib" + "${TENSORRT_HOME}/lib/nvinfer_plugin_10.lib" + "${TENSORRT_HOME}/lib/nvonnxparser_10.lib") + set(DEPS_LIBS "${DEPS_PATH}/flatbuffers-build/${CMAKE_BUILD_TYPE}/flatbuffers.lib" + "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx.lib" + "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx_proto.lib") + + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(DEPS_LIBS ${DEPS_LIBS} + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobufd.lib" + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotocd.lib") + else() + set(DEPS_LIBS ${DEPS_LIBS} + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobuf.lib" + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotoc.lib") + endif() +else() + set(PLATFORM "Linux") + set(ORT_LIB "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/libonnxruntime.so") + set(DEPS_PATH "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/_deps") + set(TRT_LIBS "${TENSORRT_HOME}/lib/libnvinfer.so" + "${TENSORRT_HOME}/lib/libnvinfer_plugin.so" + "${TENSORRT_HOME}/lib/libnvonnxparser.so") + set(DEPS_LIBS "${DEPS_PATH}/flatbuffers-build/libflatbuffers.a" + "${DEPS_PATH}/onnx-build/libonnx.a" + "${DEPS_PATH}/onnx-build/libonnx_proto.a") + + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(DEPS_LIBS ${DEPS_LIBS} + "${DEPS_PATH}/protobuf-build/libprotobufd.a" + "${DEPS_PATH}/protobuf-build/libprotocd.a") + else() + set(DEPS_LIBS ${DEPS_LIBS} + "${DEPS_PATH}/protobuf-build/libprotobuf.a" + "${DEPS_PATH}/protobuf-build/libprotoc.a") + endif() +endif() + +MESSAGE(STATUS "Looking for following dependencies ...") +MESSAGE(STATUS "Platform : ${PLATFORM}") +MESSAGE(STATUS "ORT home : ${ORT_HOME}") +MESSAGE(STATUS "ORT lib : ${ORT_LIB}") +MESSAGE(STATUS "Deps path: ${DEPS_PATH}") +MESSAGE(STATUS "Deps libs: ${DEPS_LIBS}") +MESSAGE(STATUS "TRT libs : ${TRT_LIBS}") + +target_include_directories(TensorRTEp PUBLIC "${ORT_HOME}/include/onnxruntime/core/session/" + "./utils" + "/usr/local/cuda/include" + ${TENSORRT_HOME}/include + "${DEPS_PATH}/flatbuffers-src/include" + "${DEPS_PATH}/gsl-src/include" + "${DEPS_PATH}/onnx-src" + "${DEPS_PATH}/onnx-build" + "${DEPS_PATH}/protobuf-src/src" +) + +target_link_libraries(TensorRTEp PUBLIC ${ORT_LIB} + ${TRT_LIBS} + CUDA::cudart + ${DEPS_LIBS} + GSL + flatbuffers) diff --git a/plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh b/plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh new file mode 100644 index 00000000..87cf7c83 --- /dev/null +++ b/plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +namespace onnxruntime { +namespace cuda { + +// We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer +// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. +#ifndef CUDA_LONG +#define CUDA_LONG int32_t +#endif + +template +inline __host__ __device__ INT CeilDiv(INT a, INT2 b) // ceil(a/b) +{ + return (INT)(((size_t)a + (size_t)b - 1) / (size_t)b); // these size_t casts are necessary since b may be INT_MAX (for maxGridSize[]) +} + +struct GridDim { + enum : CUDA_LONG { + maxThreadsPerBlock = 256, // max threads per block + maxElementsPerThread = 4, // max element processed per thread + }; +}; + +template +__global__ void _UnaryElementWise( + const InT* input_data, + OutT* output_data, + const FuncT functor, + CUDA_LONG N) { + CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + InT value[NumElementsPerThread]; + + CUDA_LONG id = start; + #pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + value[i] = input_data[id]; + id += NumThreadsPerBlock; + } + } + + id = start; + #pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + output_data[id] = functor(value[i]); + id += NumThreadsPerBlock; + } + } +} + +template +void UnaryElementWiseImpl( + cudaStream_t stream, + const InT* input_data, + OutT* output_data, + const FuncT& func, + size_t count) { + if (count == 0) // special case where there's a dim value of 0 in the shape + return; + + int blocksPerGrid = static_cast(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); + CUDA_LONG N = static_cast(count); + _UnaryElementWise + <<>>( + input_data, + output_data, + func, + N); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu b/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu new file mode 100644 index 00000000..ad515a23 --- /dev/null +++ b/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "cu_inc/unary_elementwise_impl.cuh" + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 +#include "cuda_fp8.h" +#endif +#include + +namespace onnxruntime { + +namespace cuda { + +// the postfix of means the types supported by the op: +// B: uint8_t +// W: uint16_t +// U: uint32_t +// Z: uint64_t +// C: int8_t +// S: int16_t +// I: int32_t +// L: int64_t +// H: float16 +// F: float +// D: double +// O: bool +// X: BFloat16 + +// When casting, half needs to be converted via float type from most other types +template +struct ViaTypeMap { + typedef T ViaT; +}; + +template <> +struct ViaTypeMap { + typedef float ViaT; +}; + +template +struct OP_Cast { + __device__ __inline__ OutT operator()(const InT& a) const { + const bool any_float16 = std::is_same::value || std::is_same::value; + typedef typename std::conditional::type T; + typedef typename ViaTypeMap::ViaT ViaT; + return (OutT)((ViaT)a); + } +}; + +#define IMPL_CAST_IMPL(InT, OutT) \ + void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \ + UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \ + } + +#define IMPL_CAST_IMPL_THROW(InT, OutT) \ + void Explicit_Impl_Cast(cudaStream_t /*stream*/, const InT* /*input_data*/, OutT* /*output_data*/, \ + size_t /*count*/) { \ + ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ + } + +#define IMPL_CAST_IMPL_FROM(T) \ + IMPL_CAST_IMPL(T, half) \ + IMPL_CAST_IMPL(T, float) \ + IMPL_CAST_IMPL(T, double) \ + IMPL_CAST_IMPL(T, int8_t) \ + IMPL_CAST_IMPL(T, int16_t) \ + IMPL_CAST_IMPL(T, int32_t) \ + IMPL_CAST_IMPL(T, int64_t) \ + IMPL_CAST_IMPL(T, uint8_t) \ + IMPL_CAST_IMPL(T, uint16_t) \ + IMPL_CAST_IMPL(T, uint32_t) \ + IMPL_CAST_IMPL(T, uint64_t) \ + IMPL_CAST_IMPL(T, bool) \ + //IMPL_CAST_IMPL(T, BFloat16) + +IMPL_CAST_IMPL_FROM(half) +IMPL_CAST_IMPL_FROM(float) +IMPL_CAST_IMPL_FROM(double) +IMPL_CAST_IMPL_FROM(int8_t) +IMPL_CAST_IMPL_FROM(int16_t) +IMPL_CAST_IMPL_FROM(int32_t) +IMPL_CAST_IMPL_FROM(int64_t) +IMPL_CAST_IMPL_FROM(uint8_t) +IMPL_CAST_IMPL_FROM(uint16_t) +IMPL_CAST_IMPL_FROM(uint32_t) +IMPL_CAST_IMPL_FROM(uint64_t) +IMPL_CAST_IMPL_FROM(bool) +//IMPL_CAST_IMPL_FROM(BFloat16) + +} // namespace cuda +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h b/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h new file mode 100644 index 00000000..392cf46f --- /dev/null +++ b/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { +namespace cuda { + +// Cast + +#define DECL_IMPL_CAST(InT, OutT) \ + void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count); + +#define DECL_IMPL_CAST_FROM(T) \ + DECL_IMPL_CAST(T, half) \ + DECL_IMPL_CAST(T, float) \ + DECL_IMPL_CAST(T, double) \ + DECL_IMPL_CAST(T, int8_t) \ + DECL_IMPL_CAST(T, int16_t) \ + DECL_IMPL_CAST(T, int32_t) \ + DECL_IMPL_CAST(T, int64_t) \ + DECL_IMPL_CAST(T, uint8_t) \ + DECL_IMPL_CAST(T, uint16_t) \ + DECL_IMPL_CAST(T, uint32_t) \ + DECL_IMPL_CAST(T, uint64_t) \ + DECL_IMPL_CAST(T, bool) \ + //DECL_IMPL_CAST(T, BFloat16) + +DECL_IMPL_CAST_FROM(half) +DECL_IMPL_CAST_FROM(float) +DECL_IMPL_CAST_FROM(double) +DECL_IMPL_CAST_FROM(int8_t) +DECL_IMPL_CAST_FROM(int16_t) +DECL_IMPL_CAST_FROM(int32_t) +DECL_IMPL_CAST_FROM(int64_t) +DECL_IMPL_CAST_FROM(uint8_t) +DECL_IMPL_CAST_FROM(uint16_t) +DECL_IMPL_CAST_FROM(uint32_t) +DECL_IMPL_CAST_FROM(uint64_t) +DECL_IMPL_CAST_FROM(bool) +//DECL_IMPL_CAST_FROM(BFloat16) + +template +void Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { + Explicit_Impl_Cast(stream, input_data, output_data, count); +} + +} // namespace cuda + +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/ep_abi_utils.cc b/plugin_execution_providers/tensorrt/ep_abi_utils.cc new file mode 100644 index 00000000..bc0b6eeb --- /dev/null +++ b/plugin_execution_providers/tensorrt/ep_abi_utils.cc @@ -0,0 +1,12 @@ +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +#include +#include +#include +#include +#include +#include +#include + diff --git a/plugin_execution_providers/tensorrt/ep_abi_utils.h b/plugin_execution_providers/tensorrt/ep_abi_utils.h new file mode 100644 index 00000000..308a49da --- /dev/null +++ b/plugin_execution_providers/tensorrt/ep_abi_utils.h @@ -0,0 +1,77 @@ +#pragma once + +#include +#include + +#include "onnxruntime_c_api.h" + +#define RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* status = (fn); \ + if (status != nullptr) { \ + return status; \ + } \ + } while (0) + +#define RETURN_IF(cond, ort_api, msg) \ + do { \ + if ((cond)) { \ + return (ort_api).CreateStatus(ORT_EP_FAIL, (msg)); \ + } \ + } while (0) + +#define RETURN_FALSE_AND_PRINT_IF_ERROR(fn, ort_api) \ + do { \ + OrtStatus* status = (fn); \ + if (status != nullptr) { \ + std::cerr << (ort_api).GetErrorMessage(status) << std::endl ; \ + return false; \ + } \ + } while (0) + +struct OrtArrayOfConstObjects { + OrtArrayOfConstObjects() = default; + explicit OrtArrayOfConstObjects(OrtTypeTag object_type) : object_type(object_type) {} + OrtArrayOfConstObjects(OrtTypeTag object_type, size_t size, const void* initial_val = nullptr) + : object_type(object_type), storage(size, initial_val) {} + + OrtTypeTag object_type = OrtTypeTag::ORT_TYPE_TAG_Void; + std::vector storage; +}; + +// Convert an OrtArrayOfConstObjects into a span of Ort___ pointers. +template +static void GetSpanFromArrayOfConstObjects(const OrtArrayOfConstObjects* ort_array, + /*out*/ gsl::span& span) { + const OrtApi& ort_api = Ort::GetApi(); + + size_t size = 0; + ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetSize(ort_array, &size)); + + const void* const* raw_data = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetData(ort_array, &raw_data)); + + auto data = reinterpret_cast(raw_data); + span = gsl::span(data, size); +} + +// Helper to release a C API Ort object at the end of its scope. +// Useful when not using the public C++ API. +// Example: +// { +// OrtTensorTypeAndShapeInfo* info = nullptr; +// DeferOrtRelease defer_release(&info, c_api.ReleaseTensorTypeAndShapeInfo); +// ... +// } /* Release is called at end of scope*/ +template +struct DeferOrtRelease { + DeferOrtRelease(T** obj_ptr, std::function release_func) : obj_ptr_(obj_ptr), release_func_(release_func) {} + ~DeferOrtRelease() { + if (obj_ptr_ != nullptr && *obj_ptr_ != nullptr) { + release_func_(*obj_ptr_); + *obj_ptr_ = nullptr; + } + } + T** obj_ptr_ = nullptr; + std::function release_func_ = nullptr; +}; \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/nv_includes.h b/plugin_execution_providers/tensorrt/nv_includes.h new file mode 100644 index 00000000..047f325f --- /dev/null +++ b/plugin_execution_providers/tensorrt/nv_includes.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +// File to include the required TRT headers with workarounds for warnings we can't fix or not fixed yet. +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) // Ignore warning C4100: unreferenced formal parameter +#pragma warning(disable : 4996) // Ignore warning C4996: 'nvinfer1::IPluginV2' was declared deprecated +#endif + +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.ccc b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.ccc new file mode 100644 index 00000000..1b29f626 --- /dev/null +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.ccc @@ -0,0 +1,258 @@ +#include +#include +#include +#include "onnx_ctx_model_helper.h" +#include "tensorrt_execution_provider.h" +#include "path_string.h" + +namespace onnxruntime { + +bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + const OrtGraphApi* graph_api = api->GetGraphApi(ORT_API_VERSION); + int maxNodeIndex = 0; + graph_api->OrtGraph_MaxNodeIndex(graph_viewer, &maxNodeIndex); + for (int i = 0; i < maxNodeIndex; ++i) { + const OrtNode* node = nullptr; + graph_api->OrtGraph_GetOrtNode(graph_viewer, i, &node); + if (node == nullptr) { + continue; + } + const char* opType = nullptr; + graph_api->OrtNode_GetOpType(node, &opType); + if (strcmp(opType, EPCONTEXT_OP.c_str()) == 0) { + return true; + } + } + return false; +} + +/* + * Return the directory where the ep context model locates + */ +std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path) { + if (ep_context_file_path.empty()) { + return std::filesystem::path(); + } + std::filesystem::path ctx_path(ep_context_file_path); + if (std::filesystem::is_directory(ep_context_file_path)) { + return ctx_path; + } else { + return ctx_path.parent_path(); + } +} + +std::string GetCtxModelPath(const std::string& ep_context_file_path, + const std::string& original_model_path) { + std::string ctx_model_path; + + if (!ep_context_file_path.empty() && !std::filesystem::is_directory(ep_context_file_path)) { + ctx_model_path = ep_context_file_path; + } else { + std::filesystem::path model_path = original_model_path; + std::filesystem::path model_name_stem = model_path.stem(); // model_name.onnx -> model_name + std::string ctx_model_name = model_name_stem.string() + "_ctx.onnx"; + + if (std::filesystem::is_directory(ep_context_file_path)) { + std::filesystem::path model_directory = ep_context_file_path; + ctx_model_path = model_directory.append(ctx_model_name).string(); + } else { + ctx_model_path = ctx_model_name; + } + } + return ctx_model_path; +} + +bool IsAbsolutePath(const std::string& path_string) { +#ifdef _WIN32 + onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); + auto path = std::filesystem::path(ort_path_string.c_str()); + return path.is_absolute(); +#else + if (!path_string.empty() && path_string[0] == '/') { + return true; + } + return false; +#endif +} + +// Like "../file_path" +bool IsRelativePathToParentPath(const std::string& path_string) { +#ifdef _WIN32 + onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); + auto path = std::filesystem::path(ort_path_string.c_str()); + auto relative_path = path.lexically_normal().make_preferred().wstring(); + if (relative_path.find(L"..", 0) != std::string::npos) { + return true; + } + return false; +#else + if (!path_string.empty() && path_string.find("..", 0) != std::string::npos) { + return true; + } + return false; +#endif +} + +/* + * Get the weight-refitted engine cache path from a weight-stripped engine cache path + * + * Weight-stipped engine: + * An engine with weights stripped and its size is smaller than a regualr engine. + * The cache name of weight-stripped engine is TensorrtExecutionProvider_TRTKernel_XXXXX.stripped.engine + * + * Weight-refitted engine: + * An engine that its weights have been refitted and it's simply a regular engine. + * The cache name of weight-refitted engine is TensorrtExecutionProvider_TRTKernel_XXXXX.engine + */ +std::string GetWeightRefittedEnginePath(std::string stripped_engine_cache) { + std::filesystem::path stripped_engine_cache_path(stripped_engine_cache); + std::string refitted_engine_cache_path = stripped_engine_cache_path.stem().stem().string() + ".engine"; + return refitted_engine_cache_path; +} + +bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) { + // The weight-stripped engine cache has the naming of xxx.stripped.engine + return engine_cache_path.stem().extension().string() == ".stripped"; +} + +OrtStatusPtr TensorRTCacheModelHandler::GetEpContextFromGraph(const OrtGraphViewer* graph_viewer) { + if (!ValidateEPCtxNode(graph_viewer)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "It's not a valid EP Context node"); + } + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph_viewer, 0, &node); + + int64_t embed_mode = -1; + graph_api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str(), &embed_mode); + if (embed_mode) { + // Get engine from byte stream. + const char* context_binary_cstr = nullptr; + size_t size; + graph_api_->OrtNode_GetAttributeStrWithSize(node, EP_CACHE_CONTEXT.c_str(), &context_binary_cstr, &size); + std::string context_binary(context_binary_cstr, size); + *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(const_cast(context_binary.c_str()), + static_cast(context_binary.length()))); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it"; + if (!(*trt_engine_)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not deserialize engine from binary data"); + } + } else { + // Get engine from cache file. + const char* cache_path_cstr = nullptr; + graph_api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str(), &cache_path_cstr); + std::string cache_path(cache_path_cstr); + + // For security purpose, in the case of running context model, TRT EP won't allow + // engine cache path to be the relative path like "../file_path" or the absolute path. + // It only allows the engine cache to be in the same directory or sub directory of the context model. + if (IsAbsolutePath(cache_path)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("For security purpose, the ep_cache_context attribute should be set with a relative path, but it is an absolute path: " + cache_path).c_str()); + } + if (IsRelativePathToParentPath(cache_path)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "The file path in ep_cache_context attribute has '..'. For security purpose, it's not allowed to point outside the directory."); + } + + // The engine cache and context model (current model) should be in the same directory + std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_)); + auto engine_cache_path = ctx_model_dir.append(cache_path); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string(); + + // If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled + if (!weight_stripped_engine_refit_) { + weight_stripped_engine_refit_ = IsWeightStrippedEngineCache(engine_cache_path); + } + + // If the serialized refitted engine is present, use it directly without refitting the engine again + if (weight_stripped_engine_refit_) { + const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string()); + if (std::filesystem::exists(refitted_engine_cache_path)) { +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " + refitted_engine_cache_path.string() + " exists."; + engine_cache_path = refitted_engine_cache_path.string(); + weight_stripped_engine_refit_ = false; + } + } + + if (!std::filesystem::exists(engine_cache_path)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + std::string("TensorRT EP can't find engine cache: " + engine_cache_path.string() + + ". Please make sure engine cache is in the same directory or sub-directory of context model.").c_str()); + } + + std::ifstream engine_file(engine_cache_path.string(), std::ios::binary | std::ios::in); + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); + if (!(*trt_engine_)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + std::string("TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string()).c_str()); + } +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string(); + + if (weight_stripped_engine_refit_) { + const char* onnx_model_filename_cstr = nullptr; + graph_api_->OrtNode_GetAttributeStr(node, ONNX_MODEL_FILENAME.c_str(), &onnx_model_filename_cstr); + const std::string onnx_model_filename(onnx_model_filename_cstr); + std::string weight_stripped_engine_cache = engine_cache_path.string(); + auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename, + onnx_model_folder_path_, + weight_stripped_engine_cache, + true /* path check for security */, + (*trt_engine_).get(), + true /* serialize refitted engine to disk */, + detailed_build_log_); + if (status != nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); + } + } + } + return nullptr; +} + +bool TensorRTCacheModelHandler::ValidateEPCtxNode(const OrtGraphViewer* graph_viewer) { + int node_count = 0; + graph_api_->OrtGraph_NumberOfNodes(graph_viewer, &node_count); + assert(node_count == 1); + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph_viewer, 0, &node); + const char* opType = nullptr; + graph_api_->OrtNode_GetOpType(node, &opType); + assert(strcmp(opType, EPCONTEXT_OP.c_str()) == 0); + + size_t key_count = 0; + graph_api_->OrtNode_GetAttributeKeyCount(node, COMPUTE_CAPABILITY.c_str(), &key_count); + // Show the warning if compute capability is not matched + if (key_count > 0) { + const char* model_compute_capability = nullptr; + graph_api_->OrtNode_GetAttributeStr(node, COMPUTE_CAPABILITY.c_str(), &model_compute_capability); + // Verify if engine was compiled with ampere+ hardware compatibility enabled + if (strcmp(model_compute_capability, "80+") == 0) { +// if (std::stoi(compute_capability_) < 80) { +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] However, this GPU doesn't match. The compute capability of the GPU: " << compute_capability_; +// } + } else if (strcmp(model_compute_capability, compute_capability_.c_str()) != 0) { +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine was compiled for a different compatibility level and might not work or perform suboptimal"; +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the engine: " << model_compute_capability; +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the GPU: " << compute_capability_; + } + } + + // "embed_mode" attr and "ep_cache_context" attr should be present + graph_api_->OrtNode_GetAttributeKeyCount(node, EMBED_MODE.c_str(), &key_count); + assert(key_count > 0); + graph_api_->OrtNode_GetAttributeKeyCount(node, EP_CACHE_CONTEXT.c_str(), &key_count); + assert(key_count > 0); + + int64_t embed_mode = -1; + graph_api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str(), &embed_mode); + if (embed_mode == 1) { + // engine binary data +// LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; + } + + return true; +} +} diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h new file mode 100644 index 00000000..77efc11f --- /dev/null +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include "onnxruntime_c_api.h" +#include "nv_includes.h" + +namespace onnxruntime { + +static const std::string EPCONTEXT_OP = "EPContext"; +static const std::string EMBED_MODE = "embed_mode"; +static const std::string EP_CACHE_CONTEXT = "ep_cache_context"; +static const std::string COMPUTE_CAPABILITY = "hardware_architecture"; +static const std::string ONNX_MODEL_FILENAME = "onnx_model_filename"; +static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft"; +static const std::string EPCONTEXT_WARNING = + "It's suggested to set the ORT graph optimization level to 0 and \ + make \"embed_mode\" to 0 (\"ep_cache_context\" is the cache path)\ + for the best model loading time"; + +bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer); +std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); +std::string GetCtxModelPath(const std::string& ep_context_file_path, + const std::string& original_model_path); +bool IsAbsolutePath(const std::string& path_string); +bool IsRelativePathToParentPath(const std::string& path_string); + +class TensorRTCacheModelHandler { + public: + TensorRTCacheModelHandler(std::unique_ptr* trt_engine, + nvinfer1::IRuntime* trt_runtime, + std::string ep_context_model_path, + std::string compute_capability, + bool weight_stripped_engine_refit, + std::string onnx_model_folder_path, + bool detailed_build_log) + : trt_engine_(trt_engine), + trt_runtime_(trt_runtime), + ep_context_model_path_(ep_context_model_path), + compute_capability_(compute_capability), + weight_stripped_engine_refit_(weight_stripped_engine_refit), + onnx_model_folder_path_(onnx_model_folder_path), + detailed_build_log_(detailed_build_log) { + api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); + graph_api_ = api_->GetGraphApi(ORT_API_VERSION); + } + bool ValidateEPCtxNode(const OrtGraphViewer* graph_viewer); + + OrtStatusPtr GetEpContextFromGraph(const OrtGraphViewer* graph_viewer); + + private: + std::unique_ptr* trt_engine_; + nvinfer1::IRuntime* trt_runtime_; + std::string ep_context_model_path_; // If using context model, it implies context model and engine cache is in the same directory + std::string compute_capability_; + bool weight_stripped_engine_refit_; + std::string onnx_model_folder_path_; + bool detailed_build_log_; + const OrtApi* api_; + const OrtGraphApi* graph_api_; +}; // TRTCacheModelHandler +} diff --git a/plugin_execution_providers/tensorrt/ort_trt_int8_cal_table.fbs.h b/plugin_execution_providers/tensorrt/ort_trt_int8_cal_table.fbs.h new file mode 100644 index 00000000..9e4324fb --- /dev/null +++ b/plugin_execution_providers/tensorrt/ort_trt_int8_cal_table.fbs.h @@ -0,0 +1,144 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +#ifndef FLATBUFFERS_GENERATED_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_ +#define FLATBUFFERS_GENERATED_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace CalTableFlatBuffers { + +struct KeyValue; +struct KeyValueBuilder; + +struct TrtTable; +struct TrtTableBuilder; + +struct KeyValue FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef KeyValueBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_KEY = 4, + VT_VALUE = 6 + }; + const flatbuffers::String* key() const { + return GetPointer(VT_KEY); + } + bool KeyCompareLessThan(const KeyValue* o) const { + return *key() < *o->key(); + } + int KeyCompareWithValue(const char* val) const { + return strcmp(key()->c_str(), val); + } + const flatbuffers::String* value() const { + return GetPointer(VT_VALUE); + } + bool Verify(flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && + VerifyOffsetRequired(verifier, VT_KEY) && + verifier.VerifyString(key()) && + VerifyOffset(verifier, VT_VALUE) && + verifier.VerifyString(value()) && + verifier.EndTable(); + } +}; + +struct KeyValueBuilder { + typedef KeyValue Table; + flatbuffers::FlatBufferBuilder& fbb_; + flatbuffers::uoffset_t start_; + void add_key(flatbuffers::Offset key) { + fbb_.AddOffset(KeyValue::VT_KEY, key); + } + void add_value(flatbuffers::Offset value) { + fbb_.AddOffset(KeyValue::VT_VALUE, value); + } + explicit KeyValueBuilder(flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + KeyValueBuilder& operator=(const KeyValueBuilder&); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + fbb_.Required(o, KeyValue::VT_KEY); + return o; + } +}; + +inline flatbuffers::Offset CreateKeyValue( + flatbuffers::FlatBufferBuilder& _fbb, + flatbuffers::Offset key = 0, + flatbuffers::Offset value = 0) { + KeyValueBuilder builder_(_fbb); + builder_.add_value(value); + builder_.add_key(key); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateKeyValueDirect( + flatbuffers::FlatBufferBuilder& _fbb, + const char* key = nullptr, + const char* value = nullptr) { + auto key__ = key ? _fbb.CreateString(key) : 0; + auto value__ = value ? _fbb.CreateString(value) : 0; + return CalTableFlatBuffers::CreateKeyValue( + _fbb, + key__, + value__); +} + +struct TrtTable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TrtTableBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DICT = 4 + }; + const flatbuffers::Vector>* dict() const { + return GetPointer>*>(VT_DICT); + } + bool Verify(flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DICT) && + verifier.VerifyVector(dict()) && + verifier.VerifyVectorOfTables(dict()) && + verifier.EndTable(); + } +}; + +struct TrtTableBuilder { + typedef TrtTable Table; + flatbuffers::FlatBufferBuilder& fbb_; + flatbuffers::uoffset_t start_; + void add_dict(flatbuffers::Offset>> dict) { + fbb_.AddOffset(TrtTable::VT_DICT, dict); + } + explicit TrtTableBuilder(flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TrtTableBuilder& operator=(const TrtTableBuilder&); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTrtTable( + flatbuffers::FlatBufferBuilder& _fbb, + flatbuffers::Offset>> dict = 0) { + TrtTableBuilder builder_(_fbb); + builder_.add_dict(dict); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTrtTableDirect( + flatbuffers::FlatBufferBuilder& _fbb, + std::vector>* dict = nullptr) { + auto dict__ = dict ? _fbb.CreateVectorOfSortedTables(dict) : 0; + return CalTableFlatBuffers::CreateTrtTable( + _fbb, + dict__); +} + +} // namespace CalTableFlatBuffers + +#endif // FLATBUFFERS_GENERATED_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_ diff --git a/plugin_execution_providers/tensorrt/tensorrt_cuda_allocator.cc b/plugin_execution_providers/tensorrt/tensorrt_cuda_allocator.cc new file mode 100644 index 00000000..89e62dae --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_cuda_allocator.cc @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "tensorrt_cuda_allocator.h" + +void CUDA_RETURN_IF_ERROR(cudaError_t res); + +namespace onnxruntime { +void CUDAAllocator::CheckDevice(bool throw_when_fail) const { +#ifndef NDEBUG + // check device to match at debug build + // if it's expected to change, call cudaSetDevice instead of the check + int current_device; + auto cuda_err = cudaGetDevice(¤t_device); + if (cuda_err == cudaSuccess) { + assert(current_device == CUDAAllocator::GetDeviceId()); + } else if (throw_when_fail) { + CUDA_RETURN_IF_ERROR(cuda_err); + } +#else +// ORT_UNUSED_PARAMETER(throw_when_fail); +#endif +} + +void CUDAAllocator::SetDevice(bool throw_when_fail) const { + int current_device; + auto cuda_err = cudaGetDevice(¤t_device); + if (cuda_err == cudaSuccess) { + int allocator_device_id = CUDAAllocator::GetDeviceId(); + if (current_device != allocator_device_id) { + cuda_err = cudaSetDevice(allocator_device_id); + } + } + + if (cuda_err != cudaSuccess && throw_when_fail) { + CUDA_RETURN_IF_ERROR(cuda_err); + } +} + +void* CUDAAllocator::Alloc(size_t size) { + SetDevice(true); + CheckDevice(true); + void* p = nullptr; + if (size > 0) { + // BFCArena was updated recently to handle the exception and adjust the request size + CUDA_RETURN_IF_ERROR(cudaMalloc((void**)&p, size)); + } + return p; +} + +void CUDAAllocator::Free(void* p) { + SetDevice(false); + CheckDevice(false); // ignore CUDA failure when free + cudaFree(p); // do not throw error since it's OK for cudaFree to fail during shutdown +} + +const OrtMemoryInfo* CUDAAllocator::Info() const { + return mem_info_; +} + +void* CUDAPinnedAllocator::Alloc(size_t size) { + void* p = nullptr; + if (size > 0) { + CUDA_RETURN_IF_ERROR(cudaMallocHost((void**)&p, size)); + } + return p; +} + +void CUDAPinnedAllocator::Free(void* p) { + CUDA_RETURN_IF_ERROR(cudaFreeHost(p)); +} + +const OrtMemoryInfo* CUDAPinnedAllocator::Info() const { + return mem_info_; +} + +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/tensorrt_cuda_allocator.h b/plugin_execution_providers/tensorrt/tensorrt_cuda_allocator.h new file mode 100644 index 00000000..64767d8e --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_cuda_allocator.h @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "onnxruntime_c_api.h" +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" + +namespace onnxruntime { + +// Following names are originally defined in allocator.h +constexpr const char* CUDA_ALLOCATOR = "Cuda"; +constexpr const char* CUDA_PINNED_ALLOCATOR = "CudaPinned"; + +using DeviceId = int16_t; + +struct CUDAAllocator : OrtAllocator { + CUDAAllocator(DeviceId device_id, const char* name = onnxruntime::CUDA_ALLOCATOR) { + OrtAllocator::version = ORT_API_VERSION; + OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; + OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; + OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; + + device_id_ = device_id; + + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + api->CreateMemoryInfo(name, + OrtAllocatorType::OrtDeviceAllocator, + static_cast(device_id), + OrtMemType::OrtMemTypeDefault, + &mem_info_); + } + //~CUDAAllocator(); + + void* Alloc(size_t size); + void Free(void* p); + const OrtMemoryInfo* Info() const; + DeviceId GetDeviceId() const { return device_id_; }; + + private: + CUDAAllocator(const CUDAAllocator&) = delete; + CUDAAllocator& operator=(const CUDAAllocator&) = delete; + + void CheckDevice(bool throw_when_fail) const; + void SetDevice(bool throw_when_fail) const; + + DeviceId device_id_; + OrtMemoryInfo* mem_info_ = nullptr; +}; + +struct CUDAPinnedAllocator : OrtAllocator { + CUDAPinnedAllocator(const char* name = onnxruntime::CUDA_PINNED_ALLOCATOR) { + OrtAllocator::version = ORT_API_VERSION; + OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; + OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; + OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + api->CreateMemoryInfo(name, + OrtAllocatorType::OrtDeviceAllocator, + 0 /* CPU device always with id 0 */, + OrtMemType::OrtMemTypeDefault, + &mem_info_); + } + //~CUDAPinnedAllocator(); + + void* Alloc(size_t size); + void Free(void* p); + const OrtMemoryInfo* Info() const; + + DeviceId GetDeviceId() const { return device_id_; }; + + private: + CUDAPinnedAllocator(const CUDAPinnedAllocator&) = delete; + CUDAPinnedAllocator& operator=(const CUDAPinnedAllocator&) = delete; + + DeviceId device_id_ = 0; + OrtMemoryInfo* mem_info_ = nullptr; +}; + + +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc new file mode 100644 index 00000000..879115f7 --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -0,0 +1,3822 @@ +#include +#include +#include +#include +#include +#include + +#include "ep_abi_utils.h" +#include "tensorrt_execution_provider.h" +#include "tensorrt_execution_provider_utils.h" +#include "tensorrt_cuda_allocator.h" +//#include "onnx_ctx_model_helper.h" +#include "onnx/onnx_pb.h" +#include "cuda/unary_elementwise_ops_impl.h" + +#ifdef _WIN32 +#include +#define LIBTYPE HINSTANCE +#define OPENLIB(libname) LoadLibrary(libname) +#define LIBFUNC(lib, fn) GetProcAddress((lib), (fn)) +#else +#include +#define LIBTYPE void* +#define OPENLIB(libname) dlopen((libname), RTLD_LAZY) +#define LIBFUNC(lib, fn) dlsym((lib), (fn)) +#endif + +void CUDA_RETURN_IF_ERROR(cudaError_t res) { + if (res != cudaSuccess) abort(); +} + +namespace onnxruntime { + +static const std::string tensorrtEp = "tensorrtEp"; +const OrtApi& ort_api = Ort::GetApi(); + +struct MemcpyFromHost : OrtCustomOp { + MemcpyFromHost() { + OrtCustomOp::version = ORT_API_VERSION; + OrtCustomOp::GetName = [](const struct OrtCustomOp* op) { return "MemcpyFromHost"; }; + OrtCustomOp::GetExecutionProviderType = [](const struct OrtCustomOp* op) { return tensorrtEp.c_str(); }; + OrtCustomOp::CreateKernelV2 = [](const struct OrtCustomOp* op, const OrtApi* api, const OrtKernelInfo* info, void** kernel) -> OrtStatusPtr { + return nullptr; + }; + OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + void* stream = nullptr; + api->KernelContext_GetGPUComputeStream(context, &stream); + + const OrtValue* input = nullptr; + api->KernelContext_GetInput(context, 0, &input); + OrtTensorTypeAndShapeInfo* shape_info; + api->GetTensorTypeAndShape(input, &shape_info); + size_t dim_count = 0; + api->GetDimensionsCount(shape_info, &dim_count); + std::vector dim(dim_count, 0); + api->GetDimensions(shape_info, dim.data(), dim_count); + + OrtValue* output = nullptr; + api->KernelContext_GetOutput(context, 0, dim.data(), dim.size(), &output); + + void *input_raw = nullptr, *output_raw = nullptr; + api->GetTensorMutableData(const_cast(input), &input_raw); + api->GetTensorMutableData(output, &output_raw); + + size_t count = dim[0]; + for (size_t i = 1; i < dim_count; i++) count *= dim[i]; + cudaMemcpyAsync(output_raw, input_raw, count * sizeof(float), cudaMemcpyHostToDevice, static_cast(stream)); // TODO(leca): other data type + + return nullptr; + }; + OrtCustomOp::GetInputTypeCount = [](const struct OrtCustomOp* op) -> size_t { return 1; }; + OrtCustomOp::GetOutputTypeCount = [](const struct OrtCustomOp* op) -> size_t { return 1; }; + OrtCustomOp::GetInputMemoryType = [](const struct OrtCustomOp* op, size_t index) { return OrtMemType::OrtMemTypeCPUInput; }; + OrtCustomOp::GetInputType = [](const struct OrtCustomOp* op, size_t index) { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; + OrtCustomOp::GetOutputType = [](const struct OrtCustomOp* op, size_t index) { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; // TODO(leca): other data type + OrtCustomOp::GetStartVersion = [](const struct OrtCustomOp* op) { return 1; }; + } +}; + +template +using IAllocatorUniquePtr = std::unique_ptr>; + +// Check if cycle exists in the graph after partitioning +bool FindCycleHelper(size_t i, const std::list* adjacency_map, bool visited[], bool* st, std::vector& cycles) { + if (!visited[i]) { + visited[i] = true; + st[i] = true; + for (auto iter = adjacency_map[i].begin(); iter != adjacency_map[i].end(); ++iter) { + if (!visited[*iter] && FindCycleHelper(*iter, adjacency_map, visited, st, cycles)) { + cycles.push_back(*iter); + return true; + } else if (st[*iter]) { + cycles.push_back(*iter); + return true; + } + } + } + st[i] = false; + return false; +} + +bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept { + size_t alloc_size = size; + if (alignment == 0) { + *out = alloc_size * nmemb; + } else { + size_t alignment_mask = alignment - 1; + *out = (alloc_size * nmemb + alignment_mask) & ~static_cast(alignment_mask); + } + return true; +} + +template +IAllocatorUniquePtr MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator, size_t count_or_bytes) { + size_t alloc_size = count_or_bytes; + // if T is not void, 'count_or_bytes' == number of items so allow for that + if constexpr (!std::is_void::value) { + // sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't + // reachable if T is void. use std::conditional to 'use' void* in the sizeof call + constexpr auto size = sizeof(typename std::conditional::value, void*, T>::type); + CalcMemSizeForArrayWithAlignment(count_or_bytes, size, 0, &alloc_size); + } + + T* p = static_cast(ort_allocator->Alloc(ort_allocator, alloc_size)); + + return IAllocatorUniquePtr{p, + [ort_allocator](T* p) { + ort_allocator->Free(ort_allocator, p); + }}; +} + +bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map& dynamic_range_map) { + // Set dynamic range for input tensors + for (int i = 0; i < network.getNbInputs(); ++i) { + const std::string tensor_name = network.getInput(i)->getName(); + auto dynamic_range_iter = dynamic_range_map.find(tensor_name); + if (dynamic_range_iter != dynamic_range_map.end()) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (!network.getInput(i)->setDynamicRange(-dynamic_range_iter->second, dynamic_range_iter->second)) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + // LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for network input " << tensor_name; + return false; + } + } + } + + // Set dynamic range for activations and weights + for (int i = 0; i < network.getNbLayers(); ++i) { + auto trt_layer = network.getLayer(i); + for (int j = 0, e = trt_layer->getNbOutputs(); j < e; ++j) { + const std::string tensor_name = trt_layer->getOutput(j)->getName(); + auto dynamic_range_iter = dynamic_range_map.find(tensor_name); + if (dynamic_range_iter != dynamic_range_map.end()) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (!trt_layer->getOutput(j)->setDynamicRange(-dynamic_range_iter->second, dynamic_range_iter->second)) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + // LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for tensor " << tensor_name; + return false; + } + } else if (trt_layer->getType() == nvinfer1::LayerType::kCONSTANT) { + nvinfer1::IConstantLayer* const_layer = static_cast(trt_layer); + const std::string const_layer_name = const_layer->getName(); + auto trt_weights = const_layer->getWeights(); + double max_weight = std::numeric_limits::min(); + for (int64_t k = 0, end = trt_weights.count; k < end; ++k) { + double weight{}; + switch (trt_weights.type) { + case nvinfer1::DataType::kFLOAT: + weight = static_cast(trt_weights.values)[k]; + break; + case nvinfer1::DataType::kBOOL: + weight = static_cast(trt_weights.values)[k]; + break; + case nvinfer1::DataType::kINT8: + weight = static_cast(trt_weights.values)[k]; + break; + case nvinfer1::DataType::kHALF: + weight = static_cast(trt_weights.values)[k]; + break; + case nvinfer1::DataType::kINT32: + weight = static_cast(trt_weights.values)[k]; + break; +#if NV_TENSORRT_MAJOR >= 10 + case nvinfer1::DataType::kINT64: + weight = static_cast(static_cast(trt_weights.values)[k]); + break; +#endif // NV_TENSORRT_MAJOR >= 10 + default: + // LOGS_DEFAULT(ERROR) << "Found unsupported datatype for layer " << const_layer_name; + return false; + } + max_weight = std::max(max_weight, std::abs(weight)); + } +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (!trt_layer->getOutput(j)->setDynamicRange(static_cast(-max_weight), static_cast(max_weight))) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + // LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for layer " << const_layer_name; + return false; + } + } + } + } + return true; +} + +std::vector SplitToStringVec(std::string const& s, char separator) { + std::vector splitted; + + for (size_t start = 0; start < s.length();) { + size_t separatorIndex = s.find(separator, start); + if (separatorIndex == std::string::npos) { + separatorIndex = s.length(); + } + splitted.emplace_back(s.substr(start, separatorIndex - start)); + start = separatorIndex + 1; + } + + return splitted; +} + +nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { + nvinfer1::TacticSources disabledTactics = 0; + nvinfer1::TacticSources enabledTactics = 0; + std::vector tacticList = SplitToStringVec(tactic_string, ','); + for (auto& t : tacticList) { + bool enable{false}; + if (t.front() == '+') { + enable = true; + } else if (t.front() != '-') { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic source must be prefixed with + or - skipping: " << t; + } + t.erase(0, 1); + + const auto toUpper = [](std::string& sourceName) { + std::transform(sourceName.begin(), sourceName.end(), sourceName.begin(), + [](char c) { return static_cast(std::toupper(c)); }); + return sourceName; + }; + + nvinfer1::TacticSource source{}; + t = toUpper(t); + if (t == "CUBLAS") { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS is deprecated in TensorRT 10.0"; +#if NV_TENSORRT_MAJOR < 10 + source = nvinfer1::TacticSource::kCUBLAS; +#endif + } else if (t == "CUBLASLT" || t == "CUBLAS_LT") { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS_LT is deprecated in TensorRT 9.0"; +#if NV_TENSORRT_MAJOR < 9 + source = nvinfer1::TacticSource::kCUBLAS_LT; +#endif + } else if (t == "CUDNN") { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUDNN is deprecated in TensorRT 10.0"; +#if NV_TENSORRT_MAJOR < 10 + source = nvinfer1::TacticSource::kCUDNN; +#endif + } else if (t == "EDGE_MASK_CONVOLUTIONS") { + source = nvinfer1::TacticSource::kEDGE_MASK_CONVOLUTIONS; + } else if (t == "JIT_CONVOLUTIONS") { + source = nvinfer1::TacticSource::kJIT_CONVOLUTIONS; + } else { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic source was not found with name: " << t; + } + + uint32_t sourceBit = 1U << static_cast(source); + + if (enable) { + enabledTactics |= sourceBit; + } else { + disabledTactics |= sourceBit; + } + } + return enabledTactics & ~disabledTactics; +} + +inline std::vector loadTimingCacheFile(const std::string inFileName) { + std::ifstream iFile(inFileName, std::ios::in | std::ios::binary); + if (!iFile) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not read timing cache from: " << inFileName + // << ". A new timing cache will be generated and written."; + return std::vector(); + } + iFile.seekg(0, std::ifstream::end); + size_t fsize = iFile.tellg(); + iFile.seekg(0, std::ifstream::beg); + std::vector content(fsize); + iFile.read(content.data(), fsize); + iFile.close(); + return content; +} + +inline void saveTimingCacheFile(const std::string outFileName, const nvinfer1::IHostMemory* blob) { + std::ofstream oFile(outFileName, std::ios::out | std::ios::binary); + if (!oFile) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not write timing cache to: " << outFileName; + return; + } + oFile.write((char*)blob->data(), blob->size()); + oFile.close(); +} + +#if NV_TENSORRT_MAJOR >= 10 +void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, + uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept { + // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr + // even for empty tensors, so allocate a dummy byte. + size = std::max(size, static_cast(1)); + if (size > allocated_size) { + cudaFree(outputPtr); + outputPtr = nullptr; + allocated_size = 0; + if (cudaMalloc(&outputPtr, size) == cudaSuccess) { + allocated_size = size; + } + } + // if cudaMalloc fails, returns nullptr. + return outputPtr; +} +#else +// Only override this method when TensorRT <= 8.6 +void* OutputAllocator::reallocateOutput(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, + uint64_t /*alignment*/) noexcept { + // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr + // even for empty tensors, so allocate a dummy byte. + size = std::max(size, static_cast(1)); + if (size > allocated_size) { + cudaFree(outputPtr); + outputPtr = nullptr; + allocated_size = 0; + if (cudaMalloc(&outputPtr, size) == cudaSuccess) { + allocated_size = size; + } + } + // if cudaMalloc fails, returns nullptr. + return outputPtr; +} +#endif + +void OutputAllocator::notifyShape(char const* /*tensorName*/, nvinfer1::Dims const& dims) noexcept { + output_shapes.clear(); + output_shapes.reserve(dims.nbDims); + for (int i = 0; i < dims.nbDims; i++) { + output_shapes.push_back(dims.d[i]); + } +} + +TensorrtLogger& GetTensorrtLogger(bool verbose_log) { + const auto log_level = verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING; + static TensorrtLogger trt_logger(log_level); + if (log_level != trt_logger.get_level()) { + trt_logger.set_level(verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING); + } + return trt_logger; +} + +std::unique_lock TensorrtExecutionProvider::GetApiLock() const { + static std::mutex singleton; + return std::unique_lock(singleton); +} + +template +void GetShapeOfShapeTensor(Ort::ConstValue& input_tensor, + void* shape_values, + int shape_size, + cudaStream_t stream) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(shape_values, + input_tensor.GetTensorData(), + shape_size * sizeof(T), + cudaMemcpyDeviceToHost, + stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); +} + +bool ApplyProfileShapesFromProviderOptions(std::vector& trt_profiles, + nvinfer1::ITensor* input, + std::unordered_map>>& profile_min_shapes, + std::unordered_map>>& profile_max_shapes, + std::unordered_map>>& profile_opt_shapes, + ShapeRangesMap& input_explicit_shape_ranges) { + if (trt_profiles.size() == 0) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Number of optimization profiles should be greater than 0, but it's 0."; + return false; + } + + const std::string& input_name = input->getName(); + if (profile_min_shapes.find(input_name) == profile_min_shapes.end()) { + return false; + } + + if (input_explicit_shape_ranges.find(input_name) == input_explicit_shape_ranges.end()) { + std::unordered_map>> inner_map; + input_explicit_shape_ranges[input_name] = inner_map; + } + + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Begin to apply profile shapes ..."; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Input tensor name is '" << input_name << "', number of profiles found is " << trt_profiles.size(); + + for (size_t i = 0; i < trt_profiles.size(); i++) { + nvinfer1::Dims dims = input->getDimensions(); + int nb_dims = dims.nbDims; + + auto trt_profile = trt_profiles[i]; + + // Shape tensor + if (input->isShapeTensor()) { + int shape_size = nb_dims == 0 ? 1 : static_cast(profile_min_shapes[input_name][i].size()); + std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); + + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shape size of this shape tensor is " << shape_size; + + for (int j = 0; j < shape_size; j++) { + auto min_value = profile_min_shapes[input_name][i][j]; + auto max_value = profile_max_shapes[input_name][i][j]; + auto opt_value = profile_opt_shapes[input_name][i][j]; + shapes_min[j] = static_cast(min_value); + shapes_max[j] = static_cast(max_value); + shapes_opt[j] = static_cast(opt_value); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shapes_min.d[" << j << "] is " << shapes_min[j]; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shapes_max.d[" << j << "] is " << shapes_max[j]; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shapes_opt.d[" << j << "] is " << shapes_opt[j]; + + if (input_explicit_shape_ranges[input_name].find(j) == input_explicit_shape_ranges[input_name].end()) { + std::vector> profile_vector(trt_profiles.size()); + input_explicit_shape_ranges[input_name][j] = profile_vector; + } + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(min_value); + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(max_value); + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(opt_value); + } + + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); + } + // Execution tensor + else { + nvinfer1::Dims dims_min, dims_opt, dims_max; + dims_min.nbDims = nb_dims; + dims_max.nbDims = nb_dims; + dims_opt.nbDims = nb_dims; + + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] number of dimension of this execution tensor is " << nb_dims; + + for (int j = 0; j < nb_dims; j++) { + if (dims.d[j] == -1) { + auto min_value = profile_min_shapes[input_name][i][j]; + auto max_value = profile_max_shapes[input_name][i][j]; + auto opt_value = profile_opt_shapes[input_name][i][j]; + dims_min.d[j] = static_cast(min_value); + dims_max.d[j] = static_cast(max_value); + dims_opt.d[j] = static_cast(opt_value); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dims_min.d[" << j << "] is " << dims_min.d[j]; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dims_max.d[" << j << "] is " << dims_max.d[j]; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dims_opt.d[" << j << "] is " << dims_opt.d[j]; + + if (input_explicit_shape_ranges[input_name].find(j) == input_explicit_shape_ranges[input_name].end()) { + std::vector> profile_vector(trt_profiles.size()); + input_explicit_shape_ranges[input_name][j] = profile_vector; + } + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(min_value); + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(max_value); + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(opt_value); + } else { + dims_min.d[j] = dims.d[j]; + dims_max.d[j] = dims.d[j]; + dims_opt.d[j] = dims.d[j]; + } + } + + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt); + } + } + return true; +} + +OrtStatusPtr ApplyProfileShapesFromInputTensorValue(std::vector& trt_profiles, + Ort::KernelContext ctx, + nvinfer1::ITensor* input, + ShapeRangesMap& shape_ranges, + const std::unordered_map& input_indexes, + std::unordered_map>& shape_tensor_values, + std::unordered_map>& shape_tensor_values_int64, + cudaStream_t stream, + bool* engine_update) { + for (size_t i = 0; i < trt_profiles.size(); i++) { + const std::string& input_name = input->getName(); + nvinfer1::Dims dims = input->getDimensions(); + int nb_dims = dims.nbDims; + + size_t input_index = 0; + const auto& iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } + + auto input_tensor = ctx.GetInput(input_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + auto& shape_ranges_per_input = shape_ranges[input_name]; + + auto trt_profile = trt_profiles[i]; + + // If there are multiple profiles, for second and rest of profiles, simply copy the min/max/opt profile values from the first profile. + // Following "if statement" won't be executed since TRT EP currently only allows single profile for non-explicit profiles case. + if (i > 0) { + if (input->isShapeTensor()) { + // shape tensor + int shape_size = nb_dims == 0 ? 1 : static_cast(tensor_shapes[0]); + std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); + for (int j = 0; j < shape_size; j++) { + shapes_min[j] = *(trt_profiles[0]->getShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN)); + shapes_max[j] = *(trt_profiles[0]->getShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX)); + shapes_opt[j] = *(trt_profiles[0]->getShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT)); + } + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); + } else { + // execution tensor + nvinfer1::Dims dims_min, dims_opt, dims_max; + dims_min = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN); + dims_max = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX); + dims_opt = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt); + } + continue; + } + + // Create shape profile + if (input->isShapeTensor()) { + // Get shape values for shape tensor input + const auto tensor_type = tensor_info.GetElementType(); + // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension + int shape_size = dims.nbDims == 0 ? 1 : static_cast(tensor_shapes[0]); + // For setting TRT optimization profile. (Note: the min/opt/max profile values are still int32 even though int64 is supported after TRT 10) + std::vector values(shape_size); + + switch (tensor_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + auto buffer = std::make_unique(shape_size); + GetShapeOfShapeTensor(input_tensor, buffer.get(), shape_size, stream); + shape_tensor_values[input_name].resize(shape_size); + for (int j = 0; j < shape_size; ++j) { + shape_tensor_values[input_name][j] = buffer[j]; + values[j] = buffer[j]; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + auto buffer = std::make_unique(shape_size); + GetShapeOfShapeTensor(input_tensor, buffer.get(), shape_size, stream); + shape_tensor_values_int64[input_name].resize(shape_size); + for (int j = 0; j < shape_size; ++j) { + shape_tensor_values_int64[input_name][j] = buffer[j]; + values[j] = static_cast(buffer[j]); + } + break; + } + default: { + return ort_api.CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported.").c_str()); + } + } + + // Update shape ranges + std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); + int shape_range_size = static_cast(shape_ranges_per_input.size()); + if (shape_size == shape_range_size) { + // If shape size matches, check/update shape range + for (int j = 0; j < shape_size; ++j) { + auto& shape_range = shape_ranges_per_input[j][0]; // only has one profile + shapes_min[j] = static_cast(shape_range[0]); + shapes_max[j] = static_cast(shape_range[1]); + shapes_opt[j] = static_cast(shape_range[2]); + + const auto& tensor_shape_value = values[j]; + // Update shape range lower bound + if (tensor_shape_value < shape_range[0]) { + shape_range[0] = tensor_shape_value; + shapes_min[j] = tensor_shape_value; + *engine_update = true; + } + // Update shape range upper bound + if (tensor_shape_value > shape_range[1]) { + shape_range[1] = tensor_shape_value; + shape_range[2] = tensor_shape_value; + shapes_max[j] = tensor_shape_value; + shapes_opt[j] = tensor_shape_value; + *engine_update = true; + } + } + } else { + // If shape size doesn't match, initialize shape_range with the new shape value + shape_ranges_per_input.clear(); + for (int j = 0; j < shape_size; ++j) { + const auto& tensor_shape_value = values[j]; + std::vector> profile_vector; + std::vector shape_vector{tensor_shape_value, tensor_shape_value, tensor_shape_value}; + profile_vector.push_back(shape_vector); // only one profile needed + shape_ranges_per_input[j] = profile_vector; + shapes_min[j] = tensor_shape_value; + shapes_opt[j] = tensor_shape_value; + shapes_max[j] = tensor_shape_value; + } + *engine_update = true; + } + + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); + } else { // Execution tensor + nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims); + for (int j = 0, end = nb_dims; j < end; ++j) { + const auto& tensor_shape = tensor_shapes[j]; + if (shape_ranges_per_input.find(j) != shape_ranges_per_input.end()) { + auto& shape_range = shape_ranges_per_input[j][0]; // only has one profile + dims_min.d[j] = static_cast(shape_range[0]); + dims_max.d[j] = static_cast(shape_range[1]); + dims_opt.d[j] = static_cast(shape_range[2]); + + // Update minimum dimension + if (tensor_shape < shape_range[0]) { + shape_range[0] = tensor_shape; + dims_min.d[j] = static_cast(tensor_shape); + *engine_update = true; + } + // Update maximum dimension + if (tensor_shape > shape_range[1]) { + shape_range[1] = tensor_shape; + shape_range[2] = tensor_shape; + dims_max.d[j] = static_cast(tensor_shape); + dims_opt.d[j] = static_cast(tensor_shape); + *engine_update = true; + } + } + } + + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt); + } + } + return nullptr; +} + +#define CASE_GET_INPUT_TENSOR(DATA_TYPE, SrcT) \ + case DATA_TYPE: { \ + auto input_tensor_ptr = input_tensor.GetTensorData(); \ + if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ + data = const_cast(input_tensor_ptr); \ + } else { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + data = scratch_buffers.back().get(); \ + } \ + break; \ + } + +#define CASE_GET_CAST_INPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto input_tensor_ptr = input_tensor.GetTensorData(); \ + if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ + data = scratch_buffers.back().get(); \ + cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), elem_cnt); \ + } else { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + data = scratch_buffers.back().get(); \ + } \ + break; \ + } + +#define CASE_GET_OUTPUT_TENSOR(DATA_TYPE, SrcT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + buffers[output_name] = output_tensor_ptr; \ + } else { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + buffers[output_name] = scratch_buffers.back().get(); \ + } \ + break; \ + } + +#define CASE_GET_CAST_OUTPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ + buffers[output_name] = scratch_buffers.back().get(); \ + output_dim_sizes[i] = static_cast(elem_cnt); \ + } else { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + buffers[output_name] = scratch_buffers.back().get(); \ + output_dim_sizes[i] = 1; \ + } \ + break; \ + } + +#define CASE_COPY_TENSOR(DATA_TYPE, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(DstT), cudaMemcpyDeviceToDevice, stream)); \ + } \ + break; \ + } + +#define CASE_CAST_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), elem_cnt); \ + } \ + break; \ + } + +OrtStatusPtr BindContextInput(Ort::KernelContext& ctx, + nvinfer1::ICudaEngine* trt_engine, + nvinfer1::IExecutionContext* trt_context, + const char* input_name, + size_t input_index, + std::unordered_map>& shape_tensor_values, + std::unordered_map>& shape_tensor_values_int64, + std::vector>& scratch_buffers, + OrtAllocator* alloc, + cudaStream_t stream) { + auto input_tensor = ctx.GetInput(input_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + const auto tensor_type = tensor_info.GetElementType(); + /* + * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). + * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. + * + * Examples:
+ * [] = 1
+ * [1,3,4] = 12
+ * [2,0,4] = 0
+ * [-1,3,4] = -1
+ */ + const auto elem_cnt = tensor_info.GetElementCount(); + + if (trt_engine->isShapeInferenceIO(input_name)) { + // Bind "shape tensor" input buffer + + // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension + int shape_size = trt_engine->getTensorShape(input_name).nbDims == 0 ? 1 : static_cast(tensor_shapes[0]); + switch (tensor_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + // get shape tensor value if not present + if (shape_tensor_values.find(input_name) == shape_tensor_values.end()) { + auto input = std::make_unique(shape_size); + GetShapeOfShapeTensor(input_tensor, input.get(), shape_size, stream); + shape_tensor_values[input_name].resize(shape_size); + for (int i = 0; i < shape_size; ++i) { + shape_tensor_values[input_name][i] = input[i]; + } + } + + if (!trt_context->setTensorAddress(input_name, &shape_tensor_values[input_name][0])) { + std::string error_input_name = input_name; + std::string error_msg = + "TensorRT EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + + error_input_name + "'"; + return ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + // get shape tensor value if not present + if (shape_tensor_values_int64.find(input_name) == shape_tensor_values_int64.end()) { + auto input = std::make_unique(shape_size); + GetShapeOfShapeTensor(input_tensor, input.get(), shape_size, stream); + shape_tensor_values_int64[input_name].resize(shape_size); + for (int i = 0; i < shape_size; ++i) { + shape_tensor_values_int64[input_name][i] = input[i]; + } + } + + if (!trt_context->setTensorAddress(input_name, &shape_tensor_values_int64[input_name][0])) { + std::string error_input_name = input_name; + std::string error_msg = + "TensorRT EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + + error_input_name + "'"; + return ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + } + break; + } + default: { + std::string error_input_name = input_name; + return ort_api.CreateStatus(ORT_EP_FAIL, std::string("The data type of shape tensor should be INT32 or INT64. Please check the data type of " + error_input_name).c_str()); + } + } + } else { + // Set shape for input tensor which is execution tensor + nvinfer1::Dims dims = trt_context->getTensorShape(input_name); + int nb_dims = dims.nbDims; + for (int j = 0, end = nb_dims; j < end; ++j) { + dims.d[j] = static_cast(tensor_shapes[j]); + } + if (!trt_context->setInputShape(input_name, dims)) { + std::string error_input_name = input_name; + return ort_api.CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'").c_str()); + } + + // Bind "execution tensor" input buffer + // + // Note: If an engine binding is an empty tensor, it still needs a non-null memory address, and different tensors should have different addresses. + // Therefore, in the case of empty tensor, TRT EP always allocates a dummy byte. + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#empty-tensors + void* data = nullptr; + switch (tensor_type) { + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) +#if NV_TENSORRT_MAJOR >= 10 + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) +#else + // Cast int64 input to int32 input because TensorRT < 10 doesn't support int64 + CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) +#endif + // Cast double input to float because TensorRT doesn't support double + CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) + default: { + return ort_api.CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported.").c_str()); + } + } + trt_context->setTensorAddress(input_name, data); + } + + return nullptr; +} + +OrtStatusPtr BindContextOutput(Ort::KernelContext& ctx, + nvinfer1::IExecutionContext* trt_context, + const char* output_name, + size_t output_index, + size_t output_type, + size_t i, + std::unordered_map& output_tensors, + std::unordered_map& output_dim_sizes, + DDSOutputAllocatorMap& dds_output_allocator_map, + std::vector>& scratch_buffers, + OrtAllocator* alloc, + std::unordered_map& buffers) { + // Get output shape + nvinfer1::Dims dims = trt_context->getTensorShape(output_name); + int nb_dims = dims.nbDims; + bool is_DDS = false; + std::vector output_shapes(nb_dims); + for (int j = 0, end = nb_dims; j < end; ++j) { + // data-dependent shape + if (dims.d[j] == -1) { + is_DDS = true; + break; + } + output_shapes[j] = dims.d[j]; + } + + auto known_DDS = dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end(); + + // If the output tensor has data-dependent shape, TRT EP will provide an IOutputAllocator for enqueueV3 to dynamically allocate memory buffer. + // Once enqueueV3 returns, TRT EP will then bind the output allocation to ORT kernel context output. + // (Please note that we take strategy A mentioned in https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#dynamic-shaped-output, + // which we defer allocation until the size is known and don't call IExecution::setTensorAddress) + // + // Otherwise, if the shape of the output tensor is known prior to the runtime, ORT will pre-allocate memory buffer for the output tensor for enqueueV3. + if (is_DDS || known_DDS) { + if (!known_DDS) { + auto allocatorPtr = std::make_unique(); + trt_context->setOutputAllocator(output_name, allocatorPtr.get()); + dds_output_allocator_map[output_name] = std::move(allocatorPtr); + } + } else { + output_tensors[i] = ctx.GetOutput(output_index, output_shapes); + auto& output_tensor = output_tensors[i]; + const auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); + + switch (output_type) { + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) +#if NV_TENSORRT_MAJOR >= 10 + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) +#else + // Allocate int32 CUDA memory for int64 output type because TensorRT < 10 doesn't support int64 + CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) +#endif + // Allocate float CUDA memory for double output type because TensorRT doesn't support double + CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) + default: { + return ort_api.CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported.").c_str()); + } + } + trt_context->setTensorAddress(output_name, buffers[output_name]); + } + + return nullptr; +} + +OrtStatusPtr BindKernelOutput(Ort::KernelContext& ctx, + OrtMemoryInfo* /*mem_info*/, + DDSOutputAllocatorMap& allocator_map, + char const* output_name, + size_t output_index, + size_t output_type, + cudaStream_t stream) { + auto allocator = allocator_map[output_name].get(); + auto& shape = allocator->getOutputShape(); + auto output_tensor = ctx.GetOutput(output_index, shape); + + /* + * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). + * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. + * + * Examples:
+ * [] = 1
+ * [1,3,4] = 12
+ * [2,0,4] = 0
+ * [-1,3,4] = -1
+ */ + auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); + + /* + * Copy output data from allocation buffer to ORT kernel context output location or + * cast (int32 or float) -> (int64 or double) to ORT kernel context output location. + * + * Note: + * 1. If the output tensor is empty tensor (i.e. any of the dimension is 0) which means element count is 0, + * TRT EP does not perform cuda memory copy nor cuda cast to prevent overwriting other location that might belong to other tensors. + * 2. The cudaMemcpyAsync() and cuda::Impl_Cast() (implemented as _UnaryElementWise() in cuda ep) are all async, but we + * don't need to explicitly call cudaStreamSynchronize() after those APIs due to CUDA EP and TRT EP uses same stream, + * and within the same stream, operations are guaranteed to be executed in order. + */ + switch (output_type) { + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) +#if NV_TENSORRT_MAJOR >= 10 + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) +#else + // The allocation buffer holds the int32 output data since TRT doesn't support int64. So, we need to cast the data (int32 -> int64) for ORT kernel output. +// CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int32_t, int64_t) +#endif + // The allocation buffer holds the float output data since TRT doesn't support double. So, we need to cast the data (float -> double) for ORT kernel output. + // CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) + default: { + return ort_api.CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported.").c_str()); + } + } + return nullptr; +} + +/* +// Detect and remove cycles from supported node list +bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const OrtGraphViewer* graph, const HashValue& model_hash, bool remove_cycles) const { + const size_t* nodes_index = nullptr; + size_t node_count = 0; + graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_index, &node_count); + bool trt_cycle = true, cycle_detected = false; + while (trt_cycle) { + trt_cycle = false; + std::unordered_map node_to_index_map; + std::unordered_map index_to_node_map; + std::unordered_map> input_to_nodes_map, node_to_outputs_map; + std::unordered_set non_trt_node_index; + for (size_t i = 0; i < node_count; ++i) { + non_trt_node_index.insert(nodes_index[i]); + } + size_t id = 0; + int subgraph_index = 0; + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + // Construct subgraph from node list + std::unique_ptr subgraph = GetSubGraph(group, graph, model_hash, subgraph_index); + + // Create node to inputs/outputs/index maps + const std::string node_name = subgraph->meta_def->name; + if (node_to_index_map.find(node_name) == node_to_index_map.end()) { + index_to_node_map[id] = node_name; + node_to_index_map[node_name] = id++; + } + + if (subgraph->meta_def != nullptr) { + for (size_t j = 0; j < subgraph->meta_def->input_len; j++) { + input_to_nodes_map[std::string(subgraph->meta_def->inputs[j])].insert(node_name); + } + for (size_t j = 0; j < subgraph->meta_def->output_len; j++) { + node_to_outputs_map[node_name].insert(std::string(subgraph->meta_def->outputs[j])); + } + } + + // Remove TensorRT nodes from node index list + for (const auto& index : group.first) { + non_trt_node_index.erase(nodes_index[index]); + } + subgraph_index++; + } + } + + // Add non TensorRT nodes to the maps + for (const auto& index : non_trt_node_index) { + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph, index, &node); + const char* node_name_char = nullptr; + graph_api_->OrtNode_GetName(node, &node_name_char); + const std::string node_name(node_name_char); + if (node_to_index_map.find(node_name) == node_to_index_map.end()) { + index_to_node_map[id] = node_name; + node_to_index_map[node_name] = id++; + } + + size_t input_count = 0; + graph_api_->OrtNode_GetNumInputs(node, &input_count); + for (size_t i = 0; i < input_count; ++i) { + const char* input_name_char = nullptr; + graph_api_->OrtNode_GetIthInputName(node, i, &input_name_char); + input_to_nodes_map[std::string(input_name_char)].insert(node_name); + } + + size_t implicit_input_count = 0; + graph_api_->OrtNode_GetImplicitInputSize(node, &implicit_input_count); + for (size_t i = 0; i < implicit_input_count; ++i) { + const char* input_name_char = nullptr; + graph_api_->OrtNode_GetIthImplicitInputName(node, i, &input_name_char); + input_to_nodes_map[std::string(input_name_char)].insert(node_name); + } + + size_t output_count = 0; + graph_api_->OrtNode_GetNumOutputs(node, &output_count); + for (size_t i = 0; i < output_count; ++i) { + const char* output_name_char = nullptr; + graph_api_->OrtNode_GetIthOutputName(node, i, &output_name_char); + node_to_outputs_map[node_name].insert(std::string(output_name_char)); + } + } + + // Create adjacency list + size_t graph_size = node_to_index_map.size(); + std::list* adjacency_map = new std::list[graph_size]; + for (const auto& node : node_to_outputs_map) { + for (auto iter = node.second.begin(); iter != node.second.end(); ++iter) { + const auto& loc = input_to_nodes_map.find(*iter); + if (loc != input_to_nodes_map.end()) { + size_t parent_node_index = node_to_index_map.find(node.first)->second; + for (auto child_node : loc->second) { + size_t child_node_index = node_to_index_map.find(child_node)->second; + adjacency_map[parent_node_index].push_back(child_node_index); + } + } + } + } + + // Check cycle in the graph + bool* visited = new bool[graph_size]; + bool* st = new bool[graph_size]; + for (size_t i = 0; i < graph_size; ++i) { + visited[i] = false; + st[i] = false; + } + + std::vector cycles; + bool has_cycle = false; + for (size_t i = 0; i < graph_size; ++i) { + if (FindCycleHelper(i, adjacency_map, visited, st, cycles)) { + has_cycle = true; + cycle_detected = true; + break; + } + } + + // Remove TensorRT subgraph from the supported node list if it's part of the cycle + if (has_cycle && remove_cycles) { + for (size_t i = 0; i < cycles.size(); ++i) { + auto loc = index_to_node_map.find(cycles[i]); + if (loc != index_to_node_map.end() && loc->second.find("TRTKernel") != std::string::npos) { + supported_nodes_vector.erase(supported_nodes_vector.begin() + cycles[i]); + trt_cycle = true; + break; + } + } + } + + delete[] adjacency_map; + delete[] visited; + delete[] st; + } + return cycle_detected; +} + +// Check the graph is the subgraph of control flow op +bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const OrtGraphViewer* graph) const { + bool is_subgraph = false; + graph_api_->OrtGraph_IsSubgraph(graph, &is_subgraph); + if (is_subgraph) { + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetParenNode(graph, &node); + const char* node_op_type = nullptr; + graph_api_->OrtNode_GetOpType(node, &node_op_type); + if (control_flow_op_set_.find(std::string(node_op_type)) != control_flow_op_set_.end()) { + return true; + } + } + return false; +} + +// Check whether all the nodes of the graph are assigned to specific ep +bool TensorrtExecutionProvider::AllNodesAssignedToSpecificEP(const OrtGraphViewer* graph, const std::string& provider_type) const { + size_t num_nodes = ort_api_.Graph_NumNodes(graph); + std::vector nodes(num_nodes, nullptr); + RETURN_IF_ERROR(ort_api_.Graph_GetNodes(graph, 1, nodes.data(), nodes.size())); + + for (const OrtNode* node : nodes) { + const char* node_ep_type = ort_api_.Node_GetExecutionProviderType(node); + if (strcmp(node_ep_type, provider_type.c_str())) { + return false; + } + } + return num_nodes != 0; +} + +// Check whether all the nodes of subgraph are supported +bool TensorrtExecutionProvider::IsSubGraphFullySupported(SubGraphCollection_t supported_nodes_vector, const int number_of_ort_nodes) const { + int number_of_trt_nodes = 0; + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + number_of_trt_nodes += static_cast(group.first.size()); + } + } + + return number_of_trt_nodes == number_of_ort_nodes; +} + +std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, const OrtGraphViewer* graph, const HashValue& model_hash, int subgraph_index) const { + const size_t* node_index = nullptr; + size_t nodes_count = 0; + graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &node_index, &nodes_count); + std::unordered_set node_set; + node_set.reserve(graph_nodes_index.first.size()); + for (const auto& index : graph_nodes_index.first) { + node_set.insert(node_index[index]); + } + + // Get parent graph output names + std::unordered_set graph_output_names; + size_t graph_output_size = 0; + graph_api_->OrtGraph_GetOutputSize(graph, &graph_output_size); + for (size_t i = 0; i < graph_output_size; i++) { + char const* output_name = nullptr; + graph_api_->OrtGraph_GetIthOutputName(graph, i, &output_name); + graph_output_names.insert(output_name); + } + + // Find inputs and outputs of the subgraph + std::unique_ptr sub_graph = std::make_unique(); + sub_graph->node_index_len = graph_nodes_index.first.size(); + sub_graph->node_index = new size_t[sub_graph->node_index_len]; + sub_graph->meta_def = new OrtMetaDef(); + std::unordered_set erased; + std::unordered_map input_to_order; + std::unordered_map output_to_order; + int input_order = 0; + int output_order = 0; + + std::vector initializers; + int i = 0; + for (const auto& index : graph_nodes_index.first) { + sub_graph->node_index[i++] = node_index[index]; + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph, node_index[index], &node); + size_t input_size = 0; + graph_api_->OrtNode_GetNumInputs(node, &input_size); + for (size_t j = 0; j < input_size; j++) { + const char* input_name = nullptr; + graph_api_->OrtNode_GetIthInputName(node, j, &input_name); + bool is_initializer = false; + graph_api_->OrtGraph_IsConstantInitializer(graph, input_name, true, &is_initializer); + if (is_initializer) { + initializers.push_back(input_name); + continue; + } + const OrtNode* producer = nullptr; + graph_api_->OrtGraph_GetNodeProducingOutput(graph, input_name, &producer); + // If the input is not produced by any node, it is a graph input + if (producer == nullptr) { + input_to_order[input_name] = input_order++; + continue; + } + size_t producer_index = -1; + graph_api_->OrtNode_GetIndex(producer, &producer_index); + // If the producer node is not in the subgraph, the input is a graph input + if (node_set.find(producer_index) == node_set.end()) { + input_to_order[input_name] = input_order++; + } + } + + size_t implicit_input_size = 0; + graph_api_->OrtNode_GetImplicitInputSize(node, &implicit_input_size); + for (size_t j = 0; j < implicit_input_size; j++) { + const char* input_name = nullptr; + graph_api_->OrtNode_GetIthImplicitInputName(node, j, &input_name); + bool is_initializer = false; + graph_api_->OrtGraph_IsConstantInitializer(graph, input_name, true, &is_initializer); + if (is_initializer) { + initializers.push_back(input_name); + continue; + } + const OrtNode* producer = nullptr; + graph_api_->OrtGraph_GetNodeProducingOutput(graph, input_name, &producer); + // If the input is not produced by any node, it is a graph input + if (producer == nullptr) { + input_to_order[input_name] = input_order++; + continue; + } + size_t producer_index = -1; + graph_api_->OrtNode_GetIndex(producer, &producer_index); + // If the producer node is not in the subgraph, the input is a graph input + if (node_set.find(producer_index) == node_set.end()) { + input_to_order[input_name] = input_order++; + } + } + + size_t output_size = 0; + graph_api_->OrtNode_GetNumOutputs(node, &output_size); + for (size_t j = 0; j < output_size; j++) { + const char* output_name = nullptr; + graph_api_->OrtNode_GetIthOutputName(node, j, &output_name); + // If the output is the graph output, it is a subgraph output + if (graph_output_names.find(output_name) != graph_output_names.end()) { + output_to_order[output_name] = output_order++; + continue; + } + const OrtNode** consumers = nullptr; + size_t consumer_count = 0; + graph_api_->OrtGraph_GetNodesConsumingInput(graph, output_name, &consumers, &consumer_count); + for (size_t k = 0; k < consumer_count; k++) { + size_t consumer_index = -1; + graph_api_->OrtNode_GetIndex(consumers[k], &consumer_index); + // If the consumer node is not in the subgraph, the output is a subgraph output + if (node_set.find(consumer_index) == node_set.end()) { + output_to_order[output_name] = output_order++; + break; + } + } + graph_api_->ReleaseOrtNodeArray(consumers); + } + } + + // Sort inputs and outputs based on their order + std::multimap ordered_inputs, ordered_outputs; + for (const auto& input : input_to_order) { + ordered_inputs.insert(std::pair(input.second, input.first)); + } + for (const auto& output : output_to_order) { + ordered_outputs.insert(std::pair(output.second, output.first)); + } + + // Generate unique kernel name for TRT subgraph + std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(subgraph_index); + bool is_subgraph = false; + graph_api_->OrtGraph_IsSubgraph(graph, &is_subgraph); + const std::string graph_type = is_subgraph ? "subgraph" : "graph"; + const char* graph_name = nullptr; + graph_api_->OrtGraph_GetName(graph, &graph_name); + std::string meta_def_name = "TRTKernel_" + graph_type + "_" + std::string(graph_name) + subgraph_id; + sub_graph->meta_def->name = new char[meta_def_name.length() + 1]; + strcpy(sub_graph->meta_def->name, meta_def_name.c_str()); + + // Assign inputs and outputs to subgraph's meta_def + sub_graph->meta_def->input_len = ordered_inputs.size(); + sub_graph->meta_def->inputs = new char*[sub_graph->meta_def->input_len]; + i = 0; + for (const auto& input : ordered_inputs) { + sub_graph->meta_def->inputs[i] = new char[input.second.length() + 1]; + strcpy(sub_graph->meta_def->inputs[i++], input.second.c_str()); + } + + sub_graph->meta_def->initializer_len = initializers.size(); + sub_graph->meta_def->constant_initializers = new char*[sub_graph->meta_def->initializer_len]; + i = 0; + for (const auto& initializer : initializers) { + sub_graph->meta_def->constant_initializers[i] = new char[initializer.length() + 1]; + strcpy(sub_graph->meta_def->constant_initializers[i++], initializer.c_str()); + } + + sub_graph->meta_def->output_len = ordered_outputs.size(); + sub_graph->meta_def->outputs = new char*[sub_graph->meta_def->output_len]; + i = 0; + for (const auto& output : ordered_outputs) { + sub_graph->meta_def->outputs[i] = new char[output.second.length() + 1]; + strcpy(sub_graph->meta_def->outputs[i++], output.second.c_str()); + } + + sub_graph->meta_def->domain = "com.microsoft"; + sub_graph->meta_def->since_version = 1; + + return sub_graph; +} +*/ + +static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) { + TensorrtExecutionProvider* ep = static_cast(this_ptr); + const OrtApi& ort_api = ep->ort_api; + /* + // Get ModelPath + const std::filesystem::path* model_path = nullptr; + graph_api_->OrtGraph_GetModelPath(graph, reinterpret_cast(&model_path)); + const auto& path_string = model_path->string(); +#ifdef _WIN32 + strncpy_s(p->model_path_, path_string.c_str(), sizeof(p->model_path_) - 1); +#else + strncpy(p->model_path_, path_string.c_str(), sizeof(p->model_path_) - 1); +#endif + p->model_path_[sizeof(p->model_path_) - 1] = '\0'; + + int node_count = 0; + graph_api_->OrtGraph_NumberOfNodes(graph, &node_count); + if (node_count == 1 && GraphHasCtxNode(graph)) { + SubGraph_t supported_node_vector = {{0}, true}; + std::unique_ptr sub_graph = p->GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph), 0); + *cnt = 1; + *indexed_sub_graph = new OrtIndexedSubGraph*[1]; + (*indexed_sub_graph)[0] = sub_graph.release(); + return; + } + */ + + // Generate unique kernel name for TRT graph + // HashValue model_hash = TRTGenerateId(ort_api, graph, std::to_string(trt_version_), std::to_string(cuda_version_)); + + // Get pre-excluded op list from provider options + auto get_exclude_ops_set = [&](std::string node_list_to_exclude) -> std::set { + std::set set; + if (!node_list_to_exclude.empty()) { + std::stringstream node_list(node_list_to_exclude); + std::string node; + while (std::getline(node_list, node, ',')) { + set.insert(node); + } + } + return set; + }; + + //auto exclude_ops_set = get_exclude_ops_set(op_types_to_exclude_); + auto exclude_ops_set = get_exclude_ops_set(""); + + // Get all Ort nodes + OrtArrayOfConstObjects* nodes_container = nullptr; + DeferOrtRelease release_nodes(&nodes_container, + ep->ort_api.ReleaseArrayOfConstObjects); + RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graph, &nodes_container)); + + gsl::span nodes{}; + GetSpanFromArrayOfConstObjects(nodes_container, nodes); + // using ORT's priority-based topo sort (node with lower node index outputs first) the sorting result is the sequence of 0, 1, ... n-1 + // RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, /*order*/ 1, nodes.data(), nodes.size())); + + SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; + bool new_subgraph = true; + + /* Iterate all the nodes and exclude the node if: + * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. + * 2. Its op type is in the exclusion list. + */ + for (size_t index = 0; index < nodes.size(); index++) { + const OrtNode* node = nodes[index]; + bool supported_node = true; + + /* If current node is control flow op, we take different approach based on following four cases: + * + * (1) control flow op is supported by TRT, and its subgraphs are all supported by TRT. Assign this node to TRT. + * (2) control flow op is supported by TRT, but not all its subgraphs supported by TRT. Don't assign this node to TRT. + * (3) control flow op is not supported by TRT, but its subgraphs all supported by TRT. Don't assign this node to TRT. + * (4) control flow op is not supported by TRT, and not all its subgraphs supported by TRT. Don't assign this node to TRT. + * + * For cases 2, 3, 4, even though the control flow op is not assigned to TRT, any portion of its subgraphs that can run in TRT will be still fused and assigned to TRT EP. + */ + const char* op_type = nullptr; + RETURN_IF_ERROR(ep->ort_api.Node_GetOperatorType(node, &op_type)); + + if (ep->control_flow_op_set_.find(op_type) != ep->control_flow_op_set_.end()) { + auto supported_control_flow_op = [&](const OrtNode* node) { + OrtStatus* status = nullptr; + size_t num_subgraphs = 0; + OrtArrayOfConstObjects* node_subgraphs_container = nullptr; + DeferOrtRelease release_node_subgraphs(&node_subgraphs_container, + ep->ort_api.ReleaseArrayOfConstObjects); + + RETURN_FALSE_AND_PRINT_IF_ERROR(ep->ort_api.Node_GetSubgraphs(node, &node_subgraphs_container), ep->ort_api); + RETURN_FALSE_AND_PRINT_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(node_subgraphs_container, &num_subgraphs), ep->ort_api); + + for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { + const OrtGraph* subgraph = nullptr; + RETURN_FALSE_AND_PRINT_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetElementAt(node_subgraphs_container, subgraph_idx, + reinterpret_cast(&subgraph)), + ep->ort_api); + + // Get number of subgraph's nodes + size_t num_subgraph_nodes = 0; + OrtArrayOfConstObjects* subgraph_nodes_container = nullptr; + DeferOrtRelease release_subgraph_nodes(&subgraph_nodes_container, + ep->ort_api.ReleaseArrayOfConstObjects); + RETURN_FALSE_AND_PRINT_IF_ERROR(ep->ort_api.Graph_GetNodes(subgraph, &subgraph_nodes_container), ep->ort_api); + RETURN_FALSE_AND_PRINT_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(subgraph_nodes_container, &num_subgraph_nodes), ep->ort_api); + + // TRT EP should consider the empty subgraph is fully supported by TRT. + if (num_subgraph_nodes == 0) { + continue; + } + + /* + if (!ep->AllNodesAssignedToSpecificEP(*(subgraph->CreateGraphViewer()), kTensorrtExecutionProvider)) { + // if not all its subgraphs are supported, we need to exclude this control flow op + return false; + } + */ + } + return true; + }; + supported_node = supported_control_flow_op(node); + } + + // Exclude any ops, if applicable + if (exclude_ops_set.find(op_type) != exclude_ops_set.end()) { + supported_node = false; + } + + if (supported_node) { + if (new_subgraph) { + parser_nodes_vector.emplace_back(); + // Mark all new graphs as "UnKnown" which will later be parsed by TRT parser + parser_nodes_vector.back().second = false; + new_subgraph = false; + } + parser_nodes_vector.back().first.emplace_back(index); + } else { + new_subgraph = true; + } + } + + bool early_termination = false; + supported_nodes_vector = ep->GetSupportedList(parser_nodes_vector, 0, p->max_partition_iterations_, graph, &early_termination); + if (early_termination) { + supported_nodes_vector.clear(); + } + + // Remove subgraphs if its size is less than the predefined minimal size + for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end(); ++it) { + const size_t subgraph_size = it->first.size(); + if (subgraph_size < p->min_subgraph_size_) { + supported_nodes_vector.erase(it--); + } + } + + // Detect and remove cycles from supported node list + //p->DetectTensorRTGraphCycles(supported_nodes_vector, graph, model_hash); + + // Consolidate supported node list + if (supported_nodes_vector.size() > 1) { + nodes_vector.clear(); + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + nodes_vector.insert(nodes_vector.end(), group.first.begin(), group.first.end()); + } + } + SubGraphCollection_t consolidated_supported_nodes_vector = {{nodes_vector, true}}; + if (p->DetectTensorRTGraphCycles(consolidated_supported_nodes_vector, graph, model_hash, false)) { + // LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are not consolidated because graph will have cycles after consolidation"; + } else { + // LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are consolidated into one subgraph"; + supported_nodes_vector = consolidated_supported_nodes_vector; + } + } + + std::vector cache; + // Handle the case where the graph is subgraph of control flow op. + // The purpose is to make control flow op as well as its subgraphs run on TRT. + // Here we need to check whether subgraph is fully supported by TRT and don't fuse the nodes of the subgraph until control flow op level. + if (p->IsSubGraphOfControlFlowOp(graph) && p->IsSubGraphFullySupported(supported_nodes_vector, number_of_ort_nodes)) { + bool all_subgraphs_are_supported = true; + + // "If" control flow op has two subgraph bodies, "then" body and "else" body respectively. + // Check its parent node's another subgraph to see whether that subgraph is also fully supported by TRT. + const OrtNode* parent_node = nullptr; + graph_api_->OrtGraph_GetParenNode(graph, &parent_node); + const char* parent_node_op_type = nullptr; + graph_api_->OrtNode_GetOpType(parent_node, &parent_node_op_type); + if (strcmp(parent_node_op_type, "If") == 0) { + all_subgraphs_are_supported = false; + SubGraphCollection_t subgraph_supported_nodes_vector; + const OrtGraphViewer** subgraphs = nullptr; + size_t subgraph_count = 0; + graph_api_->OrtNode_GetSubgraphs(parent_node, &subgraphs, &subgraph_count); + for (size_t i = 0; i < subgraph_count; i++) { + bool same_graph = false; + graph_api_->OrtGraph_IsSameGraph(graph, subgraphs[i], &same_graph); + if (same_graph) { + continue; + } + int number_of_ort_subgraph_nodes = 0; + graph_api_->OrtGraph_NumberOfNodes(subgraphs[i], &number_of_ort_subgraph_nodes); + std::vector subgraph_nodes_vector(number_of_ort_subgraph_nodes); + std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0); + SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}}; + bool subgraph_early_termination = false; + + // Another subgraph of "If" control flow op has no nodes. + // In this case, TRT EP should consider this empty subgraph is fully supported by TRT. + if (number_of_ort_subgraph_nodes == 0) { + all_subgraphs_are_supported = true; + break; + } + // Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. + else if (p->AllNodesAssignedToSpecificEP(subgraphs[i], tensorrtEp)) { + all_subgraphs_are_supported = true; + break; + } + // Another subgraph of "If" control flow has been parsed by GetCapability and not all subgraph's nodes assigned to TRT EP. + // (Note: GetExecutionProviderType() returns "" meaning node has not yet been assigned to any EPs) + else if (!p->AllNodesAssignedToSpecificEP(subgraphs[i], "")) { + all_subgraphs_are_supported = false; + break; + } + + // Another subgraph of "If" control flow has not yet been parsed by GetCapability. + subgraph_supported_nodes_vector = p->GetSupportedList(parser_subgraph_nodes_vector, 0, p->max_partition_iterations_, subgraphs[i], &subgraph_early_termination); + all_subgraphs_are_supported = p->IsSubGraphFullySupported(subgraph_supported_nodes_vector, number_of_ort_subgraph_nodes); + break; + } + graph_api_->OrtGraph_ReleaseGraphViewerArray(subgraphs, subgraph_count); + } + + if (all_subgraphs_are_supported) { + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + for (const auto& index : group.first) { + std::unique_ptr sub_graph = std::make_unique(); + sub_graph->node_index_len = 1; + sub_graph->node_index = new size_t[sub_graph->node_index_len]; + sub_graph->node_index[0] = nodes_index[index]; + cache.push_back(sub_graph.release()); + } + } + } + *cnt = cache.size(); + *indexed_sub_graph = new OrtIndexedSubGraph*[*cnt]; + for (size_t i = 0; i < *cnt; i++) { + (*indexed_sub_graph)[i] = cache[i]; + } + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; + return; + } + } + + int number_of_trt_nodes = 0, subgraph_index = 0; + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + std::unique_ptr sub_graph = p->GetSubGraph(group, graph, model_hash, subgraph_index); + cache.push_back(sub_graph.release()); + number_of_trt_nodes += static_cast(group.first.size()); + subgraph_index++; + } + } + + const size_t number_of_subgraphs = supported_nodes_vector.size(); + if (number_of_trt_nodes == 0) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] No graph will run on TensorRT execution provider"; + } else if (number_of_trt_nodes == number_of_ort_nodes) { + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; + } else { + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; + } + + *cnt = cache.size(); + *indexed_sub_graph = new OrtIndexedSubGraph*[*cnt]; + for (size_t i = 0; i < *cnt; i++) { + (*indexed_sub_graph)[i] = cache[i]; + } + + return nullptr; +} + +static OrtStatus* ORT_API_CALL CompileImpl(OrtEp* this_ptr, const OrtGraph** graphs, const OrtNode** fused_nodes, + size_t count, OrtNodeComputeInfo** node_compute_infos) { + + TensorrtExecutionProvider* ep = static_cast(this_ptr); + + for (size_t graph_idx = 0; graph_idx < count; graph_idx++) { + auto fused_node = fused_nodes[graph_idx]; + + // Gets node's inputs and outputs as pointer array + OrtArrayOfConstObjects* inputs_array = nullptr; + OrtArrayOfConstObjects* outputs_array = nullptr; + DeferOrtRelease release_inputs(&inputs_array, ep->ort_api.ReleaseArrayOfConstObjects); + DeferOrtRelease release_outputs(&outputs_array, ep->ort_api.ReleaseArrayOfConstObjects); + + RETURN_IF_ERROR(ep->ort_api.Node_GetInputs(fused_node, &inputs_array)); + RETURN_IF_ERROR(ep->ort_api.Node_GetOutputs(fused_node, &outputs_array)); + + // Gets node's inputs and outputs as OrtValueInfo in gsl::span + gsl::span node_inputs{}; + gsl::span node_outputs{}; + + RETURN_IF_ERROR(GetSpanFromConstPointerArray(inputs_array, node_inputs)); + RETURN_IF_ERROR(GetSpanFromConstPointerArray(outputs_array, node_outputs)); + + // Gets number of node's inputs and outputs + size_t num_node_inputs = 0; + size_t num_node_outputs = 0; + RETURN_IF_ERROR(ep->ort_api.ConstPointerArray_GetSize(inputs_array, &num_node_inputs)); + RETURN_IF_ERROR(ep->ort_api.ConstPointerArray_GetSize(outputs_array, &num_node_outputs)); + + // Builds map from input name to its index in input list + std::unordered_map input_map; + input_map.reserve(num_node_inputs); + for (size_t i = 0, i < num_node_inputs; i++) { + std::string& name = node_inputs[i]->GetName(); + input_map[name] = i; + } + + // Builds map from output name to its index in output list + std::unordered_map out_map; + input_map.reserve(num_node_outputs); + for (size_t i = 0, i < num_node_outputs; i++) { + std::string& name = node_outputs[i]->GetName(); + out_map[name] = i; + } + + Status status; + if (GraphHasCtxNode(graph_body_viewer)) { + status = ep->CreateNodeComputeInfoFromPrecompiledEngine(graph_body_viewer, + fused_node, + input_map, + output_map, + node_compute_funcs); + } else { + status = ep->CreateNodeComputeInfoFromGraph(graph_body_viewer, fused_node, input_map, output_map, node_compute_funcs); + } + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); + } + + /* + OrtArrayOfConstObjects* nodes_array = nullptr; + DeferOrtRelease release_nodes(&nodes_array, ep->ort_api.ReleaseArrayOfConstObjects); + size_t num_nodes = 0; + RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graphs[graph_idx], &nodes_array)); + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(nodes_array, &num_nodes)); + */ + } + + return nullptr; +} + +static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) { + const auto* ep = static_cast(this_ptr); + return ep->name_.c_str(); +} + +/// +/// +/// Constructor of Plugin TensorRT EP +/// +/// +struct TensorrtExecutionProvider : TensorrtExecutionProvider(ApiPtrs apis, const std::string& name, const OrtHardwareDevice& device, + const OrtSessionOptions& session_options, const OrtLogger& logger) + : ApiPtrs(apis), name_{name}, hardware_device_{device}, session_options_{session_options}, logger_{logger} { + // Initialize the execution provider. + auto status = ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + ("Plugin EP has been created with name " + name_).c_str(), + ORT_FILE, __LINE__, __FUNCTION__); + // ignore status for now + (void)status; + + // Implementation of OrtEp interfaces + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. + GetName = GetNameImpl; + GetCapability = GetCapabilityImpl; + Compile = CompileImpl; + // ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl; + + // The implementation of the SessionOptionsAppendExecutionProvider C API function automatically adds EP options to + // the session option configurations with the key prefix "ep..". + const std::string key_prefix = OrtSessionOptions::GetProviderOptionPrefix(name_.c_str()); + const std::unordered_map& config_options_map = config_options.GetConfigOptionsMap(); + + // Get provider options as key-value pair strings + ProviderOptions provider_options; + for (const auto& [key, value] : config_options_map) { + if (key.rfind(key_prefix, 0) == 0) { + provider_options[key.substr(key_prefix.size())] = value; + } + } + + // Provider options to TensorrtExecutionProviderInfo + info_ = TensorrtExecutionProviderInfo::FromProviderOptions(provider_options); + if (ep_info.size() > 0) info_.has_trt_options = true; + device_id_ = info_.device_id; + api_->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, device_id_, &default_device); + + std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; + + // incase the EP context is dumped the engine cache has to be enabled + auto enable_engine_cache_for_ep_context_model = [this]() { + if (dump_ep_context_model_ && ep_context_embed_mode_ == 0) { + engine_cache_enable_ = true; + } + }; + + // Get environment variables + if (info_.has_trt_options) { + max_partition_iterations_ = info_.max_partition_iterations; + min_subgraph_size_ = info_.min_subgraph_size; + max_workspace_size_ = info_.max_workspace_size; + fp16_enable_ = info_.fp16_enable; + int8_enable_ = info_.int8_enable; + if (int8_enable_) { + int8_calibration_cache_name_ = info_.int8_calibration_table_name; + int8_use_native_tensorrt_calibration_table_ = info_.int8_use_native_calibration_table; + } + if (fp16_enable_ || int8_enable_) { // DLA can only be enabled with FP16 or INT8 + dla_enable_ = info_.dla_enable; + dla_core_ = info_.dla_core; + } + dump_subgraphs_ = info_.dump_subgraphs; + engine_cache_enable_ = info_.engine_cache_enable; + weight_stripped_engine_enable_ = info_.weight_stripped_engine_enable; + onnx_model_folder_path_ = info_.onnx_model_folder_path; + timing_cache_enable_ = info_.timing_cache_enable; + force_timing_cache_match_ = info_.force_timing_cache; + detailed_build_log_ = info_.detailed_build_log; + dump_ep_context_model_ = info_.dump_ep_context_model; + ep_context_file_path_ = info_.ep_context_file_path; + ep_context_embed_mode_ = info_.ep_context_embed_mode; + enable_engine_cache_for_ep_context_model(); + if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { + cache_path_ = info_.engine_cache_path; + cache_prefix_ = info_.engine_cache_prefix; + } + // use a more global cache if given + if (timing_cache_enable_) { + if (!info_.timing_cache_path.empty()) { + global_cache_path_ = info_.timing_cache_path; + } else { + global_cache_path_ = cache_path_; + } + } + engine_decryption_enable_ = info_.engine_decryption_enable; + if (engine_decryption_enable_) { + engine_decryption_lib_path_ = info_.engine_decryption_lib_path; + } + force_sequential_engine_build_ = info_.force_sequential_engine_build; + context_memory_sharing_enable_ = info_.context_memory_sharing_enable; + if (fp16_enable_) { + layer_norm_fp32_fallback_ = info_.layer_norm_fp32_fallback; + } + build_heuristics_enable_ = info_.build_heuristics_enable; + sparsity_enable_ = info_.sparsity_enable; + builder_optimization_level_ = info_.builder_optimization_level; + auxiliary_streams_ = info_.auxiliary_streams; + tactic_sources_ = info_.tactic_sources; + profile_min_shapes = info_.profile_min_shapes; + profile_max_shapes = info_.profile_max_shapes; + profile_opt_shapes = info_.profile_opt_shapes; + cuda_graph_enable_ = info_.cuda_graph_enable; + engine_hw_compatible_ = info_.engine_hw_compatible; + } else { + try { + // const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations); + // if (!max_partition_iterations_env.empty()) { + // max_partition_iterations_ = std::stoi(max_partition_iterations_env); + // } + + // const std::string min_subgraph_size_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMinSubgraphSize); + // if (!min_subgraph_size_env.empty()) { + // min_subgraph_size_ = std::stoi(min_subgraph_size_env); + // } + + // const std::string max_workspace_size_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxWorkspaceSize); + // if (!max_workspace_size_env.empty()) { + // max_workspace_size_ = std::stoull(max_workspace_size_env); + // } + + // const std::string fp16_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kFP16Enable); + // if (!fp16_enable_env.empty()) { + // fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); + // } + + // const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kINT8Enable); + // if (!int8_enable_env.empty()) { + // int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); + // } + + // if (int8_enable_) { + // const std::string int8_calibration_cache_name_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kINT8CalibrationTableName); + // if (!int8_calibration_cache_name_env.empty()) { + // int8_calibration_cache_name_ = int8_calibration_cache_name_env; + // } + + // const std::string int8_use_native_tensorrt_calibration_table_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kINT8UseNativeTensorrtCalibrationTable); + // if (!int8_use_native_tensorrt_calibration_table_env.empty()) { + // int8_use_native_tensorrt_calibration_table_ = (std::stoi(int8_use_native_tensorrt_calibration_table_env) == 0 ? false : true); + // } + // } + + // if (fp16_enable_ || int8_enable_) { // DLA can only be enabled with FP16 or INT8 + // const std::string dla_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDLAEnable); + // if (!dla_enable_env.empty()) { + // dla_enable_ = (std::stoi(dla_enable_env) == 0 ? false : true); + // } + + // if (dla_enable_) { + // const std::string dla_core_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDLACore); + // if (!dla_core_env.empty()) { + // dla_core_ = std::stoi(dla_core_env); + // } + // } + // } + + // const std::string dump_subgraphs_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDumpSubgraphs); + // if (!dump_subgraphs_env.empty()) { + // dump_subgraphs_ = (std::stoi(dump_subgraphs_env) == 0 ? false : true); + // } + + // const std::string engine_cache_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCacheEnable); + // if (!engine_cache_enable_env.empty()) { + // engine_cache_enable_ = (std::stoi(engine_cache_enable_env) == 0 ? false : true); + // } + + // const std::string weight_stripped_engine_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kWeightStrippedEngineEnable); + // if (!weight_stripped_engine_enable_env.empty()) { + // weight_stripped_engine_enable_ = std::stoi(weight_stripped_engine_enable_env) != 0; + // } + + // const std::string onnx_model_folder_path_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kOnnxModelFolderPath); + // if (!onnx_model_folder_path_env.empty()) { + // onnx_model_folder_path_ = onnx_model_folder_path_env; + // } + + // const std::string timing_cache_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kTimingCacheEnable); + // if (!timing_cache_enable_env.empty()) { + // timing_cache_enable_ = (std::stoi(timing_cache_enable_env) == 0 ? false : true); + // } + + // const std::string detailed_build_log_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDetailedBuildLog); + // if (!detailed_build_log_env.empty()) { + // detailed_build_log_ = (std::stoi(detailed_build_log_env) == 0 ? false : true); + // } + + // const std::string timing_force_match_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kForceTimingCache); + // if (!timing_force_match_env.empty()) { + // force_timing_cache_match_ = (std::stoi(timing_force_match_env) == 0 ? false : true); + // } + + // const std::string dump_ep_context_model_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDumpEpContextModel); + // if (!dump_ep_context_model_env.empty()) { + // dump_ep_context_model_ = (std::stoi(dump_ep_context_model_env) == 0 ? false : true); + // } + + // const std::string ep_context_file_path_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextComputeCapabilityEnable); + // if (!ep_context_file_path_env.empty()) { + // ep_context_file_path_ = ep_context_file_path_env; + // } + + // const std::string ep_context_embed_mode_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextEmbedMode); + // if (!ep_context_embed_mode_env.empty()) { + // ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env); + // } + // // incase the EP context is dumped the engine cache has to be enabled + // if (dump_ep_context_model_ && ep_context_embed_mode_ == 0) { + // engine_cache_enable_ = true; + // } + + // enable_engine_cache_for_ep_context_model(); + + // if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { + // const std::string engine_cache_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCachePath); + // cache_path_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCachePath); + // cache_prefix_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCachePrefix); + // if (!engine_cache_path.empty() && cache_path_.empty()) { + // cache_path_ = engine_cache_path; + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_ENGINE_CACHE_PATH is deprecated! Please use ORT_TENSORRT_CACHE_PATH to specify engine cache path"; + // } + // } + // if (timing_cache_enable_) { + // std::string timing_cache_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kTimingCachePath); + // // use a more global cache if given + // if (!timing_cache_path.empty()) { + // global_cache_path_ = timing_cache_path; + // } else { + // global_cache_path_ = cache_path_; + // } + // } + + // const std::string engine_decryption_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionEnable); + // if (!engine_decryption_enable_env.empty()) { + // engine_decryption_enable_ = (std::stoi(engine_decryption_enable_env) == 0 ? false : true); + // } + + // if (engine_decryption_enable_) { + // engine_decryption_lib_path_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionLibPath); + // } + + // const std::string force_sequential_engine_build_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kForceSequentialEngineBuild); + // if (!force_sequential_engine_build_env.empty()) { + // force_sequential_engine_build_ = (std::stoi(force_sequential_engine_build_env) == 0 ? false : true); + // } + + // const std::string context_memory_sharing_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kContextMemorySharingEnable); + // if (!context_memory_sharing_enable_env.empty()) { + // context_memory_sharing_enable_ = (std::stoi(context_memory_sharing_enable_env) == 0 ? false : true); + // } + + // const std::string layer_norm_fp32_fallback_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kLayerNormFP32Fallback); + // if (!layer_norm_fp32_fallback_env.empty()) { + // layer_norm_fp32_fallback_ = (std::stoi(layer_norm_fp32_fallback_env) == 0 ? false : true); + // } + + // const std::string build_heuristics_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kBuildHeuristics); + // if (!build_heuristics_env.empty()) { + // build_heuristics_enable_ = (std::stoi(build_heuristics_env) == 0 ? false : true); + // } + + // const std::string sparsity_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kSparsityEnable); + // if (!sparsity_enable_env.empty()) { + // sparsity_enable_ = (std::stoi(sparsity_enable_env) == 0 ? false : true); + // } + + // const std::string builder_optimization_level_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kBuilderOptimizationLevel); + // if (!builder_optimization_level_env.empty()) { + // builder_optimization_level_ = std::stoi(builder_optimization_level_env); + // } + + // const std::string auxiliary_streams_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kAuxiliaryStreams); + // if (!auxiliary_streams_env.empty()) { + // auxiliary_streams_ = std::stoi(auxiliary_streams_env); + // } + + // const std::string tactic_sources_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kTacticSources); + // if (!tactic_sources_env.empty()) { + // tactic_sources_ = tactic_sources_env; + // } + + // profile_min_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesMinShapes); + // profile_max_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesMaxShapes); + // profile_opt_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesOptShapes); + + // const std::string cuda_graph_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCudaGraphEnable); + // if (!cuda_graph_enable_env.empty()) { + // cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true); + // } + + } catch (const std::invalid_argument& ex) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what(); + } catch (const std::out_of_range& ex) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Out Of Range Error (from environment variables): " << ex.what(); + } catch (...) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Unknown Exception (from environment variables)"; + } + } + + // Validate setting + if (max_partition_iterations_ <= 0) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] TensorRT option trt_max_partition_iterations must be a positive integer value. Set it to 1000"; + max_partition_iterations_ = 1000; + } + if (min_subgraph_size_ <= 0) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] TensorRT option trt_min_subgraph_size must be a positive integer value. Set it to 1"; + min_subgraph_size_ = 1; + } + if (max_workspace_size_ <= 0) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] TensorRT option trt_max_workspace_size must be a positive integer value. Set it to 1073741824 (1GB)"; + max_workspace_size_ = 1 << 30; + } + if (dla_core_ < 0) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] TensorRT option trt_dla_core must be a non-negative integer value. Set it to 0"; + dla_core_ = 0; + } + + // If ep_context_file_path_ is provided as a directory, create it if it's not existed + if (dump_ep_context_model_ && !ep_context_file_path_.empty() && std::filesystem::path(ep_context_file_path_).extension().empty() && !std::filesystem::is_directory(ep_context_file_path_)) { + if (!std::filesystem::create_directory(ep_context_file_path_)) { + throw std::runtime_error("Failed to create directory " + ep_context_file_path_); + } + } + + // If dump_ep_context_model_ is enable, TRT EP forces cache_path_ to be the relative path of ep_context_file_path_. + // For example, + // - original cache path = "engine_cache_dir" -> new cache path = "./context_model_dir/engine_cache_dir" + // - original cache path = "" -> new cache path = "./context_model_dir" + // The new cache path will be saved as the "ep_cache_context" node attritue of the EP context node. + // For security reason, it needs to make sure the engine cache is saved inside context model directory. + if (dump_ep_context_model_ && engine_cache_enable_) { + if (IsAbsolutePath(cache_path_)) { + // LOGS_DEFAULT(ERROR) << "In the case of dumping context model and for security purpose, the trt_engine_cache_path should be set with a relative path, but it is an absolute path: " << cache_path_; + } + if (IsRelativePathToParentPath(cache_path_)) { + // LOGS_DEFAULT(ERROR) << "In the case of dumping context model and for security purpose, The trt_engine_cache_path has '..', it's not allowed to point outside the directory."; + } + + // Engine cache relative path to context model directory. + // It's used when dumping the "ep_cache_context" node attribute. + engine_cache_relative_path_to_context_model_dir = cache_path_; + + // Make cache_path_ to be the relative path of ep_context_file_path_ + cache_path_ = GetPathOrParentPathOfCtxModel(ep_context_file_path_).append(cache_path_).string(); + } + + // Hardware compatibility: pre-check on environment + if (engine_cache_enable_ && engine_hw_compatible_) { +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + if (std::stoi(compute_capability_) < 80) { + // LOGS_DEFAULT(WARNING) << "Engine hardware compatibility cannot be enabled as GPU arch < 80. "; + engine_hw_compatible_ = false; + } else if (std::stoi(compute_capability_) == 87) { + // LOGS_DEFAULT(WARNING) << "Engine hardware compatibility cannot be enabled on Jetson Orin. "; + engine_hw_compatible_ = false; + } +#else + // LOGS_DEFAULT(WARNING) << "Engine hardware compatibility cannot be enabled as TRT < 8.6. "; + engine_hw_compatible_ = false; +#endif + } + + if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { + if (!cache_path_.empty() && !fs::is_directory(cache_path_)) { + if (!fs::create_directory(cache_path_)) { + throw std::runtime_error("Failed to create directory " + cache_path_); + } + } + if (!global_cache_path_.empty() && !fs::is_directory(global_cache_path_)) { + if (!fs::create_directory(global_cache_path_)) { + throw std::runtime_error("Failed to create directory " + global_cache_path_); + } + } + } + + if (engine_decryption_enable_) { + LIBTYPE handle = OPENLIB(engine_decryption_lib_path_.c_str()); + if (handle == nullptr) { + // TODO(yang) + // ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + // "TensorRT EP could not open shared library from " + engine_decryption_lib_path_)); + } + engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt"); + engine_encryption_ = (int (*)(const char*, char*, size_t))LIBFUNC(handle, "encrypt"); + if (engine_decryption_ == nullptr) { + // TODO(yang) + // ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + // "TensorRT EP could not find decryption function in shared library from " + engine_decryption_lib_path_)); + } + } + + if (int8_enable_) { + int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); + } + + /* + * Parse explicit min/max/opt profile shapes from provider options. + * + * The format of min/max/opt profile shapes is defined as below: + * "input1:dim1xdim2...,input2:dim1xdim2...,...,input1:dim3xdim4...,input2:dim3xdim4...,..." + * + * (Note: if multiple shapes with same input name are specified, TRT EP will consider them as multiple profiles. + * Please refer to ParserProfileShapes() for more details) + * + */ + bool status = true; + // if (status) { + // status = ParseProfileShapes(profile_min_shapes, profile_min_shapes_); + // if (!status) { + // profile_min_shapes_.clear(); + // // LOGS_DEFAULT(WARNING) << "[TensorRT EP] The format of provider option 'trt_profile_min_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + // } + // } + + // if (status) { + // status = ParseProfileShapes(profile_max_shapes, profile_max_shapes_); + // if (!status) { + // profile_max_shapes_.clear(); + // // LOGS_DEFAULT(WARNING) << "[TensorRT EP] The format of provider option 'trt_profile_max_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + // } + // } + + // if (status) { + // status = ParseProfileShapes(profile_opt_shapes, profile_opt_shapes_); + // if (!status) { + // profile_opt_shapes_.clear(); + // // LOGS_DEFAULT(WARNING) << "[TensorRT EP] The format of provider option 'trt_profile_opt_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + // } + // } + + // if (status) { + // status = ValidateProfileShapes(profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); + // if (!status) { + // // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Profile shapes validation failed. Make sure the provider options 'trt_profile_min_shapes', 'trt_profile_max_shapes' and 'trt_profile_opt_shapes' have same input name and number of profile."; + // // LOGS_DEFAULT(WARNING) << "[TensorRT EP] TRT EP will implicitly create optimization profiles based on input tensor for you."; + // profile_min_shapes_.clear(); + // profile_max_shapes_.clear(); + // profile_opt_shapes_.clear(); + // } + // } + + // cuda graph: + // cudaStreamSynchronize() is not allowed in cuda graph capture. + // + // external stream: + // If user provides "external" cuda stream, only this cuda stream will be used even if multiple threads are running InferenceSession.Run() concurrently. + // So, no need to synchronize different streams after enqueueV3. + if (cuda_graph_enable_ || external_stream_) { + sync_stream_after_enqueue_ = false; + } + + { + auto lock = GetApiLock(); + runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_))); + } + + // EP Context setting + if (dump_ep_context_model_) { + extra_attr_keys_.push_back(k_ep_ctx_hardware_architecture.c_str()); + extra_attr_keys_.push_back(k_ep_ctx_onnx_model_filename.c_str()); + + if (engine_cache_enable_ && engine_hw_compatible_) { + extra_attr_values_.push_back(k_cc_hw_compatible.c_str()); + } else { + extra_attr_values_.push_back(compute_capability_.c_str()); + } + extra_attr_values_.push_back(model_path_); + } +} + +TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory() { + OrtExecutionProviderFactory::CreateExecutionProvider = [](OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size) -> OrtExecutionProvider* { + ProviderOptions options; + for (size_t i = 0; i < option_size; i++) options[ep_option_keys[i]] = ep_option_values[i]; + std::unique_ptr ret = std::make_unique(tensorrtEp.c_str(), std::move(options)); + return ret.release(); + }; +} + +nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder(TensorrtLogger& trt_logger) const { + if (!builder_) { + { + auto lock = GetApiLock(); + builder_ = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + } + } + return builder_.get(); +} + +OrtStatusPtr TensorrtExecutionProvider::RefitEngine(std::string onnx_model_filename, + std::string& onnx_model_folder_path, + std::string& weight_stripped_engine_cath_path, + bool path_check, + nvinfer1::ICudaEngine* trt_engine, + bool serialize_refitted_engine, + bool detailed_build_log) { +#if NV_TENSORRT_MAJOR >= 10 + std::filesystem::path onnx_model_path{onnx_model_folder_path}; + onnx_model_path.append(onnx_model_filename); + if (path_check && IsAbsolutePath(onnx_model_path.string())) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + std::string("For security purpose, the ONNX model path should be set with " + "a relative path, but it is an absolute path: " + + onnx_model_path.string()) + .c_str()); + } + if (path_check && IsRelativePathToParentPath(onnx_model_path.string())) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + "The ONNX model path has '..'. For security purpose, it's not " + "allowed to point outside the directory."); + } + + if (!std::filesystem::exists(onnx_model_path)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + std::string("The ONNX model " + onnx_model_path.string() + + " does not exist.") + .c_str()); + } + + // weight-stripped engine refit logic + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log); + auto refitter = std::unique_ptr(nvinfer1::createInferRefitter(*trt_engine, trt_logger)); + auto parser_refitter = std::unique_ptr( + nvonnxparser::createParserRefitter(*refitter, trt_logger)); + if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + std::string("TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()).c_str()); + } + if (refitter->refitCudaEngine()) { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully refitted the weight-stripped engine."; + } else { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + std::string("TensorRT EP's IRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()).c_str()); + } + + // serialize the refitted engine to disk + if (serialize_refitted_engine) { + std::string refitted_engine_cache = GetWeightRefittedEnginePath(weight_stripped_engine_cath_path); + nvinfer1::IHostMemory* serialized_engine = trt_engine->serialize(); + std::ofstream engine_file(refitted_engine_cache, std::ios::binary | std::ios::out); + engine_file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialize the refitted engine to " << refitted_engine_cache; + } + return nullptr; +#else + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP's IParserRefitter can only be used on TRT 10.0 onwards."); +#endif +} + + +OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this_ptr, + const OrtGraph** graphs, + const OrtNode** fused_nodes, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo** node_compute_infos) { + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); + auto trt_builder = GetBuilder(trt_logger); + auto network_flags = 0; +#if NV_TENSORRT_MAJOR > 8 + network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); +#endif + network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); + auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); + auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); + void* buf_data = nullptr; + size_t buf_size = 0; + graph_api_->OrtGraph_SerializeToArray(graph_body_viewer, &buf_data, &buf_size); + trt_parser->parse(buf_data, buf_size, model_path_); + trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); + graph_api_->OrtFreeMem(buf_data); + + // Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow + if (fp16_enable_ && layer_norm_fp32_fallback_) { + for (auto idx = 1; idx < trt_network->getNbLayers() - 1; ++idx) { + auto layer = trt_network->getLayer(idx); + auto next_layer = trt_network->getLayer(idx + 1); + if (layer->getType() == nvinfer1::LayerType::kELEMENTWISE && next_layer->getType() == nvinfer1::LayerType::kREDUCE && (static_cast(layer))->getOperation() == nvinfer1::ElementWiseOperation::kPOW) { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow"; + layer->setPrecision(nvinfer1::DataType::kFLOAT); + next_layer->setPrecision(nvinfer1::DataType::kFLOAT); + layer->setOutputType(0, nvinfer1::DataType::kFLOAT); + next_layer->setOutputType(0, nvinfer1::DataType::kFLOAT); + } + } + } + + int num_inputs = trt_network->getNbInputs(); + int num_outputs = trt_network->getNbOutputs(); + std::unordered_map input_indexes(num_inputs); + std::unordered_map output_indexes(num_outputs); + std::unordered_map output_types(num_outputs); + + /* + * Initialize shape range for each dynamic shape input tensor: + * 1) If user explicitly specifies optimization profiles via provider options, TRT EP will create those profiles during EP compile time. + * It won't make adjustment for profile values during EP compute time. + * + * 2) If no explicit optimization profiles provided by user, TRT EP will firstly set min/max/opt shape to [INT_MAX, INT_MIN, INT_MIN]. + * Later in EP compute time, the shape will be adjusted to [min_input_value, max_input_value, max_input_value] based on input tensor value. + * + * + * Once the TRT profiles are created: + * 1) If all the dynamic shape input tensors have associated profiles explicitly provided by user, those profiles will be applied to TRT builder config + * and the engine will be built at EP compile time. + * + * 2) As long as one of the dynamic shape input tensors has no explicitly associated profile, TRT EP will create default shape as described above, + * and all the profiles won't be applied and engine won't be built until EP compute time. + */ + bool has_dynamic_shape = false; // True if input tensor has dynamic shape and no explicit profile is specified, otherwise false. + bool has_explicit_profile = false; + bool apply_explicit_profile = false; + int num_profiles = 0; + std::vector trt_profiles; + + // Following c++ map data structure is used to help serialize/deserialize profiles where it saves dynamic shape dimension(s) and min/max/opt values for dynamic shape input tensor. + // + // (1) Single profile case: + // For example, assume tensor_a has two dynamic shape dimensions: dim_0 and dim_2, and tensor_b + // has one dynamic shape dimension: dim_1. The data will be: + // { + // tensor_a: { + // dim_0: [[min_shape, max_shape, opt_shape]], + // dim_2: [[min_shape, max_shape, opt_shape]] + // }, + // tensor_b: { + // dim_1: [[min_shape, max_shape, opt_shape]] + // } + // } + // + // (2) Multiple profiles case: + // For example, assume tensor_a has one dynamic shap dimension: dim 0, and tensor_b has one dynamic shape dimension: dim_1, + // and both of the tensors have two profiles. The data will be: + // { + // tensor_a: { + // dim_0: [[min_shape_0, max_shape_0, opt_shape_0], [min_shape_1, max_shape_1, opt_shape_1]] + // }, + // tensor_b: { + // dim_1: [[min_shape_2, max_shape_2, opt_shape_2], [min_shape_3, max_shape_3, opt_shape_3]] + // } + // } + ShapeRangesMap input_explicit_shape_ranges; + ShapeRangesMap input_implicit_shape_ranges; + + if ((!profile_min_shapes_.empty()) && (!profile_max_shapes_.empty()) && (!profile_opt_shapes_.empty())) { + has_explicit_profile = true; + num_profiles = GetNumProfiles(profile_min_shapes_); + for (int i = 0; i < num_profiles; i++) { + trt_profiles.push_back(trt_builder->createOptimizationProfile()); + } + } + + // Iterate all input tensors to check dynamic shape + for (unsigned int i = 0, end = num_inputs; i < end; ++i) { + auto input = trt_network->getInput(i); + const std::string& input_name = input->getName(); + nvinfer1::Dims dims = input->getDimensions(); + int nb_dims = dims.nbDims; + + // Apply explicit optimization profiles provided by user + if (has_explicit_profile) { + apply_explicit_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges); + } + + // If no explicit optimization profile is being applied, TRT EP will later set min/max/opt shape values based on input tensor values at EP compute time + if (!apply_explicit_profile) { + if (input->isShapeTensor()) { + // Shape tensor + std::vector> profile_vector; + std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN}; + profile_vector.push_back(shape_vector); // only one profile needed + input_implicit_shape_ranges[input_name][0] = profile_vector; + has_dynamic_shape = true; + } else { + // Execution tensor + for (int j = 0, end = nb_dims; j < end; ++j) { + if (dims.d[j] == -1) { + std::vector> profile_vector; + std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN}; + profile_vector.push_back(shape_vector); // only one profile needed + input_implicit_shape_ranges[input_name][j] = profile_vector; + has_dynamic_shape = true; + } + } + } + apply_explicit_profile = false; + } + } + + // Set explicit profiles in TRT config if all dynamic shape inputs have associated profiles provided by user + if (has_explicit_profile) { + // TRT EP has a constraint here. + // Users need to provide all the dynamic shape inputs with associated profiles if they want to explicitly specify profiles through provider options. + if (has_dynamic_shape) { + std::ostringstream msg; + msg << "User needs to provide all the dynamic shape inputs with associated profiles if they want to explicitly set profiles through provider options.\n"; + msg << "Please note that main graph could be partitioned into TRT/CUDA/CPU subgraphs, in this case, user also needs to provide shape profiles for the TRT subgraph's input if it's dynamic shape input.\n"; + msg << "Following input(s) has no associated shape profiles provided: "; + auto begin = input_implicit_shape_ranges.begin(); + auto end = input_implicit_shape_ranges.end(); + auto it = begin; + if (it != end) { + msg << it->first; + ++it; + } + for (; it != end; ++it) { + msg << "," << it->first; + } + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, msg.str().c_str()); + } else { + for (auto trt_profile : trt_profiles) { + trt_config->addOptimizationProfile(trt_profile); + } + } + } + // If no explicit profile is applied and the input has dynamic shape, TRT EP simply creates one profile by default. + // It will later set proper min/max/opt shape values duing EP compute time. + else if (!has_explicit_profile && has_dynamic_shape) { + trt_profiles.push_back(trt_builder->createOptimizationProfile()); + } + + // Check platform availability for low precision + if (fp16_enable_) { + if (!trt_builder->platformHasFastFp16()) { + fp16_enable_ = false; + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE is set, but platform doesn't support fast native fp16"; + } + } + + if (int8_enable_) { + if (!trt_builder->platformHasFastInt8()) { + int8_enable_ = false; + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8"; + } + } + + const char* node_name = nullptr; + graph_api_->OrtNode_GetName(fused_node, &node_name); + + // Load INT8 calibration table + std::unordered_map dynamic_range_map; + if (int8_enable_ && int8_calibration_cache_available_) { + const std::string calibration_cache_path = GetCachePath(cache_path_, int8_calibration_cache_name_); + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_tensorrt_calibration_table_, dynamic_range_map)) { + throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); + } + } + dynamic_range_map_[node_name] = dynamic_range_map; + + // Set precision flags + std::string trt_node_name_with_precision(node_name); + if (fp16_enable_ && int8_enable_) { + trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); + trt_node_name_with_precision += "_fp16_int8"; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 and INT8 mode is enabled"; + } else if (fp16_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + trt_node_name_with_precision += "_fp16"; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; + } else if (int8_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + trt_node_name_with_precision += "_int8"; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; + } + + // Set DLA + if (fp16_enable_ || int8_enable_) { + if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8 + int number_of_dla_core = trt_builder->getNbDLACores(); + if (number_of_dla_core == 0) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core"; + dla_enable_ = false; + } else { + if (dla_core_ >= number_of_dla_core) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core #" << dla_core_ << ", but it exceeds platform's maximum DLA core number " << number_of_dla_core << ". Use DLA core 0 instead."; + dla_core_ = 0; + } + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << dla_core_; + trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); + trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); + trt_config->setDLACore(dla_core_); + trt_node_name_with_precision += "_dlacore" + std::to_string(dla_core_); + } + } + } + trt_node_name_with_precision_[node_name] = trt_node_name_with_precision; + + // enable sparse weights + if (sparsity_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + } +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 + if (build_heuristics_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are enabled." + // << " For TRT > 8.5, trt_build_heuristics_enable is deprecated, please set builder optimization level as 2 to enable builder heuristics."; + } +#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // for TRT 8.6 onwards, heuristic-based tactic option is automatically enabled by setting builder optimization level 2 + if (build_heuristics_enable_) { + if (builder_optimization_level_ == 2) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are automatically enabled by builder optimization level 2. trt_build_heuristics_enable is deprecated on TRT 8.6 onwards."; + } else { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] trt_build_heuristics_enable is deprecated on TRT 8.6 onwards. Please set builder optimization level as 2 to enable builder heuristics."; + } + } +#endif + +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // switch optimizaion level + if (builder_optimization_level_ != 3) { + trt_config->setBuilderOptimizationLevel(builder_optimization_level_); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + } + + // limit auxiliary streams + if (auxiliary_streams_ >= 0) { + trt_config->setMaxAuxStreams(auxiliary_streams_); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << auxiliary_streams_; + } +#else + if (builder_optimization_level_ != 3) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + } + if (auxiliary_streams_ >= 0) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + } +#endif + + if (weight_stripped_engine_enable_) { +#if NV_TENSORRT_MAJOR >= 10 + trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; + trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; +#else + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; +#endif + } + + // limit used tactic sources + if (!tactic_sources_.empty()) { + nvinfer1::TacticSources tactics = trt_config->getTacticSources(); + tactics |= GetTacticSourceFromString(tactic_sources_); + trt_config->setTacticSources(tactics); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using " << tactic_sources_; + } + + // Build TRT engine (if needed) and load TRT engine if: + // (1) Graph has no dynamic shape input + // (2) All the dynamic shape inputs have associated explicit profiles specified by user + // + // Otherwise engine will be handled at inference time. + std::unique_ptr trt_engine; + std::unique_ptr trt_context; + + std::string cache_path = ""; + std::string cache_suffix = ""; + // Customize cache prefix if assigned + if (!cache_prefix_.empty()) { + // Generate cache suffix in case user would like to customize cache prefix + cache_suffix = "_" + GetCacheSuffix(node_name, trt_node_name_with_precision); + cache_path = GetCachePath(cache_path_, cache_prefix_) + cache_suffix; + } else { + cache_path = GetCachePath(cache_path_, trt_node_name_with_precision); + } + cache_suffix_[node_name] = cache_suffix; + + std::string cache_hw_compat = "_sm" + compute_capability_; + // Enable hardware compatility mode if assigned + if (engine_cache_enable_ && engine_hw_compatible_) { + trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); + cache_hw_compat = "_sm80+"; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; + } + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity + const std::string cache_path_prefix = cache_path + cache_hw_compat; + std::string engine_cache_path = cache_path_prefix + ".engine"; + const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; + const std::string profile_cache_path = cache_path_prefix + ".profile"; + + // If weight-stripped engine is enabled and refitted engine cache is not present, + // TRT EP will use the engine cache with ".stripped.engine" appended to the end. + const std::filesystem::path engine_cache_fs_path = engine_cache_path; + if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { + engine_cache_path = cache_path_prefix + ".stripped.engine"; + weight_stripped_engine_refit_ = true; + } + + auto create_ep_context_model = [this](const OrtGraphViewer* graph_body_viewer, + std::string& engine_cache_path, + std::string& engine_cache_relative_path_to_context_model_dir, + const char* ep_context_node_name, + char* serialized_engine, + size_t serialized_engine_size) { + // if ep context model name is not given, create a model name based on original model name + if (ctx_model_path_.empty()) { + ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_); + } + + // "ep_cache_context" node attribute should be a relative path to context model directory + if (ep_cache_context_attr_.empty()) { + auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); + ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); + } + + graph_api_->OrtGraph_CreateOrUpdateEpCtxGraph(graph_body_viewer, + ep_context_node_name, + 1, // main_context + ep_context_embed_mode_, + ep_cache_context_attr_.c_str(), + serialized_engine, + serialized_engine_size, + extra_attr_keys_.data(), + extra_attr_values_.data(), + extra_attr_keys_.size(), + &ep_ctx_graph_); + }; + + if (!has_dynamic_shape) { + std::string timing_cache_path = ""; + bool engine_update = false; + if (timing_cache_enable_) { + timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_); + } + { + // ifstream file check, engine serialization/deserialization and engine build are in critical section. It needs lock protection to prevent race condition when inferencing with multithreading. + auto lock = GetApiLock(); + + // If explicit profile flag is on and engine cache enable flag is on, + // we need to compare explicit profiles and profiles used to build the engine in order to decide whether to rebuild the engine. + if (has_explicit_profile && engine_cache_enable_) { + engine_update = CompareProfiles(profile_cache_path, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); + if (engine_update) { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine will be built"; + } else { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine won't be rebuilt"; + } + } + + std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); + if (engine_cache_enable_ && !engine_decryption_enable_ && engine_file && !engine_update) { + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + if (trt_engine == nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not deserialize engine from cache: " + engine_cache_path).c_str()); + } + + } else if (engine_decryption_enable_ && engine_cache_enable_ && std::filesystem::exists(encrypted_engine_cache_path) && !engine_update) { + // Decrypt engine + size_t engine_size = 0; + if (!engine_decryption_(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not get engine buffer size"); + } + std::unique_ptr engine_buf{new char[engine_size]}; + if (!engine_decryption_(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not call engine decryption function decrypt"); + } + // Deserialize engine + trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + if (trt_engine == nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path).c_str()); + } + } else { + // Set INT8 per tensor dynamic range + if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + trt_config->setInt8Calibrator(nullptr); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (!SetDynamicRange(*trt_network, dynamic_range_map)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not set INT8 dynamic range for fused node: " + std::string(node_name)).c_str()); + } + } + + // Load timing cache from file. Create a fresh cache if the file doesn't exist + std::unique_ptr timing_cache = nullptr; + if (timing_cache_enable_) { + std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); + timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); + if (timing_cache == nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not create timing cache: " + timing_cache_path).c_str()); + } + trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); + if (detailed_build_log_) { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; + } + } + + // Build engine + std::chrono::steady_clock::time_point engine_build_start; + if (detailed_build_log_) { + engine_build_start = std::chrono::steady_clock::now(); + } + std::unique_ptr serialized_engine{trt_builder->buildSerializedNetwork(*trt_network, *trt_config)}; + if (serialized_engine == nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP failed to create engine from network for fused node: " + std::string(node_name)).c_str()); + } + trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); + if (trt_engine == nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP failed to deserialize engine for fused node: " + std::string(node_name)).c_str()); + } + if (detailed_build_log_) { + auto engine_build_stop = std::chrono::steady_clock::now(); + // LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; + } + if (engine_cache_enable_) { + // Serialize engine profile if it has explicit profiles + if (has_explicit_profile) { + SerializeProfileV2(profile_cache_path, input_explicit_shape_ranges); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; + } + + if (engine_decryption_enable_) { + // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. + if (engine_encryption_ != nullptr) { + if (!engine_encryption_(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP call to engine encryption library failed"); + } + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + } else { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; + } + } else { + std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); + file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path; + } + } + // serialize and save timing cache + if (timing_cache_enable_) { + auto timing_cache = trt_config->getTimingCache(); + std::unique_ptr timingCacheHostData{timing_cache->serialize()}; + if (timingCacheHostData == nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not serialize timing cache: " + timing_cache_path).c_str()); + } + saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); + if (detailed_build_log_) { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; + } + } + + // create and dump ep context model + if (dump_ep_context_model_) { + create_ep_context_model(graph_body_viewer, engine_cache_path, engine_cache_relative_path_to_context_model_dir, node_name, reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + graph_api_->OrtGraph_DumpOnnxModel(ep_ctx_graph_, ctx_model_path_.c_str()); + graph_api_->OrtGraph_ReleaseGraph(ep_ctx_graph_); + } + } + } + + if (weight_stripped_engine_refit_) { + auto status = RefitEngine(model_path_, + onnx_model_folder_path_, + engine_cache_path, + false /* path check for security */, + trt_engine.get(), + true /* serialize refitted engine to disk */, + detailed_build_log_); + if (status != nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); + } + } + + // Build context + // Note: Creating an execution context from an engine is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + if (context_memory_sharing_enable_) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + size_t mem_size = trt_engine->getDeviceMemorySize(); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (mem_size > max_ctx_mem_size_) { + max_ctx_mem_size_ = mem_size; + } +#if NV_TENSORRT_MAJOR < 10 + trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); +#else + trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); +#endif + } else { + trt_context = std::unique_ptr(trt_engine->createExecutionContext()); + } + if (!trt_context) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not build execution context for fused node: " + std::string(node_name)).c_str()); + } + } + + // Create input to index map + for (int i = 0; i < num_inputs; ++i) { + auto input = trt_network->getInput(i); + const std::string& input_name = input->getName(); + const auto& iter = input_map.find(input_name); + if (iter != input_map.end()) { + input_indexes[input_name] = iter->second; + } + } + + // Create output to index and type maps + for (int i = 0; i < num_outputs; ++i) { + const std::string& output_name = trt_network->getOutput(i)->getName(); + const auto& iter = output_map.find(output_name); + if (iter != output_map.end()) { + output_indexes[output_name] = iter->second; + } + int32_t output_type = 0; + graph_api_->OrtGraph_GetIthOutputElemType(graph_body_viewer, i, &output_type); + output_types[output_name] = output_type; + } + + // Save TRT engine, other TRT objects and input/output info to map + parsers_.emplace(node_name, std::move(trt_parser)); + engines_.emplace(node_name, std::move(trt_engine)); + contexts_.emplace(node_name, std::move(trt_context)); + networks_.emplace(node_name, std::move(trt_network)); + input_info_[node_name].push_back(input_indexes); + output_info_[node_name].push_back(output_indexes); + output_info_[node_name].push_back(output_types); + input_shape_ranges_[node_name] = input_implicit_shape_ranges; + profiles_.emplace(node_name, std::move(trt_profiles)); + + // Create ep context model if the model has dynamic shape, + // dump the model is embed mode is 0, otherwise update and dump the model at runtime. + if (has_dynamic_shape && dump_ep_context_model_) { + create_ep_context_model(graph_body_viewer, engine_cache_path, engine_cache_relative_path_to_context_model_dir, node_name, nullptr, 0); + if (ep_context_embed_mode_ == 0) { + graph_api_->OrtGraph_DumpOnnxModel(ep_ctx_graph_, ctx_model_path_.c_str()); + graph_api_->OrtGraph_ReleaseGraph(ep_ctx_graph_); + } + } + + // Create function state + node_compute_funcs->CreateFunctionStateFunc = [](OrtComputeContext* context, void* extra_param, void** state) -> int { + TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); + std::unique_ptr p = std::make_unique(); + + // translate tactic sources string to nvinfer1::TacticSources + nvinfer1::TacticSources tactics = 0; + if (!this_->tactic_sources_.empty()) { + tactics = GetTacticSourceFromString(this_->tactic_sources_); + } + *p = {context->AllocateFunc, context->DestroyFunc, context->allocator_handle, context->node_name, this_->builder_.get(), + &(this_->parsers_[context->node_name]), &(this_->engines_[context->node_name]), &(this_->contexts_[context->node_name]), + &(this_->networks_[context->node_name]), this_->input_info_[context->node_name], this_->output_info_[context->node_name], + this_->input_shape_ranges_[context->node_name], &this_->tensorrt_mu_, this_->fp16_enable_, this_->int8_enable_, this_->int8_calibration_cache_available_, + this_->dla_enable_, this_->dla_core_, &(this_->max_workspace_size_), this_->trt_node_name_with_precision_[context->node_name], + this_->engine_cache_enable_, this_->cache_path_, this_->runtime_.get(), this_->profiles_[context->node_name], + this_->context_memory_sharing_enable_, &(this_->max_ctx_mem_size_), this_->dynamic_range_map_[context->node_name], this_->engine_decryption_enable_, + this_->engine_decryption_, this_->engine_encryption_, this_->timing_cache_enable_, this_->global_cache_path_, this_->force_timing_cache_match_, + this_->detailed_build_log_, this_->build_heuristics_enable_, this_->sparsity_enable_, this_->builder_optimization_level_, + this_->auxiliary_streams_, !(this_->tactic_sources_.empty()), tactics, this_->cuda_graph_enable_, this_->cache_prefix_, this_->cache_suffix_[context->node_name], this_->engine_hw_compatible_}; + *state = p.release(); + return 0; + }; + + // Release function state + node_compute_funcs->DestroyFunctionStateFunc = [](void* state) { + delete static_cast(state); + }; + + // Create compute function + node_compute_funcs->ComputeFunc = [](void* state, void* extra_param, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { + Ort::KernelContext ctx(context); + TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); + TensorrtFuncState* trt_state = reinterpret_cast(state); + + // The whole compute_function should be considered the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine, + // save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading. + // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); + const std::unordered_map& input_indexes = (trt_state->input_info)[0]; + const std::unordered_map& output_indexes = (trt_state->output_info)[0]; + const std::unordered_map& output_types = (trt_state->output_info)[1]; + auto fused_node_name = trt_state->fused_node_name; + // This map "shape_ranges" contains the shape range info for setting TRT optimization profiles. + // The info is used for both shape tensor and execution tensor: + // tensor name->(dimension->[min, max, opt]) + auto& shape_ranges = trt_state->input_shape_ranges; + std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run + std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input + auto& dds_output_allocator_map = this_->dds_output_allocator_maps_[fused_node_name]; + auto trt_builder = trt_state->builder; + auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); + auto trt_profiles = trt_state->profiles; + auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; + int num_inputs = static_cast(input_indexes.size()); + int num_outputs = static_cast(output_indexes.size()); + bool engine_update = false; + bool context_update = false; + std::unordered_set input_names; + + OrtMemoryInfo* mem_info = nullptr; + api_->CreateMemoryInfo("Cuda", OrtAllocatorType::OrtDeviceAllocator, this_->device_id_, OrtMemType::OrtMemTypeDefault, &mem_info); + if (this_->alloc_ == nullptr) { + Ort::ThrowOnError(api_->KernelContext_GetAllocator(context, mem_info, &(this_->alloc_))); + } + OrtAllocator* alloc = this_->alloc_; + + void* cuda_stream; + Ort::ThrowOnError(api_->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + cudaStream_t stream = static_cast(cuda_stream); + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity + // Prepare cache name + std::string cache_path = ""; + // Customize cache prefix if assigned + if (!this_->cache_prefix_.empty()) { + cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->cache_prefix) + trt_state->cache_suffix; + } else { + cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); + } + + // Enable hardware compatility mode if assigned + std::string cache_hw_compat = "_sm" + this_->compute_capability_; + if (this_->engine_cache_enable_ && this_->engine_hw_compatible_) { + cache_hw_compat = "_sm80+"; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; + } + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity + const std::string cache_path_prefix = cache_path + cache_hw_compat; + std::string engine_cache_path = cache_path_prefix + ".engine"; + const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; + const std::string profile_cache_path = cache_path_prefix + ".profile"; + std::string timing_cache_path = ""; + if (this_->timing_cache_enable_) { + timing_cache_path = GetTimingCachePath(this_->global_cache_path_, this_->compute_capability_); + } + + // If weight-stripped engine is enabled and refitted engine cache is not present, + // TRT EP will use the engine cache with ".stripped.engine" appended to the end. + const std::filesystem::path engine_cache_fs_path = engine_cache_path; + if (this_->weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { + engine_cache_path = cache_path_prefix + ".stripped.engine"; + this_->weight_stripped_engine_refit_ = true; + } + + // Load serialized engine + if (trt_state->engine_cache_enable && trt_engine == nullptr) { + std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); + std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); + if (engine_file && !trt_state->engine_decryption_enable && profile_file) { + // Deserialize profile + shape_ranges = DeserializeProfileV2(profile_file); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + + // Prepare buffer + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + trt_state->engine->reset(); + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); + if (!(*(trt_state->engine))) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP Failed to Build Engine."); + } + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + trt_engine = trt_state->engine->get(); + context_update = true; + + } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) { + shape_ranges = DeserializeProfileV2(profile_file); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + // Decrypt engine + size_t engine_size = 0; + if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not get engine buffer size"); + } + std::unique_ptr engine_buf{new char[engine_size]}; + if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not call engine decryption function decrypt"); + } + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + trt_state->engine->reset(); + *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); + if (!(*(trt_state->engine))) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path).c_str()); + } + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + trt_engine = trt_state->engine->get(); + context_update = true; + } + } + + // Check and update shape ranges for dynamic shape inputs. + for (int i = 0, end = num_inputs; i < end; ++i) { + auto input = trt_state->network->get()->getInput(i); + const std::string& input_name = input->getName(); + input_names.insert(input_name); + + // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved. + // TRT EP will help determine the min/max/opt profile values based on current input tensor value. + if (shape_ranges.find(input_name) != shape_ranges.end()) { + auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, shape_tensor_values, shape_tensor_values_int64, stream, &engine_update); + if (status != nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles."); + } + } + } + + // Regenerate engine + if (engine_update) { + // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior. + trt_state->context->reset(); + trt_state->engine->reset(); + auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); + trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, *(trt_state->max_workspace_size_ptr)); + for (auto trt_profile : trt_profiles) { + trt_config->addOptimizationProfile(trt_profile); + } + + // Set INT8 Per Tensor Dynamic range + if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + trt_config->setInt8Calibrator(nullptr); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); + } + } + + // Set precision + if (trt_state->fp16_enable && trt_state->int8_enable) { + trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); + } else if (trt_state->fp16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + } else if (trt_state->int8_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + } + + // Set DLA (DLA can only run with FP16 or INT8) + if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; + trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); + trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); + trt_config->setDLACore(trt_state->dla_core); + } + + // enable sparse weights + if (trt_state->sparsity_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + } +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 + // enable builder heuristics + if (trt_state->build_heuristics_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; + } +#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // switch optimizaion level + if (trt_state->builder_optimization_level != 3) { + trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + } + + // limit auxiliary streams + if (trt_state->auxiliary_streams >= 0) { + trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; + } +#else + if (trt_state->builder_optimization_level != 3) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + } + if (trt_state->auxiliary_streams >= 0) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + } +#endif + if (this_->weight_stripped_engine_enable_) { +#if NV_TENSORRT_MAJOR >= 10 + trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; + trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; +#else + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; +#endif + } + // limit used tactic sources + if (trt_state->filter_tactic_sources) { + nvinfer1::TacticSources tactics = trt_config->getTacticSources(); + tactics |= trt_state->tactic_sources; + trt_config->setTacticSources(tactics); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; + } + + // Load timing cache from file. Create a fresh cache if the file doesn't exist + std::unique_ptr timing_cache = nullptr; + if (trt_state->timing_cache_enable) { + std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); + timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); + if (timing_cache == nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not create timing cache: " + timing_cache_path).c_str()); + } + trt_config->setTimingCache(*timing_cache, this_->force_timing_cache_match_); + if (this_->detailed_build_log_) { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; + } + } + + // Enable hardware compatility mode if assigned + if (trt_state->engine_hw_compatible) { + trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Re-generate engine with hardware compatibility enabled."; + } + + // Build engine + std::unique_ptr serialized_engine; + { + auto lock = this_->GetApiLock(); + std::chrono::steady_clock::time_point engine_build_start; + if (this_->detailed_build_log_) { + engine_build_start = std::chrono::steady_clock::now(); + } + serialized_engine = std::unique_ptr( + trt_builder->buildSerializedNetwork(*trt_state->network->get(), *trt_config)); + if (!serialized_engine) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to create engine from network."); + } + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); + if (!(*(trt_state->engine))) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to deserialize engine."); + } + if (this_->detailed_build_log_) { + auto engine_build_stop = std::chrono::steady_clock::now(); + // LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; + } + } + if (!(*(trt_state->engine))) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP Failed to Build Engine."); + } + trt_engine = trt_state->engine->get(); + if (trt_state->engine_cache_enable) { + // Serialize engine profile + SerializeProfileV2(profile_cache_path, shape_ranges); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; + + // Serialize engine + if (trt_state->engine_decryption_enable) { + // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. + if (trt_state->engine_encryption != nullptr) { + if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not call engine encryption function encrypt"); + } + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + } else { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; + } + } else { + std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); + file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; + } + } + + // serialize and save timing cache + if (trt_state->timing_cache_enable) { + auto timing_cache = trt_config->getTimingCache(); + std::unique_ptr timingCacheHostData{timing_cache->serialize()}; + if (timingCacheHostData == nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not serialize timing cache: " + timing_cache_path).c_str()); + } + saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); + if (this_->detailed_build_log_) { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; + } + } + + // dump ep context model + if (this_->dump_ep_context_model_ && this_->ep_context_embed_mode_) { + graph_api_->OrtGraph_CreateOrUpdateEpCtxGraph(nullptr, + fused_node_name.c_str(), + 1, // main_context + this_->ep_context_embed_mode_, + this_->ep_cache_context_attr_.c_str(), + reinterpret_cast(serialized_engine->data()), + serialized_engine->size(), + this_->extra_attr_keys_.data(), + this_->extra_attr_values_.data(), + this_->extra_attr_keys_.size(), + &this_->ep_ctx_graph_); + graph_api_->OrtGraph_DumpOnnxModel(this_->ep_ctx_graph_, this_->ctx_model_path_.c_str()); + graph_api_->OrtGraph_ReleaseGraph(this_->ep_ctx_graph_); + } + context_update = true; + + if (this_->weight_stripped_engine_refit_) { + auto status = RefitEngine(this_->model_path_, + this_->onnx_model_folder_path_, + engine_cache_path, + false /* path check for security */, + trt_engine, + true /* serialize refitted engine to disk */, + this_->detailed_build_log_); + if (status != nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); + } + } + } + + if (context_update) { + if (trt_state->context_memory_sharing_enable) { +#if NV_TENSORRT_MAJOR < 10 + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); +#else + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); +#endif + } else { + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContext()); + } + if (!(*(trt_state->context))) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to create context."); + } + trt_context = trt_state->context->get(); + } + + // Get input and output binding names + int total_bindings = trt_engine->getNbIOTensors(); + std::vector input_binding_names, output_binding_names; + for (int i = 0, end = total_bindings; i < end; ++i) { + auto const& name = trt_engine->getIOTensorName(i); + auto const& mode = trt_engine->getTensorIOMode(name); + if (mode == nvinfer1::TensorIOMode::kINPUT) { + input_binding_names.push_back(name); + } else { + output_binding_names.push_back(name); + } + } + + /* + * Set input shapes and bind input buffers + */ + std::vector> scratch_buffers; + for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { + char const* input_name = input_binding_names[i]; + + size_t input_index = 0; + const auto iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } + auto input_tensor = ctx.GetInput(input_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + + auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); + if (status != nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); + } + } + + /* + * Set output shapes and bind output buffers + */ + std::unordered_map buffers; + buffers.reserve(num_outputs); + using OutputOrtValue = Ort::UnownedValue; + std::unordered_map output_tensors; + output_tensors.reserve(num_outputs); + std::unordered_map output_dim_sizes; + output_dim_sizes.reserve(num_outputs); + + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + + size_t output_type = 0; + const auto type_iter = output_types.find(output_name); + if (type_iter != output_types.end()) { + output_type = type_iter->second; + } + + OrtStatusPtr status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, + dds_output_allocator_map, scratch_buffers, alloc, buffers); + if (status != nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); + } + } + + // Set execution context memory + if (trt_state->context_memory_sharing_enable) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + size_t mem_size = trt_engine->getDeviceMemorySize(); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (mem_size > *max_context_mem_size_ptr) { + *max_context_mem_size_ptr = mem_size; + } + trt_context->setDeviceMemory(MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + } + + // Start CUDA graph capture. + // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because + // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. + // if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { + // LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; + // cuda_graph_.SetStream(stream); + // CaptureBegin(0); + // } + + // Run TRT inference + if (!trt_context->enqueueV3(stream)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP execution context enqueue failed."); + } + + /* + * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently, + * TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen: + * + * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream. + * In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently, + * the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT. + * So TRT EP will end up having one trt execution context using multiple streams which is not suggested. + * But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream + * is guaranteed. + * + * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. + * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. + */ + if (this_->sync_stream_after_enqueue_) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + } + + // Assign TRT output back to ORT output + // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) + // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_type = 0; + const auto& iter = output_types.find(output_name); + if (iter != output_types.end()) { + output_type = iter->second; + } + + if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + auto status = BindKernelOutput(ctx, mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); + if (status != nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); + } + } else { + auto& output_tensor = output_tensors[i]; +#if NV_TENSORRT_MAJOR < 10 + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); + } + } +#endif + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); + } + } + } + } + + // // End CUDA graph capture. + // // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture + // // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. + // // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. + // if (cuda_graph_enable_ && !IsGraphCaptured(0)) { + // if (IsGraphCaptureAllowed()) { + // CaptureEnd(0); + // // CUDA work issued to a capturing stream doesn’t actually run on the GPU, + // // so run the captured graph here to actually execute the work. + // ORT_RETURN_IF_ERROR(ReplayGraph(0)); + // } else { + // IncrementRegularRunCountBeforeGraphCapture(); + // } + // } + // std::cout << "end of ComputeFunc in TRTEp's CreateNodeComputeInfoFromGraph()\n"; + return nullptr; + }; + + return nullptr; +} + +OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const OrtGraphViewer* graph_body_viewer, const OrtNode* fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo* node_compute_funcs) { + std::unique_ptr trt_engine; + std::unique_ptr trt_context; + std::unordered_map input_indexes; // TRT engine input name -> ORT kernel context input index + std::unordered_map output_indexes; // TRT engine output name -> ORT kernel context output index + std::unordered_map output_types; // TRT engine output name -> ORT output tensor type + + // Get engine binary data and deserialize it + auto trt_cache_model_handler = TensorRTCacheModelHandler(&trt_engine, + runtime_.get(), + model_path_, + compute_capability_, + weight_stripped_engine_enable_, + onnx_model_folder_path_, + detailed_build_log_); + auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer); + if (status != nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); + } + + // Build context + // + // Note: Creating an execution context from an engine is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + if (context_memory_sharing_enable_) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + size_t mem_size = trt_engine->getDeviceMemorySize(); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (mem_size > max_ctx_mem_size_) { + max_ctx_mem_size_ = mem_size; + } +#if NV_TENSORRT_MAJOR < 10 + trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); +#else + trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); +#endif + } else { + trt_context = std::unique_ptr(trt_engine->createExecutionContext()); + } + + const char* fused_node_name = nullptr; + graph_api_->OrtNode_GetName(fused_node, &fused_node_name); + if (!trt_context) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + std::string("TensorRT EP could not build execution context for fused node: " + std::string(fused_node_name)).c_str()); + } + + // Create input/output to index maps + for (int32_t i = 0; i < trt_engine->getNbIOTensors(); ++i) { + auto const& name = trt_engine->getIOTensorName(i); + auto const& mode = trt_engine->getTensorIOMode(name); + if (mode == nvinfer1::TensorIOMode::kINPUT) { + const auto& iter = input_map.find(name); + if (iter != input_map.end()) { + input_indexes[name] = iter->second; + } + } else { + const auto& iter = output_map.find(name); + if (iter != output_map.end()) { + output_indexes[name] = iter->second; + } + } + } + + // Create output to type map + size_t graph_output_size = 0; + graph_api_->OrtGraph_GetOutputSize(graph_body_viewer, &graph_output_size); + for (size_t i = 0; i < graph_output_size; i++) { + char const* output_name = nullptr; + graph_api_->OrtGraph_GetIthOutputName(graph_body_viewer, i, &output_name); + int32_t output_type = 0; + graph_api_->OrtGraph_GetIthOutputElemType(graph_body_viewer, i, &output_type); + output_types[output_name] = output_type; + } + + // Save TRT engine, TRT context and input/output info to map + engines_.emplace(fused_node_name, std::move(trt_engine)); + contexts_.emplace(fused_node_name, std::move(trt_context)); + input_info_[fused_node_name].push_back(input_indexes); + output_info_[fused_node_name].push_back(output_indexes); + output_info_[fused_node_name].push_back(output_types); + + // Create function state + node_compute_funcs->CreateFunctionStateFunc = [](OrtComputeContext* context, void* extra_param, void** state) -> int { + TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); + std::unique_ptr p = std::make_unique(); + *p = {context->AllocateFunc, + context->DestroyFunc, + context->allocator_handle, + context->node_name, + &(this_->engines_[context->node_name]), + &(this_->contexts_[context->node_name]), + this_->input_info_[context->node_name], + this_->output_info_[context->node_name], + this_->context_memory_sharing_enable_, + &this_->max_ctx_mem_size_, + &this_->tensorrt_mu_}; + *state = p.release(); + return 0; + }; + + // Release function state + node_compute_funcs->DestroyFunctionStateFunc = [](void* state) { + delete reinterpret_cast(state); + }; + + // Create compute function + node_compute_funcs->ComputeFunc = [](void* state, void* extra_param, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { + TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); + TensorrtShortFuncState* trt_state = reinterpret_cast(state); + Ort::KernelContext ctx(context); + + // The whole compute_function should be considered the critical section. + // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); + const std::unordered_map& input_indexes = (trt_state->input_info)[0]; + const std::unordered_map& output_indexes = (trt_state->output_info)[0]; + const std::unordered_map& output_types = (trt_state->output_info)[1]; + auto fused_node_name = trt_state->fused_node_name; + std::cout << fused_node_name << std::endl; + auto& dds_output_allocator_map = this_->dds_output_allocator_maps_[fused_node_name]; + auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); + auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; + int num_outputs = static_cast(output_indexes.size()); + std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run + std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input + + OrtMemoryInfo* mem_info = nullptr; + api_->CreateMemoryInfo("Cuda", OrtAllocatorType::OrtDeviceAllocator, this_->device_id_, OrtMemType::OrtMemTypeDefault, &mem_info); + if (this_->alloc_ == nullptr) { + Ort::ThrowOnError(api_->KernelContext_GetAllocator(context, mem_info, &(this_->alloc_))); + } + OrtAllocator* alloc = this_->alloc_; + + void* cuda_stream; + Ort::ThrowOnError(api_->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + cudaStream_t stream = static_cast(cuda_stream); + + // Get input and output binding names + int total_bindings = trt_engine->getNbIOTensors(); + std::vector input_binding_names, output_binding_names; + for (int i = 0, end = total_bindings; i < end; ++i) { + auto const& name = trt_engine->getIOTensorName(i); + auto const& mode = trt_engine->getTensorIOMode(name); + if (mode == nvinfer1::TensorIOMode::kINPUT) { + input_binding_names.push_back(name); + } else { + output_binding_names.push_back(name); + } + } + + /* + * Set input shapes and bind input buffers + */ + std::vector> scratch_buffers; + for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { + char const* input_name = input_binding_names[i]; + + size_t input_index = 0; + const auto iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } + + OrtStatusPtr status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); + if (status != nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); + } + } + + /* + * Set output shapes and bind output buffers + */ + std::unordered_map buffers; + buffers.reserve(num_outputs); + using OutputOrtValue = Ort::UnownedValue; + std::unordered_map output_tensors; + output_tensors.reserve(num_outputs); + std::unordered_map output_dim_sizes; + output_dim_sizes.reserve(num_outputs); + + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + + size_t output_type = 0; + const auto type_iter = output_types.find(output_name); + if (type_iter != output_types.end()) { + output_type = type_iter->second; + } + + OrtStatusPtr status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, + dds_output_allocator_map, scratch_buffers, alloc, buffers); + if (status != nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); + } + } + + // Set execution context memory + if (trt_state->context_memory_sharing_enable) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + size_t mem_size = trt_engine->getDeviceMemorySize(); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (mem_size > *max_context_mem_size_ptr) { + *max_context_mem_size_ptr = mem_size; + } + trt_context->setDeviceMemory(MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + } + + // Start CUDA graph capture. + // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because + // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. + if (this_->cuda_graph_enable_ && this_->IsGraphCaptureAllowed() && !this_->IsGraphCaptured(0)) { + // LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; + // cuda_graph_.SetStream(stream); + // CaptureBegin(0); + } + + // Run TRT inference + if (!trt_context->enqueueV3(stream)) { + return api_->CreateStatus(OrtErrorCode::ORT_FAIL, "TensorRT EP execution context enqueue failed."); + } + + /* + * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently, + * TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen: + * + * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream. + * In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently, + * the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT. + * So TRT EP will end up having one trt execution context using multiple streams which is not suggested. + * But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream + * is guaranteed. + * + * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. + * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. + */ + if (this_->sync_stream_after_enqueue_) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + } + + // Assign TRT output back to ORT output + // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) + // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_type = 0; + const auto& iter = output_types.find(output_name); + if (iter != output_types.end()) { + output_type = iter->second; + } + + if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + OrtStatusPtr status = BindKernelOutput(ctx, mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); + if (status != nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_FAIL, api_->GetErrorMessage(status)); + } + } else { + auto& output_tensor = output_tensors[i]; +#if NV_TENSORRT_MAJOR < 10 + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); + } + } +#endif + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); + } + } + } + } + + // End CUDA graph capture. + // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture + // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. + // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. + if (this_->cuda_graph_enable_ && !this_->IsGraphCaptured(0)) { + // if (IsGraphCaptureAllowed()) { + // CaptureEnd(0); + // // CUDA work issued to a capturing stream doesn’t actually run on the GPU, + // // so run the captured graph here to actually execute the work. + // ORT_RETURN_IF_ERROR(ReplayGraph(0)); + // } else { + // IncrementRegularRunCountBeforeGraphCapture(); + // } + } + + return nullptr; + }; + + return nullptr; +} + +SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t nodes_vector_input, int iterations, const int max_iterations, + const OrtGraph* graph, bool* early_termination) const { + // Return if iterations are exceeding predefined number + SubGraphCollection_t nodes_list_output; + if (iterations > max_iterations) { + *early_termination = true; + return nodes_list_output; + } + + iterations++; + for (const auto& group : nodes_vector_input) { + // Construct subgraph + if (!group.first.empty()) { + if (group.second) { + nodes_list_output.push_back(group); + } else { + //const OrtGraphViewer* sub_graph_viewer = nullptr; + //graph_api_->OrtGraph_GetSubGraph(graph, group.first.size(), group.first.data(), &sub_graph_viewer); + + void* buf_data = nullptr; + size_t buf_size = 0; + graph_api_->OrtGraph_SerializeToArray(sub_graph_viewer, &buf_data, &buf_size); + + // Get supported node list recursively + SubGraphCollection_t parser_nodes_list; + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); + auto trt_builder = GetBuilder(trt_logger); + auto network_flags = 0; +#if NV_TENSORRT_MAJOR > 8 + network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); +#endif + network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); + + auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + trt_parser->supportsModel(buf_data, buf_size, parser_nodes_list, model_path_); + graph_api_->OrtFreeMem(buf_data); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + + SubGraphCollection_t next_nodes_list; + const size_t* subgraph_node_index = nullptr; + size_t subgraph_node_count = 0; + graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(sub_graph_viewer, 1, &subgraph_node_index, &subgraph_node_count); + next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, sub_graph_viewer, early_termination); + for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { + for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { + next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]]; + } + nodes_list_output.push_back(next_nodes_list[i]); + } + graph_api_->OrtGraph_ReleaseGraphViewer(sub_graph_viewer, true); + } + } + } + return nodes_list_output; +} + +} // namespace onnxruntime + +#ifdef __cplusplus +extern "C" { +#endif +OrtExecutionProviderFactory* RegisterCustomEp() { + std::unique_ptr ret = std::make_unique(); + return ret.release(); +} +#ifdef __cplusplus +} +#endif diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h new file mode 100644 index 00000000..0e1c24c3 --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -0,0 +1,406 @@ +#pragma once + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +#include "utils/provider_options.h" +#include "tensorrt_execution_provider_info.h" +#include "nv_includes.h" + +#include +#include +#include +#include + +#ifdef _WIN32 +#define EXPORT_API __declspec(dllexport) +#else +#define EXPORT_API +#endif + +namespace onnxruntime { + +namespace tensorrt_env_vars { +static const std::string kMaxPartitionIterations = "ORT_TENSORRT_MAX_PARTITION_ITERATIONS"; +static const std::string kMinSubgraphSize = "ORT_TENSORRT_MIN_SUBGRAPH_SIZE"; +static const std::string kMaxWorkspaceSize = "ORT_TENSORRT_MAX_WORKSPACE_SIZE"; +static const std::string kFP16Enable = "ORT_TENSORRT_FP16_ENABLE"; +static const std::string kINT8Enable = "ORT_TENSORRT_INT8_ENABLE"; +static const std::string kINT8CalibrationTableName = "ORT_TENSORRT_INT8_CALIBRATION_TABLE_NAME"; +static const std::string kINT8UseNativeTensorrtCalibrationTable = "ORT_TENSORRT_INT8_USE_NATIVE_CALIBRATION_TABLE"; +static const std::string kDLAEnable = "ORT_TENSORRT_DLA_ENABLE"; +static const std::string kDLACore = "ORT_TENSORRT_DLA_CORE"; +static const std::string kDumpSubgraphs = "ORT_TENSORRT_DUMP_SUBGRAPHS"; +static const std::string kEngineCacheEnable = "ORT_TENSORRT_ENGINE_CACHE_ENABLE"; +static const std::string kCachePath = "ORT_TENSORRT_CACHE_PATH"; +static const std::string kWeightStrippedEngineEnable = "ORT_TENSORRT_WEIGHT_STRIPPED_ENGINE_ENABLE"; +static const std::string kOnnxModelFolderPath = "ORT_TENSORRT_ONNX_MODEL_FOLDER_PATH"; +// As a timing cache can be used across multiple ONNX files it makes sense to have a separate cache path +static const std::string kTimingCachePath = "ORT_TENSORRT_GLOBAL_CACHE_PATH"; +static const std::string kDecryptionEnable = "ORT_TENSORRT_ENGINE_DECRYPTION_ENABLE"; +static const std::string kDecryptionLibPath = "ORT_TENSORRT_ENGINE_DECRYPTION_LIB_PATH"; +static const std::string kForceSequentialEngineBuild = "ORT_TENSORRT_FORCE_SEQUENTIAL_ENGINE_BUILD"; +static const std::string kContextMemorySharingEnable = "ORT_TENSORRT_CONTEXT_MEMORY_SHARING_ENABLE"; +static const std::string kLayerNormFP32Fallback = "ORT_TENSORRT_LAYER_NORM_FP32_FALLBACK"; +static const std::string kTimingCacheEnable = "ORT_TENSORRT_TIMING_CACHE_ENABLE"; +static const std::string kForceTimingCache = "ORT_TENSORRT_FORCE_TIMING_CACHE_ENABLE"; +static const std::string kDetailedBuildLog = "ORT_TENSORRT_DETAILED_BUILD_LOG_ENABLE"; +static const std::string kBuildHeuristics = "ORT_TENSORRT_BUILD_HEURISTICS_ENABLE"; +static const std::string kSparsityEnable = "ORT_TENSORRT_SPARSITY_ENABLE"; +static const std::string kBuilderOptimizationLevel = "ORT_TENSORRT_BUILDER_OPTIMIZATION_LEVEL"; +static const std::string kAuxiliaryStreams = "ORT_TENSORRT_AUXILIARY_STREAMS"; +static const std::string kTacticSources = "ORT_TENSORRT_TACTIC_SOURCES"; +static const std::string kExtraPluginLibPaths = "ORT_TENSORRT_EXTRA_PLUGIN_LIB_PATHS"; +static const std::string kProfilesMinShapes = "ORT_TENSORRT_PROFILE_MIN_SHAPES"; +static const std::string kProfilesMaxShapes = "ORT_TENSORRT_PROFILE_MAX_SHAPES"; +static const std::string kProfilesOptShapes = "ORT_TENSORRT_PROFILE_OPT_SHAPES"; +static const std::string kCudaGraphEnable = "ORT_TENSORRT_CUDA_GRAPH_ENABLE"; +static const std::string kDumpEpContextModel = "ORT_DUMP_EP_CONTEXT_MODEL"; +static const std::string kEpContextEmbedMode = "ORT_EP_CONTEXT_EMBED_MODE"; +static const std::string kEpContextComputeCapabilityEnable = "ORT_EP_CONTEXT_COMPUTE_CAPABILITY_ENABLE"; +static const std::string kEngineCachePrefix = "ORT_TENSORRT_CACHE_PREFIX"; +// Old env variable for backward compatibility +static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH"; +} // namespace tensorrt_env_vars + +using HashValue = uint64_t; +using AllocateFunc = void* (*)(void*, size_t, size_t); +using DestroyFunc = void (*)(void*, void*); + +class TensorrtLogger : public nvinfer1::ILogger { + nvinfer1::ILogger::Severity verbosity_; + + public: + TensorrtLogger(Severity verbosity = Severity::kWARNING) + : verbosity_(verbosity) {} + void log(Severity severity, const char* msg) noexcept override { + if (severity <= verbosity_) { + time_t rawtime = std::time(0); + struct tm stm; +#ifdef _MSC_VER + gmtime_s(&stm, &rawtime); +#else + gmtime_r(&rawtime, &stm); +#endif + char buf[256]; + strftime(&buf[0], 256, + "%Y-%m-%d %H:%M:%S", + &stm); + const char* sevstr = (severity == Severity::kINTERNAL_ERROR ? " BUG" : severity == Severity::kERROR ? " ERROR" + : severity == Severity::kWARNING ? "WARNING" + : severity == Severity::kINFO ? " INFO" + : "UNKNOWN"); + if (severity <= Severity::kERROR) { + // LOGS_DEFAULT(ERROR) << "[" << buf << " " << sevstr << "] " << msg; + } else { + // LOGS_DEFAULT(WARNING) << "[" << buf << " " << sevstr << "] " << msg; + } + } + } + void set_level(Severity verbosity) { + verbosity_ = verbosity; + } + Severity get_level() const { + return verbosity_; + } +}; + +namespace tensorrt_ptr { + +struct TensorrtInferDeleter { + template + void operator()(T* obj) const { + if (obj) { + delete obj; + } + } +}; + +template +using unique_pointer = std::unique_ptr; +}; // namespace tensorrt_ptr + +class OutputAllocator : public nvinfer1::IOutputAllocator { + public: +#if NV_TENSORRT_MAJOR >= 10 + void* reallocateOutputAsync(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment, cudaStream_t stream) noexcept override; +#else + void* reallocateOutput(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment) noexcept override; +#endif + void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override; + + void* getBuffer() { + return outputPtr; + } + + std::vector& getOutputShape() { + return output_shapes; + } + + uint64_t getSize() { + return allocated_size; + } + + ~OutputAllocator() override { + cudaFree(outputPtr); + } + + private: + void* outputPtr{nullptr}; + uint64_t allocated_size = 0; + std::vector output_shapes; +}; + +using ShapeRangesMap = std::unordered_map>>>; + +struct TensorrtFuncState { + AllocateFunc test_allocate_func = nullptr; + DestroyFunc test_release_func = nullptr; + void* allocator = nullptr; + std::string fused_node_name; + nvinfer1::IBuilder* builder; + tensorrt_ptr::unique_pointer* parser = nullptr; + std::unique_ptr* engine = nullptr; + std::unique_ptr* context = nullptr; + std::unique_ptr* network = nullptr; + std::vector> input_info; + std::vector> output_info; + std::unordered_map>>> input_shape_ranges; + std::mutex* tensorrt_mu_ptr = nullptr; + bool fp16_enable = false; + bool int8_enable = false; + bool int8_calibration_cache_available = false; + bool dla_enable = false; + int dla_core = 0; + size_t* max_workspace_size_ptr = nullptr; + std::string trt_node_name_with_precision; + bool engine_cache_enable = false; + std::string engine_cache_path; + nvinfer1::IRuntime* runtime = nullptr; + std::vector profiles; + bool context_memory_sharing_enable = false; + size_t* max_context_mem_size_ptr = nullptr; + std::unordered_map dynamic_range_map; + bool engine_decryption_enable = false; + int (*engine_decryption)(const char*, char*, size_t*) = nullptr; + int (*engine_encryption)(const char*, char*, size_t) = nullptr; + bool timing_cache_enable = true; + std::string timing_cache_path; + bool force_timing_cache = false; + bool detailed_build_log = false; + bool build_heuristics_enable = false; + bool sparsity_enable = false; + int builder_optimization_level = 3; + int auxiliary_streams = -1; + bool filter_tactic_sources = false; + nvinfer1::TacticSources tactic_sources; + bool cuda_graph_enable = 0; + std::string cache_prefix; + std::string cache_suffix; + bool engine_hw_compatible = false; +}; + +// Minimum information to construct kernel function state for direct engine load code path +struct TensorrtShortFuncState { + AllocateFunc test_allocate_func = nullptr; + DestroyFunc test_release_func = nullptr; + void* allocator = nullptr; + std::string fused_node_name; + std::unique_ptr* engine = nullptr; + std::unique_ptr* context = nullptr; + std::vector> input_info; + std::vector> output_info; + bool context_memory_sharing_enable = false; + size_t* max_context_mem_size_ptr = nullptr; + std::mutex* tensorrt_mu_ptr = nullptr; +}; + +using DDSOutputAllocatorMap = std::unordered_map>; +std::string GetWeightRefittedEnginePath(std::string engine_cache_path); + +static const std::string k_cc_hw_compatible = "80+"; +static const std::string k_ep_ctx_hardware_architecture = "hardware_architecture"; +static const std::string k_ep_ctx_onnx_model_filename = "onnx_model_filename"; + +struct ApiPtrs { + const OrtApi& ort_api; + const OrtEpApi& ep_api; +}; + +/// +/// +/// Plugin TensorRT EP that implements OrtEp +/// +/// +struct TensorrtExecutionProvider : OrtEp, ApiPtrs { + TensorrtExecutionProvider(ApiPtrs apis, const std::string& name, const OrtHardwareDevice& device, + const OrtSessionOptions& session_options, const OrtLogger& logger); + ~TensorrtExecutionProvider(); + + std::string name_; + const OrtHardwareDevice& hardware_device_; + const OrtSessionOptions& session_options_; + const OrtLogger& logger_; + + + /* + bool IsGraphCaptured(int graph_annotation_id) const { return false; } + + static OrtStatusPtr RefitEngine(std::string onnx_model_filename, + std::string& onnx_model_folder_path, + std::string& weight_stripped_engine_cath_path, + bool path_check, + nvinfer1::ICudaEngine* trt_engine, + bool serialize_refitted_engine, + bool detailed_build_log); + + std::unique_ptr GetSubGraph(SubGraph_t graph_nodes_index, + const OrtGraph* graph, const HashValue& model_hash, int subgraph_index) const; + SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations, const int max_iterations, + const OrtGraph* graph, bool* early_termination) const; + + bool DetectTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const OrtGraphViewer* graph, const HashValue& model_hash, bool remove_cycles = true) const; + */ + + /** + Get a unique_lock object to control the concurrency behavior. + Every api call not in the thread-safe operations(https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading) + should be protected by a lock when invoked by multiple threads concurrently. + */ + std::unique_lock GetApiLock() const; + + /**Check the graph is the subgraph of control flow op*/ + //bool IsSubGraphOfControlFlowOp(const OrtGraphViewer* graph) const; + + /**Check whether all the nodes of the graph are assigned to specific ep*/ + //bool AllNodesAssignedToSpecificEP(const OrtGraphViewer* graph, const std::string& provider_type) const; + + /**Check whether all the nodes of subgraph are supported*/ + //bool IsSubGraphFullySupported(SubGraphCollection_t supported_nodes_vector, const int number_of_ort_nodes) const; + + std::unordered_map trt_node_name_with_precision_; + std::unordered_map> dynamic_range_map_; + std::unordered_map cache_suffix_; + + //private: + mutable TensorrtExecutionProviderInfo info_; + bool external_stream_ = false; + cudaStream_t stream_ = nullptr; + int max_partition_iterations_ = 1000; + size_t min_subgraph_size_ = 1; + size_t max_workspace_size_ = 1 << 30; // 1GB + bool fp16_enable_ = false; + bool int8_enable_ = false; + bool dla_enable_ = false; + int dla_core_ = 0; + bool force_sequential_engine_build_ = false; + std::string int8_calibration_cache_name_; + bool int8_calibration_cache_available_ = false; + bool int8_use_native_tensorrt_calibration_table_ = false; + bool dump_subgraphs_ = false; + bool engine_cache_enable_ = false; + bool weight_stripped_engine_enable_ = false; + bool weight_stripped_engine_refit_ = false; + std::string onnx_model_folder_path_; + bool build_heuristics_enable_ = false; + bool sparsity_enable_ = false; + int builder_optimization_level_ = 3; + int auxiliary_streams_ = -1; + std::string tactic_sources_; + std::string global_cache_path_, cache_path_, engine_decryption_lib_path_; + std::unique_ptr runtime_ = nullptr; + std::mutex tensorrt_mu_; + int device_id_; + std::string compute_capability_; + bool context_memory_sharing_enable_ = false; + bool layer_norm_fp32_fallback_ = false; + size_t max_ctx_mem_size_ = 0; + // IAllocatorUniquePtr context_memory_ = nullptr; + mutable char model_path_[4096] = {}; // Reserved for max path length + bool engine_decryption_enable_ = false; + int (*engine_decryption_)(const char*, char*, size_t*) = nullptr; + int (*engine_encryption_)(const char*, char*, size_t) = nullptr; + bool timing_cache_enable_ = false; + bool force_timing_cache_match_ = false; + bool detailed_build_log_ = false; + bool cuda_graph_enable_ = false; + std::string cache_prefix_; + bool engine_hw_compatible_ = false; + + // The OrtAllocator object will be get during ep compute time + // and should be kept for the lifetime of TRT EP object. + OrtAllocator* alloc_ = nullptr; + + // For create/dump EP context node model + bool dump_ep_context_model_ = false; + std::string ep_context_file_path_; + int ep_context_embed_mode_ = 0; + std::string ctx_model_path_; + std::string ep_cache_context_attr_; + std::string engine_cache_relative_path_to_context_model_dir; + + OrtGraph* ep_ctx_graph_ = nullptr; + std::vector extra_attr_keys_; + std::vector extra_attr_values_; + + // std::unique_ptr model_proto_ = ONNX_NAMESPACE::ModelProto::Create(); + + std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; + // mutable std::unordered_map> subgraph_context_map_; + + mutable std::unique_ptr builder_; + + // Following maps that hold TRT objects will be accessible by different threads if ORT is using multithreading. + // In general, TensorRT objects are not thread safe; accesses to an object from different threads must be serialized by the client. + // But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + // For those non thread safe operations, TRT EP uses (1) lock_guard or (2) PerThreadContext to make sure synchronization. + std::unordered_map> parsers_; + std::unordered_map> engines_; + std::unordered_map> contexts_; + std::unordered_map> builders_; + std::unordered_map> networks_; + std::unordered_map>> input_info_; + std::unordered_map>> output_info_; + std::unordered_map>> profile_min_shapes_; + std::unordered_map>> profile_max_shapes_; + std::unordered_map>> profile_opt_shapes_; + std::unordered_map input_shape_ranges_; // The profile shape ranges that the engine is built with + std::unordered_map> profiles_; + std::unordered_map dds_output_allocator_maps_; + + // for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture + // cudnnHandle_t external_cudnn_handle_ = nullptr; + // cublasHandle_t external_cublas_handle_ = nullptr; + + // Call cudaStreamSynchronize() after TRT enqueueV3() + mutable bool sync_stream_after_enqueue_ = true; + + // CUDAGraph cuda_graph_; + // bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: + // (1) memory pattern is enabled. (2) arena allocation for stream. + // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs + // to allocate enough memory in Arena before graph capturing. + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + + OrtStatus* CreateNodeComputeInfoFromPrecompiledEngine(OrtEp* this_ptr, + const OrtGraph** graphs, + const OrtNode** fused_nodes, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo** node_compute_infos); + + OrtStatus* CreateNodeComputeInfoFromGraph(OrtEp* this_ptr, + const OrtGraph** graphs, + const OrtNode** fused_nodes, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo** node_compute_infos); + + bool IsGraphCaptureAllowed() const { return false; }; + + nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const; +}; +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc new file mode 100644 index 00000000..8a34cf0c --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc @@ -0,0 +1,339 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include //#incldue "core/providers/cuda/cuda_pch.h" + +#include "tensorrt_execution_provider_info.h" +#include "provider_options_utils.h" +#include "cuda/cuda_common.h" + +namespace onnxruntime { +namespace tensorrt { +namespace provider_option_names { +constexpr const char* kDeviceId = "device_id"; +constexpr const char* kHasUserComputeStream = "has_user_compute_stream"; +constexpr const char* kUserComputeStream = "user_compute_stream"; +constexpr const char* kMaxPartitionIterations = "trt_max_partition_iterations"; +constexpr const char* kMinSubgraphSize = "trt_min_subgraph_size"; +constexpr const char* kMaxWorkspaceSize = "trt_max_workspace_size"; +constexpr const char* kFp16Enable = "trt_fp16_enable"; +constexpr const char* kInt8Enable = "trt_int8_enable"; +constexpr const char* kInt8CalibTable = "trt_int8_calibration_table_name"; +constexpr const char* kInt8UseNativeCalibTable = "trt_int8_use_native_calibration_table"; +constexpr const char* kDLAEnable = "trt_dla_enable"; +constexpr const char* kDLACore = "trt_dla_core"; +constexpr const char* kDumpSubgraphs = "trt_dump_subgraphs"; +constexpr const char* kEngineCacheEnable = "trt_engine_cache_enable"; +constexpr const char* kEngineCachePath = "trt_engine_cache_path"; +constexpr const char* kWeightStrippedEngineEnable = "trt_weight_stripped_engine_enable"; +constexpr const char* kOnnxModelFolderPath = "trt_onnx_model_folder_path"; +constexpr const char* kEngineCachePrefix = "trt_engine_cache_prefix"; +constexpr const char* kDecryptionEnable = "trt_engine_decryption_enable"; +constexpr const char* kDecryptionLibPath = "trt_engine_decryption_lib_path"; +constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine_build"; +// add new provider option name here. +constexpr const char* kContextMemorySharingEnable = "trt_context_memory_sharing_enable"; +constexpr const char* kLayerNormFP32Fallback = "trt_layer_norm_fp32_fallback"; +constexpr const char* kTimingCacheEnable = "trt_timing_cache_enable"; +constexpr const char* kTimingCachePath = "trt_timing_cache_path"; +constexpr const char* kForceTimingCacheMatch = "trt_force_timing_cache"; +constexpr const char* kDetailedBuildLog = "trt_detailed_build_log"; +constexpr const char* kBuildHeuristics = "trt_build_heuristics_enable"; +constexpr const char* kSparsityEnable = "trt_sparsity_enable"; +constexpr const char* kBuilderOptimizationLevel = "trt_builder_optimization_level"; +constexpr const char* kAuxiliaryStreams = "trt_auxiliary_streams"; +constexpr const char* kTacticSources = "trt_tactic_sources"; +constexpr const char* kExtraPluginLibPaths = "trt_extra_plugin_lib_paths"; +constexpr const char* kProfilesMinShapes = "trt_profile_min_shapes"; +constexpr const char* kProfilesMaxShapes = "trt_profile_max_shapes"; +constexpr const char* kProfilesOptShapes = "trt_profile_opt_shapes"; +constexpr const char* kCudaGraphEnable = "trt_cuda_graph_enable"; +constexpr const char* kEpContextEmbedMode = "trt_ep_context_embed_mode"; +constexpr const char* kEpContextFilePath = "trt_ep_context_file_path"; +constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model"; +constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible"; + +} // namespace provider_option_names +} // namespace tensorrt + +TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { + TensorrtExecutionProviderInfo info{}; + + void* user_compute_stream = nullptr; + ORT_THROW_IF_ERROR( + ProviderOptionsParser{} + .AddValueParser( + tensorrt::provider_option_names::kDeviceId, + [&info](const std::string& value_str) -> Status { + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); + int num_devices{}; + CUDA_RETURN_IF_ERROR(cudaGetDeviceCount(&num_devices)); + ORT_RETURN_IF_NOT( + 0 <= info.device_id && info.device_id < num_devices, + "Invalid device ID: ", info.device_id, + ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); + return Status::OK(); + }) + .AddAssignmentToReference(tensorrt::provider_option_names::kMaxPartitionIterations, info.max_partition_iterations) + .AddAssignmentToReference(tensorrt::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream) + .AddValueParser( + tensorrt::provider_option_names::kUserComputeStream, + [&user_compute_stream](const std::string& value_str) -> Status { + size_t address; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + user_compute_stream = reinterpret_cast(address); + return Status::OK(); + }) + .AddAssignmentToReference(tensorrt::provider_option_names::kMinSubgraphSize, info.min_subgraph_size) + .AddAssignmentToReference(tensorrt::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size) + .AddAssignmentToReference(tensorrt::provider_option_names::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kInt8Enable, info.int8_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kInt8CalibTable, info.int8_calibration_table_name) + .AddAssignmentToReference(tensorrt::provider_option_names::kInt8UseNativeCalibTable, info.int8_use_native_calibration_table) + .AddAssignmentToReference(tensorrt::provider_option_names::kDLAEnable, info.dla_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kDLACore, info.dla_core) + .AddAssignmentToReference(tensorrt::provider_option_names::kDumpSubgraphs, info.dump_subgraphs) + .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCacheEnable, info.engine_cache_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePath, info.engine_cache_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kWeightStrippedEngineEnable, info.weight_stripped_engine_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kOnnxModelFolderPath, info.onnx_model_folder_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePrefix, info.engine_cache_prefix) + .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build) + .AddAssignmentToReference(tensorrt::provider_option_names::kContextMemorySharingEnable, info.context_memory_sharing_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kLayerNormFP32Fallback, info.layer_norm_fp32_fallback) + .AddAssignmentToReference(tensorrt::provider_option_names::kTimingCacheEnable, info.timing_cache_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kTimingCachePath, info.timing_cache_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kForceTimingCacheMatch, info.force_timing_cache) + .AddAssignmentToReference(tensorrt::provider_option_names::kDetailedBuildLog, info.detailed_build_log) + .AddAssignmentToReference(tensorrt::provider_option_names::kBuildHeuristics, info.build_heuristics_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kSparsityEnable, info.sparsity_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kBuilderOptimizationLevel, info.builder_optimization_level) + .AddAssignmentToReference(tensorrt::provider_option_names::kAuxiliaryStreams, info.auxiliary_streams) + .AddAssignmentToReference(tensorrt::provider_option_names::kTacticSources, info.tactic_sources) + .AddAssignmentToReference(tensorrt::provider_option_names::kExtraPluginLibPaths, info.extra_plugin_lib_paths) + .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMinShapes, info.profile_min_shapes) + .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes) + .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes) + .AddAssignmentToReference(tensorrt::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kDumpEpContextModel, info.dump_ep_context_model) + .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextFilePath, info.ep_context_file_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextEmbedMode, info.ep_context_embed_mode) + .AddAssignmentToReference(tensorrt::provider_option_names::kEngineHwCompatible, info.engine_hw_compatible) + .Parse(options)); // add new provider option here. + + info.user_compute_stream = user_compute_stream; + info.has_user_compute_stream = (user_compute_stream != nullptr); + return info; +} + +//ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtExecutionProviderInfo& info) { +// const ProviderOptions options{ +// {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, +// {tensorrt::provider_option_names::kMaxPartitionIterations, MakeStringWithClassicLocale(info.max_partition_iterations)}, +// {tensorrt::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, +// {tensorrt::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, +// {tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.min_subgraph_size)}, +// {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.max_workspace_size)}, +// {tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, +// {tensorrt::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, +// {tensorrt::provider_option_names::kInt8CalibTable, MakeStringWithClassicLocale(info.int8_calibration_table_name)}, +// {tensorrt::provider_option_names::kInt8UseNativeCalibTable, MakeStringWithClassicLocale(info.int8_use_native_calibration_table)}, +// {tensorrt::provider_option_names::kDLAEnable, MakeStringWithClassicLocale(info.dla_enable)}, +// {tensorrt::provider_option_names::kDLACore, MakeStringWithClassicLocale(info.dla_core)}, +// {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.dump_subgraphs)}, +// {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.engine_cache_enable)}, +// {tensorrt::provider_option_names::kEngineCachePath, MakeStringWithClassicLocale(info.engine_cache_path)}, +// {tensorrt::provider_option_names::kWeightStrippedEngineEnable, MakeStringWithClassicLocale(info.weight_stripped_engine_enable)}, +// {tensorrt::provider_option_names::kOnnxModelFolderPath, MakeStringWithClassicLocale(info.onnx_model_folder_path)}, +// {tensorrt::provider_option_names::kEngineCachePrefix, MakeStringWithClassicLocale(info.engine_cache_prefix)}, +// {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)}, +// {tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)}, +// {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)}, +// // add new provider option here. +// {tensorrt::provider_option_names::kContextMemorySharingEnable, MakeStringWithClassicLocale(info.context_memory_sharing_enable)}, +// {tensorrt::provider_option_names::kLayerNormFP32Fallback, MakeStringWithClassicLocale(info.layer_norm_fp32_fallback)}, +// {tensorrt::provider_option_names::kTimingCacheEnable, MakeStringWithClassicLocale(info.timing_cache_enable)}, +// {tensorrt::provider_option_names::kTimingCachePath, MakeStringWithClassicLocale(info.timing_cache_path)}, +// {tensorrt::provider_option_names::kForceTimingCacheMatch, MakeStringWithClassicLocale(info.force_timing_cache)}, +// {tensorrt::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.detailed_build_log)}, +// {tensorrt::provider_option_names::kBuildHeuristics, MakeStringWithClassicLocale(info.build_heuristics_enable)}, +// {tensorrt::provider_option_names::kSparsityEnable, MakeStringWithClassicLocale(info.sparsity_enable)}, +// {tensorrt::provider_option_names::kBuilderOptimizationLevel, MakeStringWithClassicLocale(info.builder_optimization_level)}, +// {tensorrt::provider_option_names::kAuxiliaryStreams, MakeStringWithClassicLocale(info.auxiliary_streams)}, +// {tensorrt::provider_option_names::kTacticSources, MakeStringWithClassicLocale(info.tactic_sources)}, +// {tensorrt::provider_option_names::kExtraPluginLibPaths, MakeStringWithClassicLocale(info.extra_plugin_lib_paths)}, +// {tensorrt::provider_option_names::kProfilesMinShapes, MakeStringWithClassicLocale(info.profile_min_shapes)}, +// {tensorrt::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)}, +// {tensorrt::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)}, +// {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)}, +// {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.dump_ep_context_model)}, +// {tensorrt::provider_option_names::kEpContextFilePath, MakeStringWithClassicLocale(info.ep_context_file_path)}, +// {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.ep_context_embed_mode)}, +// {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.engine_hw_compatible)}, +// }; +// return options; +//} +// +//ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensorRTProviderOptionsV2& info) { +// auto empty_if_null = [](const char* s) { return s != nullptr ? std::string{s} : std::string{}; }; +// const std::string kInt8CalibTable_ = empty_if_null(info.trt_int8_calibration_table_name); +// const std::string kEngineCachePath_ = empty_if_null(info.trt_engine_cache_path); +// const std::string kEngineCachePrefix_ = empty_if_null(info.trt_engine_cache_prefix); +// const std::string kTimingCachePath_ = empty_if_null(info.trt_timing_cache_path); +// const std::string kTacticSources_ = empty_if_null(info.trt_tactic_sources); +// const std::string kDecryptionLibPath_ = empty_if_null(info.trt_engine_decryption_lib_path); +// const std::string kExtraPluginLibPaths_ = empty_if_null(info.trt_extra_plugin_lib_paths); +// const std::string kProfilesMinShapes_ = empty_if_null(info.trt_profile_min_shapes); +// const std::string kProfilesMaxShapes_ = empty_if_null(info.trt_profile_max_shapes); +// const std::string kProfilesOptShapes_ = empty_if_null(info.trt_profile_opt_shapes); +// const std::string kEpContextFilePath_ = empty_if_null(info.trt_ep_context_file_path); +// const std::string kOnnxModelFolderPath_ = empty_if_null(info.trt_onnx_model_folder_path); +// +// const ProviderOptions options{ +// {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, +// {tensorrt::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, +// {tensorrt::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, +// {tensorrt::provider_option_names::kMaxPartitionIterations, MakeStringWithClassicLocale(info.trt_max_partition_iterations)}, +// {tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.trt_min_subgraph_size)}, +// {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.trt_max_workspace_size)}, +// {tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.trt_fp16_enable)}, +// {tensorrt::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.trt_int8_enable)}, +// {tensorrt::provider_option_names::kInt8CalibTable, kInt8CalibTable_}, +// {tensorrt::provider_option_names::kInt8UseNativeCalibTable, MakeStringWithClassicLocale(info.trt_int8_use_native_calibration_table)}, +// {tensorrt::provider_option_names::kDLAEnable, MakeStringWithClassicLocale(info.trt_dla_enable)}, +// {tensorrt::provider_option_names::kDLACore, MakeStringWithClassicLocale(info.trt_dla_core)}, +// {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.trt_dump_subgraphs)}, +// {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.trt_engine_cache_enable)}, +// {tensorrt::provider_option_names::kEngineCachePath, kEngineCachePath_}, +// {tensorrt::provider_option_names::kEngineCachePrefix, kEngineCachePrefix_}, +// {tensorrt::provider_option_names::kWeightStrippedEngineEnable, MakeStringWithClassicLocale(info.trt_weight_stripped_engine_enable)}, +// {tensorrt::provider_option_names::kOnnxModelFolderPath, kOnnxModelFolderPath_}, +// {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.trt_engine_decryption_enable)}, +// {tensorrt::provider_option_names::kDecryptionLibPath, kDecryptionLibPath_}, +// {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.trt_force_sequential_engine_build)}, +// {tensorrt::provider_option_names::kContextMemorySharingEnable, MakeStringWithClassicLocale(info.trt_context_memory_sharing_enable)}, +// {tensorrt::provider_option_names::kLayerNormFP32Fallback, MakeStringWithClassicLocale(info.trt_layer_norm_fp32_fallback)}, +// {tensorrt::provider_option_names::kTimingCacheEnable, MakeStringWithClassicLocale(info.trt_timing_cache_enable)}, +// {tensorrt::provider_option_names::kTimingCachePath, kTimingCachePath_}, +// {tensorrt::provider_option_names::kForceTimingCacheMatch, MakeStringWithClassicLocale(info.trt_force_timing_cache)}, +// {tensorrt::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.trt_detailed_build_log)}, +// {tensorrt::provider_option_names::kBuildHeuristics, MakeStringWithClassicLocale(info.trt_build_heuristics_enable)}, +// {tensorrt::provider_option_names::kSparsityEnable, MakeStringWithClassicLocale(info.trt_sparsity_enable)}, +// {tensorrt::provider_option_names::kBuilderOptimizationLevel, MakeStringWithClassicLocale(info.trt_builder_optimization_level)}, +// {tensorrt::provider_option_names::kAuxiliaryStreams, MakeStringWithClassicLocale(info.trt_auxiliary_streams)}, +// {tensorrt::provider_option_names::kTacticSources, kTacticSources_}, +// {tensorrt::provider_option_names::kExtraPluginLibPaths, kExtraPluginLibPaths_}, +// {tensorrt::provider_option_names::kProfilesMinShapes, kProfilesMinShapes_}, +// {tensorrt::provider_option_names::kProfilesMaxShapes, kProfilesMaxShapes_}, +// {tensorrt::provider_option_names::kProfilesOptShapes, kProfilesOptShapes_}, +// {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.trt_cuda_graph_enable)}, +// {tensorrt::provider_option_names::kEpContextFilePath, kEpContextFilePath_}, +// {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.trt_dump_ep_context_model)}, +// {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.trt_ep_context_embed_mode)}, +// {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.trt_engine_hw_compatible)}, +// }; +// return options; +//} +// +///** +// * Update OrtTensorRTProviderOptionsV2 instance with ProviderOptions (map of string-based key-value pairs) +// * +// * Please note that it will reset the OrtTensorRTProviderOptionsV2 instance first and then set up the provided provider options +// * See TensorrtExecutionProviderInfo::FromProviderOptions() for more details. This function will be called by the C API UpdateTensorRTProviderOptions() also. +// * +// * \param provider_options - a pointer to OrtTensorRTProviderOptionsV2 instance +// * \param options - a reference to ProviderOptions instance +// * \param string_copy - if it's true, it uses strncpy() to copy 'provider option' string from ProviderOptions instance to where the 'provider option' const char pointer in OrtTensorRTProviderOptionsV2 instance points to. +// * it it's false, it only saves the pointer and no strncpy(). +// * +// * Note: If there is strncpy involved, please remember to deallocate or simply call C API ReleaseTensorRTProviderOptions. +// */ +//void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy) { +// if (provider_options == nullptr) { +// return; +// } +// auto copy_string_if_needed = [&](std::string& s_in) { +// if (string_copy) { +// char* dest = nullptr; +// auto str_size = s_in.size(); +// if (str_size == 0) { +// return (const char*)nullptr; +// } else { +// dest = new char[str_size + 1]; +//#ifdef _MSC_VER +// strncpy_s(dest, str_size + 1, s_in.c_str(), str_size); +//#else +// strncpy(dest, s_in.c_str(), str_size); +//#endif +// dest[str_size] = '\0'; +// return (const char*)dest; +// } +// } else { +// return s_in.c_str(); +// } +// }; +// +// TensorrtExecutionProviderInfo internal_options = onnxruntime::TensorrtExecutionProviderInfo::FromProviderOptions(options); +// auto& trt_provider_options_v2 = *reinterpret_cast(provider_options); +// trt_provider_options_v2.device_id = internal_options.device_id; +// +// // The 'has_user_compute_stream' of the OrtTensorRTProviderOptionsV2 instance can be set by C API UpdateTensorRTProviderOptionsWithValue() as well +// // We only set the 'has_user_compute_stream' of the OrtTensorRTProviderOptionsV2 instance if it is provided in options or user_compute_stream is provided +// if (options.find("has_user_compute_stream") != options.end()) { +// trt_provider_options_v2.has_user_compute_stream = internal_options.has_user_compute_stream; +// } +// if (options.find("user_compute_stream") != options.end() && internal_options.user_compute_stream != nullptr) { +// trt_provider_options_v2.user_compute_stream = internal_options.user_compute_stream; +// trt_provider_options_v2.has_user_compute_stream = true; +// } +// +// trt_provider_options_v2.trt_max_partition_iterations = internal_options.max_partition_iterations; +// trt_provider_options_v2.trt_min_subgraph_size = internal_options.min_subgraph_size; +// trt_provider_options_v2.trt_max_workspace_size = internal_options.max_workspace_size; +// trt_provider_options_v2.trt_fp16_enable = internal_options.fp16_enable; +// trt_provider_options_v2.trt_int8_enable = internal_options.int8_enable; +// +// trt_provider_options_v2.trt_int8_calibration_table_name = copy_string_if_needed(internal_options.int8_calibration_table_name); +// +// trt_provider_options_v2.trt_int8_use_native_calibration_table = internal_options.int8_use_native_calibration_table; +// trt_provider_options_v2.trt_dla_enable = internal_options.dla_enable; +// trt_provider_options_v2.trt_dla_core = internal_options.dla_core; +// trt_provider_options_v2.trt_dump_subgraphs = internal_options.dump_subgraphs; +// trt_provider_options_v2.trt_engine_cache_enable = internal_options.engine_cache_enable; +// trt_provider_options_v2.trt_weight_stripped_engine_enable = internal_options.weight_stripped_engine_enable; +// trt_provider_options_v2.trt_onnx_model_folder_path = copy_string_if_needed(internal_options.onnx_model_folder_path); +// +// trt_provider_options_v2.trt_engine_cache_path = copy_string_if_needed(internal_options.engine_cache_path); +// trt_provider_options_v2.trt_engine_cache_prefix = copy_string_if_needed(internal_options.engine_cache_prefix); +// trt_provider_options_v2.trt_timing_cache_path = copy_string_if_needed(internal_options.timing_cache_path); +// +// trt_provider_options_v2.trt_engine_decryption_enable = internal_options.engine_decryption_enable; +// +// trt_provider_options_v2.trt_engine_decryption_lib_path = copy_string_if_needed(internal_options.engine_decryption_lib_path); +// +// trt_provider_options_v2.trt_force_sequential_engine_build = internal_options.force_sequential_engine_build; +// trt_provider_options_v2.trt_context_memory_sharing_enable = internal_options.context_memory_sharing_enable; +// trt_provider_options_v2.trt_layer_norm_fp32_fallback = internal_options.layer_norm_fp32_fallback; +// trt_provider_options_v2.trt_timing_cache_enable = internal_options.timing_cache_enable; +// trt_provider_options_v2.trt_force_timing_cache = internal_options.force_timing_cache; +// trt_provider_options_v2.trt_detailed_build_log = internal_options.detailed_build_log; +// trt_provider_options_v2.trt_build_heuristics_enable = internal_options.build_heuristics_enable; +// trt_provider_options_v2.trt_sparsity_enable = internal_options.sparsity_enable; +// trt_provider_options_v2.trt_builder_optimization_level = internal_options.builder_optimization_level; +// trt_provider_options_v2.trt_auxiliary_streams = internal_options.auxiliary_streams; +// +// trt_provider_options_v2.trt_tactic_sources = copy_string_if_needed(internal_options.tactic_sources); +// trt_provider_options_v2.trt_extra_plugin_lib_paths = copy_string_if_needed(internal_options.extra_plugin_lib_paths); +// trt_provider_options_v2.trt_profile_min_shapes = copy_string_if_needed(internal_options.profile_min_shapes); +// trt_provider_options_v2.trt_profile_max_shapes = copy_string_if_needed(internal_options.profile_max_shapes); +// trt_provider_options_v2.trt_profile_opt_shapes = copy_string_if_needed(internal_options.profile_opt_shapes); +// +// trt_provider_options_v2.trt_cuda_graph_enable = internal_options.cuda_graph_enable; +// trt_provider_options_v2.trt_dump_ep_context_model = internal_options.dump_ep_context_model; +// trt_provider_options_v2.trt_ep_context_embed_mode = internal_options.ep_context_embed_mode; +// trt_provider_options_v2.trt_ep_context_file_path = copy_string_if_needed(internal_options.ep_context_file_path); +// trt_provider_options_v2.trt_engine_hw_compatible = internal_options.engine_hw_compatible; +//} +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h new file mode 100644 index 00000000..5ca1d6df --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "provider_options.h" +#include "common.h" + +#define TRT_DEFAULT_OPTIMIZER_LEVEL 3 + +namespace onnxruntime { +// Information needed to construct trt execution providers. +struct TensorrtExecutionProviderInfo { + int device_id{0}; + bool has_user_compute_stream{false}; + void* user_compute_stream{nullptr}; + bool has_trt_options{false}; + int max_partition_iterations{1000}; + int min_subgraph_size{1}; + size_t max_workspace_size{1 << 30}; + bool fp16_enable{false}; + bool int8_enable{false}; + std::string int8_calibration_table_name{""}; + bool int8_use_native_calibration_table{false}; + bool dla_enable{false}; + int dla_core{0}; + bool dump_subgraphs{false}; + bool engine_cache_enable{false}; + std::string engine_cache_path{""}; + bool weight_stripped_engine_enable{false}; + std::string onnx_model_folder_path{""}; + bool engine_decryption_enable{false}; + std::string engine_decryption_lib_path{""}; + bool force_sequential_engine_build{false}; + bool context_memory_sharing_enable{false}; + bool layer_norm_fp32_fallback{false}; + bool timing_cache_enable{false}; + std::string timing_cache_path{""}; + bool force_timing_cache{false}; + bool detailed_build_log{false}; + bool build_heuristics_enable{false}; + bool sparsity_enable{false}; + int builder_optimization_level{3}; + int auxiliary_streams{-1}; + std::string tactic_sources{""}; + std::string extra_plugin_lib_paths{""}; + std::string profile_min_shapes{""}; + std::string profile_max_shapes{""}; + std::string profile_opt_shapes{""}; + bool cuda_graph_enable{false}; + bool dump_ep_context_model{false}; + std::string ep_context_file_path{""}; + int ep_context_embed_mode{0}; + std::string engine_cache_prefix{""}; + bool engine_hw_compatible{false}; + + static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); +// static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info); +// static ProviderOptions ToProviderOptions(const OrtTensorRTProviderOptionsV2& info); +// static void UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy); +// +// std::vector custom_op_domain_list; +}; +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h new file mode 100644 index 00000000..60ff20e7 --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h @@ -0,0 +1,397 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include "flatbuffers/idl.h" +#include "ort_trt_int8_cal_table.fbs.h" +#include "murmurhash3.h" +#include "path_string.h" + +namespace fs = std::filesystem; + +namespace onnxruntime { + +float ConvertSinglePrecisionIEEE754ToFloat(unsigned long input) { + int s = (input >> 31) & 0x01; + int e = ((input & 0x7f800000) >> 23) - 127; + int p = -1; + double m = 0.0; + for (int i = 0; i < 23; ++i) { + m += ((input >> (23 - i - 1)) & 0x01) * pow(2.0, p--); + } + return static_cast((s ? -1 : 1) * pow(2.0, e) * (m + 1.0)); +} + +bool ReadDynamicRange(const std::string file_name, const bool is_trt_calibration_table, std::unordered_map& dynamic_range_map) { + std::ifstream infile(file_name, std::ios::binary | std::ios::in); + if (!infile) { + return false; + } + + if (is_trt_calibration_table) { + // Native TensorRT generated calibration table + std::string line; + char delim = ':'; + if (std::getline(infile, line)) { + std::istringstream first_line(line); + std::string version; + std::getline(first_line, version, delim); + std::size_t found = version.find("TRT-"); + if (found != std::string::npos) { + while (std::getline(infile, line)) { + std::istringstream in_line(line); + std::string str; + std::getline(in_line, str, delim); + std::string tensor_name = str; + std::getline(in_line, str, delim); + unsigned long scale_int = std::strtoul(str.c_str(), nullptr, 16); + float scale_float = ConvertSinglePrecisionIEEE754ToFloat(scale_int); + float dynamic_range = scale_float * 127.0f; + dynamic_range_map[tensor_name] = dynamic_range; + } + } else { + throw std::runtime_error("This is not a TensorRT generated calibration table " + file_name); + } + } + } else { + // ORT generated calibration table + infile.seekg(0, std::ios::end); + size_t length = infile.tellg(); + infile.seekg(0, std::ios::beg); + std::unique_ptr data{new char[length]}; + infile.read((char*)data.get(), length); + infile.close(); + auto flat_table = flatbuffers::GetRoot((const uint8_t*)data.get()); + auto flat_dict = flat_table->dict(); + for (size_t i = 0, end = flat_dict->size(); i < end; ++i) { + flatbuffers::uoffset_t idx = static_cast(i); + dynamic_range_map[flat_dict->Get(idx)->key()->str()] = std::stof(flat_dict->Get(idx)->value()->str()); + } + } + return true; +} + +int GetNumProfiles(std::unordered_map>>& profile_shapes) { + int num_profile = 0; + for (auto it = profile_shapes.begin(); it != profile_shapes.end(); it++) { + num_profile = static_cast(it->second.size()); + if (num_profile > 0) { + break; + } + } + return num_profile; +} + +void SerializeProfileV2(const std::string& file_name, std::unordered_map>>>& shape_ranges) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] In SerializeProfileV2()"; + // Serialize profile + flexbuffers::Builder builder; + auto tensor_map_start = builder.StartMap(); + for (auto tensor_it = shape_ranges.begin(); tensor_it != shape_ranges.end(); tensor_it++) { // iterate tensors + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] input tensor is '" << tensor_it->first.c_str() << "'"; + builder.TypedVector(tensor_it->first.c_str(), [&] { + for (auto dim_it = tensor_it->second.begin(); dim_it != tensor_it->second.end(); dim_it++) { + size_t num_profiles = dim_it->second.size(); + for (size_t i = 0; i < num_profiles; i++) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] profile #" << i << ", dim is " << dim_it->first; + builder.Int(dim_it->first); + builder.Int(dim_it->second[i][0]); + builder.Int(dim_it->second[i][1]); + builder.Int(dim_it->second[i][2]); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << dim_it->first << ", " << dim_it->second[i][0] << ", " << dim_it->second[i][1] << ", " << dim_it->second[i][2]; + } + } + }); + } + builder.EndMap(tensor_map_start); + builder.Finish(); + + // Save flexbuffer + std::ofstream file(file_name, std::ios::binary | std::ios::out); + auto buf = builder.GetBuffer(); + size_t size = builder.GetSize(); + file.write(reinterpret_cast(&buf[0]), size); + file.close(); +} + +std::unordered_map>>> DeserializeProfileV2(std::ifstream& infile) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] In DeserializeProfileV2()"; + // Load flexbuffer + infile.seekg(0, std::ios::end); + size_t length = infile.tellg(); + infile.seekg(0, std::ios::beg); + std::unique_ptr data{new char[length]}; + infile.read((char*)data.get(), length); + infile.close(); + + // Deserialize profile + std::unordered_map>>> shape_ranges; + auto tensors_range_entries = flexbuffers::GetRoot((const uint8_t*)data.get(), length).AsMap(); + auto keys = tensors_range_entries.Keys(); + auto values = tensors_range_entries.Values(); + for (size_t i = 0, end = keys.size(); i < end; ++i) { // iterate tensors + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] input tensor is '" << keys[i].AsString().c_str() << "'"; + auto dim_range_vector = values[i].AsTypedVector(); + std::unordered_map>> inner_map; + std::vector> profile_vector; + + for (size_t k = 0; k < (dim_range_vector.size() / 4); k++) { // iterate dim, min, max, opt for all profiles + std::vector shape_vector; + auto idx = 4 * k; + auto dim = dim_range_vector[idx].AsInt64(); + shape_vector.push_back(dim_range_vector[idx + 1].AsInt64()); // min shape + shape_vector.push_back(dim_range_vector[idx + 2].AsInt64()); // max shape + shape_vector.push_back(dim_range_vector[idx + 3].AsInt64()); // opt shape + + if (inner_map.find(dim) == inner_map.end()) { + inner_map[dim] = profile_vector; + } + inner_map[dim].push_back(shape_vector); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << dim << ", " << shape_vector[0] << ", " << shape_vector[1] << ", " << shape_vector[2]; + } + shape_ranges[keys[i].AsString().c_str()] = inner_map; + } + return shape_ranges; +} + +bool CompareProfiles(const std::string& file_name, + std::unordered_map>>& profile_min_shapes, + std::unordered_map>>& profile_max_shapes, + std::unordered_map>>& profile_opt_shapes) { + std::ifstream profile_file(file_name, std::ios::binary | std::ios::in); + if (!profile_file) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << file_name << " doesn't exist."; + return true; + } + + std::unordered_map>>> shape_ranges; + shape_ranges = DeserializeProfileV2(profile_file); + + /* The format of the two data structures are below, for example: + * + * shape_ranges: + * { + * tensor_a: { + * dim_0: [[min_shape, max_shape, opt_shape]], + * dim_2: [[min_shape, max_shape, opt_shape]] + * }, + * tensor_b: { + * dim_1: [[min_shape, max_shape, opt_shape]] + * } + * } + * + * profile_min_shapes: + * { + * tensor_a: [[dim_0_value_0, dim_1_value_1, dim_2_value_2]], + * tensor_b: [[dim_0_value_3, dim_1_value_4, dim_2_value_5]] + * } + * + */ + + // Check number of dynamic shape inputs + if (profile_min_shapes.size() != shape_ranges.size()) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Numbers of dynamic shape inputs are not the same."; + return true; + } + + // Iterate through shape_ranges map + for (auto tensor_it = shape_ranges.begin(); tensor_it != shape_ranges.end(); tensor_it++) { // iterate tensors + auto tensor_name = tensor_it->first; + if (profile_min_shapes.find(tensor_name) == profile_min_shapes.end()) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tensor name '" << tensor_name << "' doesn't exist in trt_profile_min_shapes."; + return true; + } + + for (auto dim_it = tensor_it->second.begin(); dim_it != tensor_it->second.end(); dim_it++) { // iterate dimensions + auto dim = dim_it->first; + auto num_profiles = GetNumProfiles(profile_min_shapes); + + if (dim_it->second.size() != static_cast(num_profiles)) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Numbers of profiles are not the same."; + return true; + } + + for (size_t i = 0; i < dim_it->second.size(); i++) { // iterate (multiple) profile(s) + auto shape_values = dim_it->second[i]; + if (dim > (profile_min_shapes[tensor_name][i].size() - 1)) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dimension " << dim << " of '" << tensor_name << "' in " << file_name << " exceeds the total dimension of trt_profile_min_shapes."; + return true; + } + + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] min shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_min_shapes[tensor_name][i][dim]; + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] min shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[0] << " in " << file_name; + if (profile_min_shapes[tensor_name][i][dim] != shape_values[0]) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] min shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; + return true; + } + + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] max shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_max_shapes[tensor_name][i][dim]; + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] max shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[1] << " in " << file_name; + if (profile_max_shapes[tensor_name][i][dim] != shape_values[1]) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] max shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; + return true; + } + + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] opt shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_opt_shapes[tensor_name][i][dim]; + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] opt shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[2] << " in " << file_name; + if (profile_opt_shapes[tensor_name][i][dim] != shape_values[2]) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] opt shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; + return true; + } + } + } + } + return false; +} + +std::string GetCachePath(const std::string& root, const std::string& name) { + if (root.empty()) { + return name; + } else { + fs::path path = root; + path.append(name); + return path.string(); + } +} + +std::string GetComputeCapacity(const cudaDeviceProp& prop) { + const std::string compute_capability = std::to_string(prop.major * 10 + prop.minor); + return compute_capability; +} + +std::string GetTimingCachePath(const std::string& root, std::string& compute_cap) { + // append compute capability of the GPU as this invalidates the cache and TRT will throw when loading the cache + const std::string timing_cache_name = "TensorrtExecutionProvider_cache_sm" + + compute_cap + ".timing"; + return GetCachePath(root, timing_cache_name); +} + +/* +HashValue TRTGenerateId(const OrtApi& api, const OrtGraph* graph, std::string trt_version, std::string cuda_version) { + HashValue model_hash = 0; + + + //// find the top level graph + //const Graph* cur_graph = &graph_viewer.GetGraph(); + //while (cur_graph->IsSubgraph()) { + // cur_graph = cur_graph->ParentGraph(); + //} + + uint32_t hash[4] = {0, 0, 0, 0}; + + auto hash_str = [&hash](const std::string& str) { + MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); + }; + + const std::filesystem::path* model_path = nullptr; + api.OrtGraph_GetModelPath(graph_viewer, reinterpret_cast(&model_path)); + + // Use the model's file name instead of the entire path to avoid cache regeneration if path changes + if (model_path->has_filename()) { + std::string model_name = PathToUTF8String(model_path->filename()); + + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Model name is " << model_name; + // Ensure enough characters are hashed in case model names are too short + const size_t model_name_length = model_name.size(); + constexpr size_t hash_string_length = 500; + std::string repeat_model_name = model_name; + for (size_t i = model_name_length; i > 0 && i < hash_string_length; i += model_name_length) { + repeat_model_name += model_name; + } + hash_str(repeat_model_name); + } else { + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Model path is empty"; + } + + // fingerprint current graph by hashing graph inputs + // const std::vector& input_names = nullptr; + const char** input_names = nullptr; // TODO(leca): release input_names + size_t input_count = 0; + api.OrtGraph_GetAllInputs(graph_viewer, &input_names, &input_count); + for (size_t i = 0; i < input_count; ++i) { + hash_str(input_names[i]); + } + + // hashing output of each node + int number_of_ort_nodes = 0; + api.OrtGraph_NumberOfNodes(graph_viewer, &number_of_ort_nodes); + std::vector nodes_vector(number_of_ort_nodes); + std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); + const size_t* nodes_index = nullptr; + size_t nodes_count = 0; + api.OrtGraph_GetNodesIndexInTopologicalOrder(graph_viewer, 0, &nodes_index, &nodes_count); + for (const auto& index : nodes_vector) { + const OrtNode* node = nullptr; + graph_api->OrtGraph_GetOrtNode(graph_viewer, nodes_index[index], &node); + size_t output_size = 0; + graph_api->OrtNode_GetNumOutputs(node, &output_size); + for (size_t i = 0; i < output_size; ++i) { + const char* output_name = nullptr; + graph_api->OrtNode_GetIthOutputName(node, i, &output_name); + if (output_name != nullptr) { + hash_str(output_name); + } + } + } + +#ifdef __linux__ + hash_str("LINUX"); +#elif defined(_WIN32) + hash_str("WINDOWS"); +#endif + +#ifdef CUDA_VERSION + hash_str(cuda_version); +#endif + +#if defined(NV_TENSORRT_MAJOR) && defined(NV_TENSORRT_MINOR) + hash_str(trt_version); +#endif + + model_hash = hash[0] | (uint64_t(hash[1]) << 32); + + // return the current unique id + return model_hash; +} +*/ + +std::vector split(const std::string& str, char delimiter) { + std::vector tokens; + std::string token; + std::istringstream tokenStream(str); + while (std::getline(tokenStream, token, delimiter)) { + tokens.push_back(token); + } + return tokens; +} + +std::string join(const std::vector& vec, const std::string& delimiter) { + std::string result; + for (size_t i = 0; i < vec.size(); ++i) { + result += vec[i]; + if (i < vec.size() - 1) { + result += delimiter; + } + } + return result; +} + +std::string GetCacheSuffix(const std::string& fused_node_name, const std::string& trt_node_name_with_precision) { + std::vector split_fused_node_name = split(fused_node_name, '_'); + if (split_fused_node_name.size() >= 3) { + // Get index of model hash from fused_node_name + std::string model_hash = split_fused_node_name[split_fused_node_name.size() - 3]; + size_t index = fused_node_name.find(model_hash); + // Parse suffix from trt_node_name_with_precision, as it has additional precision info + std::vector suffix_group = split(trt_node_name_with_precision.substr(index), '_'); + if (suffix_group.size() > 2) { + suffix_group.erase(suffix_group.begin() + 2); + } + return join(suffix_group, "_"); + } + return ""; +} +} diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc new file mode 100644 index 00000000..e2f8a50e --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -0,0 +1,152 @@ +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT +#include + +#include +#include +#include +#include +#include +#include +#include + +struct TensorrtExecutionProvider; + +static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) { + const auto* factory = static_cast(this_ptr); + return factory->ep_name_.c_str(); +} + +static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) { + const auto* factory = static_cast(this_ptr); + return factory->vendor_.c_str(); +} + +static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) { + size_t& num_ep_devices = *p_num_ep_devices; + auto* factory = static_cast(this_ptr); + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + // C API + const OrtHardwareDevice& device = *devices[i]; + if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + // these can be returned as nullptr if you have nothing to add. + OrtKeyValuePairs* ep_metadata = nullptr; + OrtKeyValuePairs* ep_options = nullptr; + factory->ort_api.CreateKeyValuePairs(&ep_metadata); + factory->ort_api.CreateKeyValuePairs(&ep_options); + + // random example using made up values + factory->ort_api.AddKeyValuePair(ep_metadata, "version", "0.1"); + factory->ort_api.AddKeyValuePair(ep_options, "run_really_fast", "true"); + + // OrtEpDevice copies ep_metadata and ep_options. + auto* status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, + &ep_devices[num_ep_devices++]); + + factory->ort_api.ReleaseKeyValuePairs(ep_metadata); + factory->ort_api.ReleaseKeyValuePairs(ep_options); + + if (status != nullptr) { + return status; + } + } + + // C++ API equivalent. Throws on error. + //{ + // Ort::ConstHardwareDevice device(devices[i]); + // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + // Ort::KeyValuePairs ep_metadata; + // Ort::KeyValuePairs ep_options; + // ep_metadata.Add("version", "0.1"); + // ep_options.Add("run_really_fast", "true"); + // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()}; + // ep_devices[num_ep_devices++] = ep_device.release(); + // } + //} + } + + return nullptr; +} + +static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, + _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, + _Out_ OrtEp** ep) { + auto* factory = static_cast(this_ptr); + *ep = nullptr; + + if (num_devices != 1) { + // we only registered for CPU and only expected to be selected for one CPU + // if you register for multiple devices (e.g. CPU, GPU and maybe NPU) you will get an entry for each device + // the EP has been selected for. + return factory->ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "Example EP only supports selection for one device."); + } + + // Create the execution provider + RETURN_IF_ERROR(factory->ort_api.Logger_LogMessage(logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + "Creating Example EP", ORT_FILE, __LINE__, __FUNCTION__)); + + // use properties from the device and ep_metadata if needed + // const OrtHardwareDevice* device = devices[0]; + // const OrtKeyValuePairs* ep_metadata = ep_metadata[0]; + + auto dummy_ep = std::make_unique(*factory, factory->ep_name_, *session_options, *logger); + + *ep = dummy_ep.release(); + return nullptr; +} + +static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) { + ExampleEp* dummy_ep = static_cast(ep); + delete dummy_ep; +} + +// To make symbols visible on macOS/iOS +#ifdef __APPLE__ +#define EXPORT_SYMBOL __attribute__((visibility("default"))) +#else +#define EXPORT_SYMBOL +#endif + +extern "C" { +// +// Public symbols +// +EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const OrtApiBase* ort_api_base, + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + const OrtEpApi* ort_ep_api = ort_api->GetEpApi(); + + // Factory could use registration_name or define its own EP name. + std::unique_ptr factory = std::make_unique(registration_name, + ApiPtrs{*ort_api, *ort_ep_api}); + + if (max_factories < 1) { + return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, + "Not enough space to return EP factory. Need at least one."); + } + + factories[0] = factory.release(); + *num_factories = 1; + + return nullptr; +} + +EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { + delete factory; + return nullptr; +} + +} // extern "C" diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h new file mode 100644 index 00000000..63a9a40a --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -0,0 +1,58 @@ +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +#define RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* status = (fn); \ + if (status != nullptr) { \ + return status; \ + } \ + } while (0) + +#define RETURN_IF(cond, ort_api, msg) \ + do { \ + if ((cond)) { \ + return (ort_api).CreateStatus(ORT_EP_FAIL, (msg)); \ + } \ + } while (0) + +static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr); +static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr); +static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices); +static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, + _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, + _Out_ OrtEp** ep); +static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep); + +struct ApiPtrs { + const OrtApi& ort_api; + const OrtEpApi& ep_api; +}; + +/// +/// +/// Plugin TensorRT EP factory that can create an OrtEp and return information about the supported hardware devices. +/// +/// +struct TensorrtExecutionProviderFactory : OrtEpFactory, ApiPtrs { + TensorrtExecutionProviderFactory(const char* ep_name, ApiPtrs apis) : ApiPtrs(apis), ep_name_{ep_name} { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + GetSupportedDevices = GetSupportedDevicesImpl; + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + } + const std::string ep_name_; // EP name + const std::string vendor_{"Nvidia"}; // EP vendor name +}; \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/utils/code_location.h b/plugin_execution_providers/tensorrt/utils/code_location.h new file mode 100644 index 00000000..dbff6909 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/code_location.h @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { +/** + CodeLocation captures information on where in the source code a message came from. +*/ +struct CodeLocation { + /** + @param file_path Usually the value of __FILE__ + @param line Usually the value of __LINE__ + @param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__ + */ + CodeLocation(const char* file_path, const int line, const char* func) + : file_and_path{file_path}, line_num{line}, function{func} { + } + + /** + @param file_path Usually the value of __FILE__ + @param line Usually the value of __LINE__ + @param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__ + @param stacktrace Stacktrace from source of message. + */ + CodeLocation(const char* file_path, const int line, const char* func, const std::vector& stacktrace) + : file_and_path{file_path}, line_num{line}, function{func}, stacktrace(stacktrace) { + } + + std::string FileNoPath() const { + // assuming we always have work to do, so not trying to avoid creating a new string if + // no path was removed. + return file_and_path.substr(file_and_path.find_last_of("/\\") + 1); + } + + enum Format { + kFilename, + kFilenameAndPath + }; + + std::string ToString(Format format = Format::kFilename) const { + std::ostringstream out; + out << (format == Format::kFilename ? FileNoPath() : file_and_path) << ":" << line_num << " " << function; + return out.str(); + } + // utf-8. Because on Windows we compile our code with "/utf-8". And we assume the other platforms only use utf-8. + const std::string file_and_path; + const int line_num; + // utf-8 + const std::string function; + const std::vector stacktrace; +}; + +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/common.h b/plugin_execution_providers/tensorrt/utils/common.h new file mode 100644 index 00000000..eaf000a5 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/common.h @@ -0,0 +1,169 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Portions Copyright (c) Microsoft Corporation + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "code_location.h" +#include "exceptions.h" +#include "make_string.h" +#include "status.h" + +namespace onnxruntime { + +// __PRETTY_FUNCTION__ isn't a macro on gcc, so use a check for _MSC_VER +// so we only define it as one for MSVC +#if (_MSC_VER && !defined(__PRETTY_FUNCTION__)) +#define __PRETTY_FUNCTION__ __FUNCTION__ +#endif + +// Capture where a message is coming from. Use __FUNCTION__ rather than the much longer __PRETTY_FUNCTION__ +#define ORT_WHERE ::onnxruntime::CodeLocation(__FILE__, __LINE__, static_cast(__FUNCTION__)) + +#define ORT_WHERE_WITH_STACK \ + ::onnxruntime::CodeLocation(__FILE__, __LINE__, static_cast(__PRETTY_FUNCTION__), ::onnxruntime::GetStackTrace()) + +// Throw an exception with optional message. +// NOTE: The arguments get streamed into a string via ostringstream::operator<< +// DO NOT use a printf format string, as that will not work as you expect. +/* +#define ORT_THROW(...) \ + throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, ::onnxruntime::MakeString(__VA_ARGS__)) +*/ +#define ORT_THROW(...) \ + throw ::onnxruntime::OnnxRuntimeException(::onnxruntime::MakeString(__VA_ARGS__)) + +// Just in order to mark things as not implemented. Do not use in final code. +#define ORT_NOT_IMPLEMENTED(...) \ + throw ::onnxruntime::NotImplementedException(::onnxruntime::MakeString(__VA_ARGS__)) + +// Check condition. +// NOTE: The arguments get streamed into a string via ostringstream::operator<< +// DO NOT use a printf format string, as that will not work as you expect. +#define ORT_ENFORCE(condition, ...) \ + do { \ + if (!(condition)) { \ + throw ::onnxruntime::OnnxRuntimeException(#condition, \ + ::onnxruntime::MakeString(__VA_ARGS__)); \ + } \ + } while (false) + +#define ORT_THROW_EX(ex, ...) \ + throw ex(__VA_ARGS__) + +#define ORT_MAKE_STATUS(category, code, ...) \ + ::onnxruntime::common::Status(::onnxruntime::common::category, \ + ::onnxruntime::common::code, \ + ::onnxruntime::MakeString(__VA_ARGS__)) + +// Check condition. if met, return status. +#define ORT_RETURN_IF(condition, ...) \ + do { \ + if (condition) { \ + return ::onnxruntime::common::Status(::onnxruntime::common::ONNXRUNTIME, \ + ::onnxruntime::common::FAIL, \ + ::onnxruntime::MakeString(ORT_WHERE.ToString(), " ", __VA_ARGS__)); \ + } \ + } while (false) + +// Check condition. if not met, return status. +#define ORT_RETURN_IF_NOT(condition, ...) \ + ORT_RETURN_IF(!(condition), __VA_ARGS__) + +// Macros to disable the copy and/or move ctor and assignment methods +// These are usually placed in the private: declarations for a class. + +#define ORT_DISALLOW_COPY(TypeName) TypeName(const TypeName&) = delete + +#define ORT_DISALLOW_ASSIGNMENT(TypeName) TypeName& operator=(const TypeName&) = delete + +#define ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName) \ + ORT_DISALLOW_COPY(TypeName); \ + ORT_DISALLOW_ASSIGNMENT(TypeName) + +#define ORT_DISALLOW_MOVE(TypeName) \ + TypeName(TypeName&&) = delete; \ + TypeName& operator=(TypeName&&) = delete + +#define ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TypeName) \ + ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName); \ + ORT_DISALLOW_MOVE(TypeName) + +#define ORT_RETURN_IF_ERROR(expr) \ + do { \ + auto _status = (expr); \ + if ((!_status.IsOK())) { \ + return _status; \ + } \ + } while (0) + +#define ORT_THROW_IF_ERROR(expr) \ + do { \ + auto _status = (expr); \ + if ((!_status.IsOK())) { \ + ORT_THROW(_status); \ + } \ + } while (0) + +// use this macro when cannot early return +#define ORT_CHECK_AND_SET_RETVAL(expr) \ + do { \ + if (retval.IsOK()) { \ + retval = (expr); \ + } \ + } while (0) + +struct null_type {}; +inline std::string ToUTF8String(const std::string& s) { return s; } +#ifdef _WIN32 +/** + * Convert a wide character string to a UTF-8 string + */ +std::string ToUTF8String(const std::wstring& s); + +std::wstring ToWideString(const std::string& s); +inline std::wstring ToWideString(const std::wstring& s) { return s; } +#else +inline std::string ToWideString(const std::string& s) { return s; } +#endif + +constexpr size_t kMaxStrLen = 2048; + +// Returns whether `key` is in `container`. +// Like C++20's map/set contains() member function. +template typename AssociativeContainer, + typename LookupKey> +inline bool Contains(const AssociativeContainer& container, LookupKey&& key) { + return container.find(std::forward(key)) != container.end(); +} + +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h b/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h new file mode 100644 index 00000000..81d5975c --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "../common.h" + +namespace onnxruntime { + +// ----------------------------------------------------------------------- +// Error handling +// ----------------------------------------------------------------------- +// +template +const char* CudaErrString(ERRTYPE) { + ORT_NOT_IMPLEMENTED(); +} + +template +std::conditional_t CudaCall( + ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line) { + if (retCode != successCode) { + try { +//#ifdef _WIN32 + //std::string hostname_str = GetEnvironmentVar("COMPUTERNAME"); + //if (hostname_str.empty()) { + //hostname_str = "?"; + //} + //const char* hostname = hostname_str.c_str(); +//#else + //char hostname[HOST_NAME_MAX]; + //if (gethostname(hostname, HOST_NAME_MAX) != 0) + //strcpy(hostname, "?"); +//#endif + int currentCudaDevice = -1; + cudaGetDevice(¤tCudaDevice); + cudaGetLastError(); // clear last CUDA error + static char str[1024]; + snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=? ; file=%s ; line=%d ; expr=%s; %s", + libName, (int)retCode, CudaErrString(retCode), currentCudaDevice, + //hostname, + file, line, exprString, msg); + if constexpr (THRW) { + // throw an exception with the error info + ORT_THROW(str); + } else { + //LOGS_DEFAULT(ERROR) << str; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str); + } + } catch (const std::exception& e) { // catch, log, and rethrow since CUDA code sometimes hangs in destruction, so we'd never get to see the error + if constexpr (THRW) { + ORT_THROW(e.what()); + } else { + //LOGS_DEFAULT(ERROR) << e.what(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what()); + } + } + } + if constexpr (!THRW) { + return Status::OK(); + } +} + +//template +//std::conditional_t CudaCall( + //ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line); + +#define CUDA_CALL(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) + +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/cuda/cuda_common.h b/plugin_execution_providers/tensorrt/utils/cuda/cuda_common.h new file mode 100644 index 00000000..b00ef3f9 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/cuda/cuda_common.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cuda_call.h" + +namespace onnxruntime { +namespace cuda { + +#define CUDA_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDA_CALL(expr)) + +} // namespace cuda +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/endian.h b/plugin_execution_providers/tensorrt/utils/endian.h new file mode 100644 index 00000000..629fb78f --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/endian.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { + +// the semantics of this enum should match std::endian from C++20 +enum class endian { +#if defined(_WIN32) + little = 0, + big = 1, + native = little, +#elif defined(__GNUC__) || defined(__clang__) + little = __ORDER_LITTLE_ENDIAN__, + big = __ORDER_BIG_ENDIAN__, + native = __BYTE_ORDER__, +#else +#error onnxruntime::endian is not implemented in this environment. +#endif +}; + +static_assert( + endian::native == endian::little || endian::native == endian::big, + "Only little-endian or big-endian native byte orders are supported."); + +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/exceptions.h b/plugin_execution_providers/tensorrt/utils/exceptions.h new file mode 100644 index 00000000..19c1586a --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/exceptions.h @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "common.h" +//#include "code_location.h" + +namespace onnxruntime { + +class NotImplementedException : public std::logic_error { + public: + explicit NotImplementedException(const char* _Message = "Function not yet implemented") noexcept : std::logic_error(_Message){}; + explicit NotImplementedException(const std::string& _Message = "Function not yet implemented") noexcept : std::logic_error(_Message){}; +}; + +class TypeMismatchException : public std::logic_error { + public: + TypeMismatchException() noexcept : logic_error("Type mismatch"){}; +}; + +class OnnxRuntimeException : public std::exception { + public: + // code location is not provided for now + /* + OnnxRuntimeException(const CodeLocation& location, const std::string& msg) noexcept + : OnnxRuntimeException(location, nullptr, msg) { + } + */ + + /** + Create a new exception that captures the location it was thrown from. + @param location Location in the source code the exception is being thrown from + @param failed_condition Optional string containing the condition that failed. + e.g. "tensor.Size() == input.Size()". May be nullptr. + @param msg Message containing additional information about the exception cause. + */ + /* + OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) + : location_{location} { + std::ostringstream ss; + + ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous + if (failed_condition != nullptr) { + ss << " " << failed_condition << " was false."; + } + + ss << " " << msg << "\n"; + if (!location.stacktrace.empty()) { + ss << "Stacktrace:\n"; + // skip the first entry in the stacktrace as we have that information from location.ToString() + std::copy(std::next(location.stacktrace.begin()), location.stacktrace.end(), std::ostream_iterator(ss, "\n")); + } + + what_ = ss.str(); + } + */ + + OnnxRuntimeException(const std::string& msg) noexcept + : OnnxRuntimeException(nullptr, msg) { + } + + OnnxRuntimeException(const char* failed_condition, const std::string& msg) { + std::ostringstream ss; + + if (failed_condition != nullptr) { + ss << failed_condition << " was false."; + } + + ss << " " << msg << "\n"; + what_ = ss.str(); + } + + const char* what() const noexcept override { + return what_.c_str(); + } + + private: + //const CodeLocation location_; + const std::vector stacktrace_; + std::string what_; +}; + +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/helper.cc b/plugin_execution_providers/tensorrt/utils/helper.cc new file mode 100644 index 00000000..7a889c30 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/helper.cc @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "common.h" + +#ifdef _WIN32 +#include +#include +#endif + +namespace onnxruntime { +#ifdef _WIN32 +std::string ToUTF8String(const std::wstring& s) { + if (s.size() >= static_cast(std::numeric_limits::max())) + ORT_THROW("length overflow"); + + const int src_len = static_cast(s.size() + 1); + const int len = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, nullptr, 0, nullptr, nullptr); + assert(len > 0); + std::string ret(static_cast(len) - 1, '\0'); +#pragma warning(disable : 4189) + const int r = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, (char*)ret.data(), len, nullptr, nullptr); + assert(len == r); +#pragma warning(default : 4189) + return ret; +} + +std::wstring ToWideString(const std::string& s) { + if (s.size() >= static_cast(std::numeric_limits::max())) + ORT_THROW("length overflow"); + + const int src_len = static_cast(s.size() + 1); + const int len = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, nullptr, 0); + assert(len > 0); + std::wstring ret(static_cast(len) - 1, '\0'); +#pragma warning(disable : 4189) + const int r = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, (wchar_t*)ret.data(), len); + assert(len == r); +#pragma warning(default : 4189) + return ret; +} +#endif // #ifdef _WIN32 + +#ifdef ORT_NO_EXCEPTIONS +void PrintFinalMessage(const char* msg) { +#if defined(__ANDROID__) + __android_log_print(ANDROID_LOG_ERROR, "onnxruntime", "%s", msg); +#else + // TODO, consider changing the output of the error message from std::cerr to logging when the + // exceptions are disabled, since using std::cerr might increase binary size, and std::cerr output + // might not be easily accessible on some systems such as mobile + // TODO, see if we need to change the output of the error message from std::cerr to NSLog for iOS + std::cerr << msg << std::endl; +#endif +} +#endif // #ifdef ORT_NO_EXCEPTIONS + +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/make_string.h b/plugin_execution_providers/tensorrt/utils/make_string.h new file mode 100644 index 00000000..826898de --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/make_string.h @@ -0,0 +1,126 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Portions Copyright (c) Microsoft Corporation + +#pragma once + +#include +#include +#include + +namespace onnxruntime { + +namespace detail { + +inline void MakeStringImpl(std::ostringstream& /*ss*/) noexcept { +} + +template +inline void MakeStringImpl(std::ostringstream& ss, const T& t) noexcept { + ss << t; +} + +template +inline void MakeStringImpl(std::ostringstream& ss, const T& t, const Args&... args) noexcept { + MakeStringImpl(ss, t); + MakeStringImpl(ss, args...); +} + +// see MakeString comments for explanation of why this is necessary +template +inline std::string MakeStringImpl(const Args&... args) noexcept { + std::ostringstream ss; + MakeStringImpl(ss, args...); + return ss.str(); +} + +// +// Infrastructure to convert char[n] to char* to reduce binary size +// + +// default is to leave the type as is +template +struct if_char_array_make_ptr { + using type = T; +}; + +// specialization that matches an array reference, which is what the char array from a string literal +// used in a call to MakeString will be. +// if the type is a char[n] array we 'decay' it to a char* so that the usages can be folded. +template +struct if_char_array_make_ptr { + // remove a single extent (T[x] -> T, but T[x][y] -> T[y]) so we only match char[x], + // and get the type name without the 'const' so both 'const char (&)[n]' and 'char (&)[n]' are matched. + using element_type = typename std::remove_const::type>::type; + using type = typename std::conditional::value, T*, T (&)[N]>::type; +}; + +// helper to make usage simpler in MakeString +template +using if_char_array_make_ptr_t = typename if_char_array_make_ptr::type; +} // namespace detail + +/** + * Makes a string by concatenating string representations of the arguments. + * This version uses the current locale. + */ +template +std::string MakeString(const Args&... args) { + // We need to update the types from the MakeString template instantiation to decay any char[n] to char*. + // e.g. MakeString("in", "out") goes from MakeString to MakeStringImpl + // so that MakeString("out", "in") will also match MakeStringImpl instead of requiring + // MakeStringImpl. + // + // We have to do the type processing before any actual work, so this function purely implements the type processing. + // If we do not do it this way we do not get the full binary size reduction. + // + // See https://stackoverflow.com/a/29418212/684911 for overall details of the approach, but note it does not cover + // the need to do the type processing as a separate step. + + return detail::MakeStringImpl(detail::if_char_array_make_ptr_t(args)...); +} + +/** + * Makes a string by concatenating string representations of the arguments. + * This version uses std::locale::classic(). + */ +template +std::string MakeStringWithClassicLocale(const Args&... args) { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + detail::MakeStringImpl(ss, args...); + return ss.str(); +} + +// MakeString versions for already-a-string types. + +inline std::string MakeString(const std::string& str) { + return str; +} + +inline std::string MakeString(const char* cstr) { + return cstr; +} + +inline std::string MakeStringWithClassicLocale(const std::string& str) { + return str; +} + +inline std::string MakeStringWithClassicLocale(const char* cstr) { + return cstr; +} + +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/murmurhash3.cc b/plugin_execution_providers/tensorrt/utils/murmurhash3.cc new file mode 100644 index 00000000..49fcb2ef --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/murmurhash3.cc @@ -0,0 +1,349 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "murmurhash3.h" + +// Original source: https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp +//----------------------------------------------------------------------------- +// MurmurHash3 was written by Austin Appleby, and is placed in the public +// domain. The author hereby disclaims copyright to this source code. + +// Note - The x86 and x64 versions do _not_ produce the same results, as the +// algorithms are optimized for their respective platforms. You can still +// compile and run any of them on any platform, but your performance with the +// non-native version will be less than optimal. + +/* Modifications Copyright (c) Microsoft. */ + +#include "endian.h" + +//----------------------------------------------------------------------------- +// Platform-specific functions and macros + +// Microsoft Visual Studio + +#if defined(_MSC_VER) + +#define FORCE_INLINE __forceinline + +#include + +#define ROTL32(x, y) _rotl(x, y) +#define ROTL64(x, y) _rotl64(x, y) + +#define BIG_CONSTANT(x) (x) + +// Other compilers + +#else // defined(_MSC_VER) + +#define FORCE_INLINE inline __attribute__((always_inline)) + +inline uint32_t rotl32(uint32_t x, int8_t r) { + return (x << r) | (x >> (32 - r)); +} + +inline uint64_t rotl64(uint64_t x, int8_t r) { + return (x << r) | (x >> (64 - r)); +} + +#define ROTL32(x, y) rotl32(x, y) +#define ROTL64(x, y) rotl64(x, y) + +#define BIG_CONSTANT(x) (x##LLU) + +#endif // !defined(_MSC_VER) +#include +//----------------------------------------------------------------------------- +// Block read - on little-endian machines this is a single load, +// while on big-endian or unknown machines the byte accesses should +// still get optimized into the most efficient instruction. +// +// Changes to support big-endian from https://github.com/explosion/murmurhash/pull/27/ +// were manually applied to original murmurhash3 source code. +FORCE_INLINE uint32_t getblock32(const uint32_t* p, int i) { + if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) { + return p[i]; + } else { + const uint8_t* c = (const uint8_t*)&p[i]; + return (uint32_t)c[0] | + (uint32_t)c[1] << 8 | + (uint32_t)c[2] << 16 | + (uint32_t)c[3] << 24; + } +} + +FORCE_INLINE uint64_t getblock64(const uint64_t* p, int i) { + if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) { + return p[i]; + } else { + const uint8_t* c = (const uint8_t*)&p[i]; + return (uint64_t)c[0] | + (uint64_t)c[1] << 8 | + (uint64_t)c[2] << 16 | + (uint64_t)c[3] << 24 | + (uint64_t)c[4] << 32 | + (uint64_t)c[5] << 40 | + (uint64_t)c[6] << 48 | + (uint64_t)c[7] << 56; + } +} + +//----------------------------------------------------------------------------- +// Finalization mix - force all bits of a hash block to avalanche + +FORCE_INLINE constexpr uint32_t fmix32(uint32_t h) { + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + h ^= h >> 16; + + return h; +} + +//---------- + +FORCE_INLINE constexpr uint64_t fmix64(uint64_t k) { + k ^= k >> 33; + k *= BIG_CONSTANT(0xff51afd7ed558ccd); + k ^= k >> 33; + k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53); + k ^= k >> 33; + + return k; +} + +//----------------------------------------------------------------------------- + +namespace onnxruntime { +void MurmurHash3::x86_32(const void* key, int len, + uint32_t seed, void* out) { + const uint8_t* data = (const uint8_t*)key; + const int nblocks = len / 4; + + uint32_t h1 = seed; + + constexpr uint32_t c1 = 0xcc9e2d51; + constexpr uint32_t c2 = 0x1b873593; + + //---------- + // body + + const uint32_t* blocks = (const uint32_t*)(data + static_cast(nblocks) * 4); + + for (int i = -nblocks; i; i++) { + uint32_t k1 = getblock32(blocks, i); + + k1 *= c1; + k1 = ROTL32(k1, 15); + k1 *= c2; + + h1 ^= k1; + h1 = ROTL32(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + } + + //---------- + // tail + + const uint8_t* tail = (const uint8_t*)(data + static_cast(nblocks) * 4); + + uint32_t k1 = 0; + + switch (len & 3) { + case 3: + k1 ^= tail[2] << 16; + [[fallthrough]]; + case 2: + k1 ^= tail[1] << 8; + [[fallthrough]]; + case 1: + k1 ^= tail[0]; + k1 *= c1; + k1 = ROTL32(k1, 15); + k1 *= c2; + h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; + + h1 = fmix32(h1); + + *(uint32_t*)out = h1; +} + +//----------------------------------------------------------------------------- + +void MurmurHash3::x86_128(const void* key, int len, uint32_t seed, void* out) { + const uint8_t* data = (const uint8_t*)key; + const int nblocks = len / 16; + + uint32_t h1 = seed; + uint32_t h2 = seed; + uint32_t h3 = seed; + uint32_t h4 = seed; + + constexpr uint32_t c1 = 0x239b961b; + constexpr uint32_t c2 = 0xab0e9789; + constexpr uint32_t c3 = 0x38b34ae5; + constexpr uint32_t c4 = 0xa1e38b93; + + //---------- + // body + + const uint32_t* blocks = (const uint32_t*)(data + static_cast(nblocks) * 16); + + for (int i = -nblocks; i; i++) { + uint32_t k1 = getblock32(blocks, i * 4 + 0); + uint32_t k2 = getblock32(blocks, i * 4 + 1); + uint32_t k3 = getblock32(blocks, i * 4 + 2); + uint32_t k4 = getblock32(blocks, i * 4 + 3); + + k1 *= c1; + k1 = ROTL32(k1, 15); + k1 *= c2; + h1 ^= k1; + + h1 = ROTL32(h1, 19); + h1 += h2; + h1 = h1 * 5 + 0x561ccd1b; + + k2 *= c2; + k2 = ROTL32(k2, 16); + k2 *= c3; + h2 ^= k2; + + h2 = ROTL32(h2, 17); + h2 += h3; + h2 = h2 * 5 + 0x0bcaa747; + + k3 *= c3; + k3 = ROTL32(k3, 17); + k3 *= c4; + h3 ^= k3; + + h3 = ROTL32(h3, 15); + h3 += h4; + h3 = h3 * 5 + 0x96cd1c35; + + k4 *= c4; + k4 = ROTL32(k4, 18); + k4 *= c1; + h4 ^= k4; + + h4 = ROTL32(h4, 13); + h4 += h1; + h4 = h4 * 5 + 0x32ac3b17; + } + + //---------- + // tail + + const uint8_t* tail = (const uint8_t*)(data + static_cast(nblocks) * 16); + + uint32_t k1 = 0; + uint32_t k2 = 0; + uint32_t k3 = 0; + uint32_t k4 = 0; + + switch (len & 15) { + case 15: + k4 ^= tail[14] << 16; + [[fallthrough]]; + case 14: + k4 ^= tail[13] << 8; + [[fallthrough]]; + case 13: + k4 ^= tail[12] << 0; + k4 *= c4; + k4 = ROTL32(k4, 18); + k4 *= c1; + h4 ^= k4; + [[fallthrough]]; + case 12: + k3 ^= tail[11] << 24; + [[fallthrough]]; + case 11: + k3 ^= tail[10] << 16; + [[fallthrough]]; + case 10: + k3 ^= tail[9] << 8; + [[fallthrough]]; + case 9: + k3 ^= tail[8] << 0; + k3 *= c3; + k3 = ROTL32(k3, 17); + k3 *= c4; + h3 ^= k3; + [[fallthrough]]; + case 8: + k2 ^= tail[7] << 24; + [[fallthrough]]; + case 7: + k2 ^= tail[6] << 16; + [[fallthrough]]; + case 6: + k2 ^= tail[5] << 8; + [[fallthrough]]; + case 5: + k2 ^= tail[4] << 0; + k2 *= c2; + k2 = ROTL32(k2, 16); + k2 *= c3; + h2 ^= k2; + [[fallthrough]]; + case 4: + k1 ^= tail[3] << 24; + [[fallthrough]]; + case 3: + k1 ^= tail[2] << 16; + [[fallthrough]]; + case 2: + k1 ^= tail[1] << 8; + [[fallthrough]]; + case 1: + k1 ^= tail[0] << 0; + k1 *= c1; + k1 = ROTL32(k1, 15); + k1 *= c2; + h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; + h2 ^= len; + h3 ^= len; + h4 ^= len; + + h1 += h2; + h1 += h3; + h1 += h4; + h2 += h1; + h3 += h1; + h4 += h1; + + h1 = fmix32(h1); + h2 = fmix32(h2); + h3 = fmix32(h3); + h4 = fmix32(h4); + + h1 += h2; + h1 += h3; + h1 += h4; + h2 += h1; + h3 += h1; + h4 += h1; + + ((uint32_t*)out)[0] = h1; + ((uint32_t*)out)[1] = h2; + ((uint32_t*)out)[2] = h3; + ((uint32_t*)out)[3] = h4; +} + +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/murmurhash3.h b/plugin_execution_providers/tensorrt/utils/murmurhash3.h new file mode 100644 index 00000000..ab86a3e5 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/murmurhash3.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { +struct MurmurHash3 { + // generate 32-bit hash from input and write to 'out' + static void x86_32(const void* key, int len, uint32_t seed, void* out); + + // generate 128-bit hash from input and write to 'out'. + static void x86_128(const void* key, int len, uint32_t seed, void* out); +}; +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/parse_string.h b/plugin_execution_providers/tensorrt/utils/parse_string.h new file mode 100644 index 00000000..ce404607 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/parse_string.h @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "common.h" + +namespace onnxruntime { + +/** + * Tries to parse a value from an entire string. + */ +template +bool TryParseStringWithClassicLocale(std::string_view str, T& value) { + if constexpr (std::is_integral::value && std::is_unsigned::value) { + // if T is unsigned integral type, reject negative values which will wrap + if (!str.empty() && str[0] == '-') { + return false; + } + } + + // don't allow leading whitespace + if (!str.empty() && std::isspace(str[0], std::locale::classic())) { + return false; + } + + std::istringstream is{std::string{str}}; + is.imbue(std::locale::classic()); + T parsed_value{}; + + const bool parse_successful = + is >> parsed_value && + is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters + if (!parse_successful) { + return false; + } + + value = std::move(parsed_value); + return true; +} + +inline bool TryParseStringWithClassicLocale(std::string_view str, std::string& value) { + value = str; + return true; +} + +inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) { + if (str == "0" || str == "False" || str == "false") { + value = false; + return true; + } + + if (str == "1" || str == "True" || str == "true") { + value = true; + return true; + } + + return false; +} + +/** + * Parses a value from an entire string. + */ +template +Status ParseStringWithClassicLocale(std::string_view s, T& value) { + ORT_RETURN_IF_NOT(TryParseStringWithClassicLocale(s, value), "Failed to parse value: \"", value, "\""); + return Status::OK(); +} + +/** + * Parses a value from an entire string. + */ +template +T ParseStringWithClassicLocale(std::string_view s) { + T value{}; + ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(s, value)); + return value; +} + +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/path_string.h b/plugin_execution_providers/tensorrt/utils/path_string.h new file mode 100644 index 00000000..fd638aa5 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/path_string.h @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +// for std::tolower or std::towlower +#ifdef _WIN32 +#include +#else +#include +#endif + +// for converting / printing ORT_TSTR path strings to std::string +#ifdef _WIN32 +#define ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(X) std::wstring_convert>().to_bytes(X) +#define ORT_TSTR_CONVERT_FROM_STRING(X) std::wstring_convert>().from_bytes(X); +#else +#define ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(X) X +#define ORT_TSTR_CONVERT_FROM_STRING(X) X +#endif + +//#include "core/common/common.h" +//#include "core/session/onnxruntime_c_api.h" + +//#include "common.h" + +namespace onnxruntime { +// char type for filesystem paths +using PathChar = ORTCHAR_T; +// string type for filesystem paths +using PathString = std::basic_string; + +inline PathString ToPathString(const PathString& s) { + return s; +} + +#ifdef _WIN32 + +static_assert(std::is_same::value, "PathString is not std::wstring!"); + +inline PathString ToPathString(const std::string& s) { + return ToWideString(s); +} + +inline PathChar ToLowerPathChar(PathChar c) { + return std::towlower(c); +} + +inline std::string PathToUTF8String(const PathString& s) { + return ToUTF8String(s); +} + +#else + +static_assert(std::is_same::value, "PathString is not std::string!"); + +inline PathChar ToLowerPathChar(PathChar c) { + return std::tolower(c); +} + +inline std::string PathToUTF8String(const PathString& s) { + return s; +} + +#endif + +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/provider_options.h b/plugin_execution_providers/tensorrt/utils/provider_options.h new file mode 100644 index 00000000..aab13e80 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/provider_options.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { + +// data types for execution provider options + +using ProviderOptions = std::unordered_map; +using ProviderOptionsVector = std::vector; +using ProviderOptionsMap = std::unordered_map; + +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/provider_options_utils.h b/plugin_execution_providers/tensorrt/utils/provider_options_utils.h new file mode 100644 index 00000000..c7380b36 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/provider_options_utils.h @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "common.h" +#include "parse_string.h" +#include "provider_options.h" + +namespace onnxruntime { + +template +using EnumNameMapping = std::vector>; + +/** + * Given a mapping and an enumeration value, gets the corresponding name. + */ +template +Status EnumToName(const EnumNameMapping& mapping, TEnum value, std::string& name) { + const auto it = std::find_if( + mapping.begin(), mapping.end(), + [&value](const std::pair& entry) { + return entry.first == value; + }); + ORT_RETURN_IF( + it == mapping.end(), + "Failed to map enum value to name: ", static_cast::type>(value)); + name = it->second; + return Status::OK(); +} + +template +std::string EnumToName(const EnumNameMapping& mapping, TEnum value) { + std::string name; + ORT_THROW_IF_ERROR(EnumToName(mapping, value, name)); + return name; +} + +/** + * Given a mapping and a name, gets the corresponding enumeration value. + */ +template +Status NameToEnum( + const EnumNameMapping& mapping, const std::string& name, TEnum& value) { + const auto it = std::find_if( + mapping.begin(), mapping.end(), + [&name](const std::pair& entry) { + return entry.second == name; + }); + ORT_RETURN_IF( + it == mapping.end(), + "Failed to map enum name to value: ", name); + value = it->first; + return Status::OK(); +} + +template +TEnum NameToEnum(const EnumNameMapping& mapping, const std::string& name) { + TEnum value; + ORT_THROW_IF_ERROR(NameToEnum(mapping, name, value)); + return value; +} + +class ProviderOptionsParser { + public: + /** + * Adds a parser for a particular provider option value. + * + * @param name The provider option name. + * @param value_parser An object that parses the option value. + * It should be callable with the following signature and return + * whether the parsing was successful: + * Status value_parser(const std::string&) + * + * @return The current ProviderOptionsParser instance. + */ + template + ProviderOptionsParser& AddValueParser( + const std::string& name, ValueParserType value_parser) { + ORT_ENFORCE( + value_parsers_.emplace(name, ValueParser{value_parser}).second, + "Provider option \"", name, "\" already has a value parser."); + return *this; + } + + /** + * Adds a parser for a particular provider option value which converts a + * value to the right type and assigns it to the given reference. + * + * IMPORTANT: This function stores a reference to the destination variable. + * The caller must ensure that the reference is valid when Parse() is called! + * + * @param name The provider option name. + * @param dest The destination variable reference. + * + * @return The current ProviderOptionsParser instance. + */ + template + ProviderOptionsParser& AddAssignmentToReference( + const std::string& name, ValueType& dest) { + return AddValueParser( + name, + [&dest](const std::string& value_str) -> Status { + return ParseStringWithClassicLocale(value_str, dest); + }); + } + + /** + * Adds a parser for a particular provider option value which maps an + * enumeration name to a value and assigns it to the given reference. + * + * IMPORTANT: This function stores references to the mapping and destination + * variables. The caller must ensure that the references are valid when + * Parse() is called! + * + * @param name The provider option name. + * @param mapping The enumeration value to name mapping. + * @param dest The destination variable reference. + * + * @return The current ProviderOptionsParser instance. + */ + template + ProviderOptionsParser& AddAssignmentToEnumReference( + const std::string& name, const EnumNameMapping& mapping, EnumType& dest) { + return AddValueParser( + name, + [&mapping, &dest](const std::string& value_str) -> Status { + return NameToEnum(mapping, value_str, dest); + }); + } + + /** + * Parses the given provider options. + */ + Status Parse(const ProviderOptions& options) const { + for (const auto& option : options) { + const auto& name = option.first; + const auto& value_str = option.second; + const auto value_parser_it = value_parsers_.find(name); + ORT_RETURN_IF( + value_parser_it == value_parsers_.end(), + "Unknown provider option: \"", name, "\"."); + + const auto parse_status = value_parser_it->second(value_str); + ORT_RETURN_IF_NOT( + parse_status.IsOK(), + "Failed to parse provider option \"", name, "\": ", parse_status.ErrorMessage()); + } + + return Status::OK(); + } + + private: + using ValueParser = std::function; + std::unordered_map value_parsers_; +}; + +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/status.cc b/plugin_execution_providers/tensorrt/utils/status.cc new file mode 100644 index 00000000..b3a89c8c --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/status.cc @@ -0,0 +1,91 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Modifications Copyright (c) Microsoft. + +#include "status.h" +#include "common.h" + +namespace onnxruntime { +namespace common { +Status::Status(StatusCategory category, int code, const std::string& msg) { + // state_ will be allocated here causing the status to be treated as a failure + ORT_ENFORCE(code != static_cast(common::OK)); + + state_ = std::make_unique(category, code, msg); +} + +Status::Status(StatusCategory category, int code, const char* msg) { + // state_ will be allocated here causing the status to be treated as a failure + ORT_ENFORCE(code != static_cast(common::OK)); + + state_ = std::make_unique(category, code, msg); +} + +Status::Status(StatusCategory category, int code) + : Status(category, code, "") { +} + +StatusCategory Status::Category() const noexcept { + return IsOK() ? common::NONE : state_->category; +} + +int Status::Code() const noexcept { + return IsOK() ? static_cast(common::OK) : state_->code; +} + +const std::string& Status::ErrorMessage() const noexcept { + return IsOK() ? EmptyString() : state_->msg; +} + +std::string Status::ToString() const { + if (state_ == nullptr) { + return std::string("OK"); + } + + std::string result; + + if (common::SYSTEM == state_->category) { + result += "SystemError"; + result += " : "; + result += std::to_string(errno); + } else if (common::ONNXRUNTIME == state_->category) { + result += "[ONNXRuntimeEPError]"; + result += " : "; + result += std::to_string(Code()); + result += " : "; + result += StatusCodeToString(static_cast(Code())); + result += " : "; + result += state_->msg; + } + + return result; +} + +// GSL_SUPRESS(i.22) is broken. Ignore the warnings for the static local variables that are trivial +// and should not have any destruction order issues via pragmas instead. +// https://developercommunity.visualstudio.com/content/problem/249706/gslsuppress-does-not-work-for-i22-c-core-guideline.html +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 26426) +#endif + +const std::string& Status::EmptyString() noexcept { + static std::string s_empty; + return s_empty; +} + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +} // namespace common +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/status.h b/plugin_execution_providers/tensorrt/utils/status.h new file mode 100644 index 00000000..80bf7caf --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/status.h @@ -0,0 +1,192 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Modifications Copyright (c) Microsoft. + +#pragma once + +#include +#include +#include +#ifdef _WIN32 +#include +#endif + +namespace onnxruntime { +namespace common { + +enum StatusCategory { + NONE = 0, + SYSTEM = 1, + ONNXRUNTIME = 2, +}; + +/** + Error code for ONNXRuntime. +*/ +enum StatusCode { + OK = 0, + FAIL = 1, + INVALID_ARGUMENT = 2, + NO_SUCHFILE = 3, + NO_MODEL = 4, + ENGINE_ERROR = 5, + RUNTIME_EXCEPTION = 6, + INVALID_PROTOBUF = 7, + MODEL_LOADED = 8, + NOT_IMPLEMENTED = 9, + INVALID_GRAPH = 10, + EP_FAIL = 11 +}; + +constexpr const char* StatusCodeToString(StatusCode status) noexcept { + switch (status) { + case StatusCode::OK: + return "SUCCESS"; + case StatusCode::FAIL: + return "FAIL"; + case StatusCode::INVALID_ARGUMENT: + return "INVALID_ARGUMENT"; + case StatusCode::NO_SUCHFILE: + return "NO_SUCHFILE"; + case StatusCode::NO_MODEL: + return "NO_MODEL"; + case StatusCode::ENGINE_ERROR: + return "ENGINE_ERROR"; + case StatusCode::RUNTIME_EXCEPTION: + return "RUNTIME_EXCEPTION"; + case StatusCode::INVALID_PROTOBUF: + return "INVALID_PROTOBUF"; + case StatusCode::MODEL_LOADED: + return "MODEL_LOADED"; + case StatusCode::NOT_IMPLEMENTED: + return "NOT_IMPLEMENTED"; + case StatusCode::INVALID_GRAPH: + return "INVALID_GRAPH"; + case StatusCode::EP_FAIL: + return "EP_FAIL"; + default: + return "GENERAL ERROR"; + } +} + +#ifdef _WIN32 +constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept { + switch (status) { + case StatusCode::OK: + return S_OK; + case StatusCode::FAIL: + return E_FAIL; + case StatusCode::INVALID_ARGUMENT: + return E_INVALIDARG; + case StatusCode::NO_SUCHFILE: + return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); + case StatusCode::NO_MODEL: + return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); + case StatusCode::ENGINE_ERROR: + return E_FAIL; + case StatusCode::RUNTIME_EXCEPTION: + return E_FAIL; + case StatusCode::INVALID_PROTOBUF: + return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); + case StatusCode::MODEL_LOADED: + return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + case StatusCode::NOT_IMPLEMENTED: + return E_NOTIMPL; + case StatusCode::INVALID_GRAPH: + return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); + case StatusCode::EP_FAIL: + return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + default: + return E_FAIL; + } +} +#endif + +class [[nodiscard]] Status { + public: + Status() noexcept = default; + + Status(StatusCategory category, int code, const std::string& msg); + + Status(StatusCategory category, int code, const char* msg); + + Status(StatusCategory category, int code); + + Status(const Status& other) + : state_((other.state_ == nullptr) ? nullptr : new State(*other.state_)) {} + Status& operator=(const Status& other) { + if (state_ != other.state_) { + if (other.state_ == nullptr) { + state_.reset(); + } else { + state_.reset(new State(*other.state_)); + } + } + return *this; + } + + Status(Status&&) = default; + Status& operator=(Status&&) = default; + ~Status() = default; + + bool IsOK() const { + return (state_ == nullptr); + } + + int Code() const noexcept; + + StatusCategory Category() const noexcept; + + const std::string& ErrorMessage() const noexcept; + + std::string ToString() const; + + bool operator==(const Status& other) const { + return (this->state_ == other.state_) || (ToString() == other.ToString()); + } + + bool operator!=(const Status& other) const { + return !(*this == other); + } + + static Status OK() { + return Status(); + } + + private: + static const std::string& EmptyString() noexcept; + + struct State { + State(StatusCategory cat0, int code0, const std::string& msg0) + : category(cat0), code(code0), msg(msg0) {} + + State(StatusCategory cat0, int code0, const char* msg0) + : category(cat0), code(code0), msg(msg0) {} + + const StatusCategory category; + const int code; + const std::string msg; + }; + + // As long as Code() is OK, state_ == nullptr. + std::unique_ptr state_; +}; + +inline std::ostream& operator<<(std::ostream& out, const Status& status) { + return out << status.ToString(); +} +} // namespace common + +// make Status directly available in the onnxruntime namespace as it is widely used +using common::Status; + +} // namespace onnxruntime From ed65a9fc98059b280b96998148f79dfd55fed4ec Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Sun, 22 Jun 2025 22:41:23 -0700 Subject: [PATCH 02/60] clean up GetCapabilityImpl and make it pass compiler for now --- .../tensorrt/tensorrt_execution_provider.cc | 46 ++++++++++++------- .../tensorrt/tensorrt_execution_provider.h | 5 +- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 879115f7..cd72f838 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -1330,8 +1330,9 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph OrtEpGraphSupportInfo* graph_support_info) { TensorrtExecutionProvider* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; - /* + // Get ModelPath + /* const std::filesystem::path* model_path = nullptr; graph_api_->OrtGraph_GetModelPath(graph, reinterpret_cast(&model_path)); const auto& path_string = model_path->string(); @@ -1387,6 +1388,8 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; bool new_subgraph = true; + std::unordered_set control_flow_op_set = {"If", "Loop", "Scan"}; + /* Iterate all the nodes and exclude the node if: * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. * 2. Its op type is in the exclusion list. @@ -1407,7 +1410,7 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph const char* op_type = nullptr; RETURN_IF_ERROR(ep->ort_api.Node_GetOperatorType(node, &op_type)); - if (ep->control_flow_op_set_.find(op_type) != ep->control_flow_op_set_.end()) { + if (control_flow_op_set.find(op_type) != control_flow_op_set.end()) { auto supported_control_flow_op = [&](const OrtNode* node) { OrtStatus* status = nullptr; size_t num_subgraphs = 0; @@ -1467,8 +1470,14 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph } } + + // Use this local definitions for now + // TODO: Use provider option + int max_partition_iterations = 1000; + int min_subgraph_size = 1; + bool early_termination = false; - supported_nodes_vector = ep->GetSupportedList(parser_nodes_vector, 0, p->max_partition_iterations_, graph, &early_termination); + supported_nodes_vector = ep->GetSupportedList(parser_nodes_vector, 0, max_partition_iterations, graph, &early_termination); if (early_termination) { supported_nodes_vector.clear(); } @@ -1476,15 +1485,16 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph // Remove subgraphs if its size is less than the predefined minimal size for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end(); ++it) { const size_t subgraph_size = it->first.size(); - if (subgraph_size < p->min_subgraph_size_) { + if (subgraph_size < min_subgraph_size) { supported_nodes_vector.erase(it--); } } // Detect and remove cycles from supported node list - //p->DetectTensorRTGraphCycles(supported_nodes_vector, graph, model_hash); + /* ep->DetectTensorRTGraphCycles(supported_nodes_vector, graph, model_hash); */ // Consolidate supported node list + /* if (supported_nodes_vector.size() > 1) { nodes_vector.clear(); for (const auto& group : supported_nodes_vector) { @@ -1500,11 +1510,12 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph supported_nodes_vector = consolidated_supported_nodes_vector; } } + */ - std::vector cache; // Handle the case where the graph is subgraph of control flow op. // The purpose is to make control flow op as well as its subgraphs run on TRT. // Here we need to check whether subgraph is fully supported by TRT and don't fuse the nodes of the subgraph until control flow op level. + /* if (p->IsSubGraphOfControlFlowOp(graph) && p->IsSubGraphFullySupported(supported_nodes_vector, number_of_ort_nodes)) { bool all_subgraphs_are_supported = true; @@ -1580,32 +1591,33 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph return; } } + */ - int number_of_trt_nodes = 0, subgraph_index = 0; + int number_of_trt_nodes = 0; for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { - std::unique_ptr sub_graph = p->GetSubGraph(group, graph, model_hash, subgraph_index); - cache.push_back(sub_graph.release()); + std::vector supported_nodes; + for (const auto& index : group.first) { + const OrtNode* supported_node = nullptr; + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetElementAt(nodes_container, index, + reinterpret_cast(&supported_node))); + supported_nodes.push_back(supported_node); + } + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, supported_nodes.data(), + supported_nodes.size())); number_of_trt_nodes += static_cast(group.first.size()); - subgraph_index++; } } const size_t number_of_subgraphs = supported_nodes_vector.size(); if (number_of_trt_nodes == 0) { // LOGS_DEFAULT(WARNING) << "[TensorRT EP] No graph will run on TensorRT execution provider"; - } else if (number_of_trt_nodes == number_of_ort_nodes) { + } else if (number_of_trt_nodes == nodes.size()) { // LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; } else { // LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; } - *cnt = cache.size(); - *indexed_sub_graph = new OrtIndexedSubGraph*[*cnt]; - for (size_t i = 0; i < *cnt; i++) { - (*indexed_sub_graph)[i] = cache[i]; - } - return nullptr; } diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index 0e1c24c3..360b35b9 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -243,6 +243,8 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs { const OrtSessionOptions& session_options_; const OrtLogger& logger_; + SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations, const int max_iterations, + const OrtGraph* graph, bool* early_termination) const; /* bool IsGraphCaptured(int graph_annotation_id) const { return false; } @@ -283,7 +285,7 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs { std::unordered_map> dynamic_range_map_; std::unordered_map cache_suffix_; - //private: + private: mutable TensorrtExecutionProviderInfo info_; bool external_stream_ = false; cudaStream_t stream_ = nullptr; @@ -346,7 +348,6 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs { // std::unique_ptr model_proto_ = ONNX_NAMESPACE::ModelProto::Create(); - std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; // mutable std::unordered_map> subgraph_context_map_; mutable std::unique_ptr builder_; From 3269f73a0029313e58454f2c9f8c9334a2b3e402 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Sun, 22 Jun 2025 23:42:10 -0700 Subject: [PATCH 03/60] Clean up CompileImpl --- .../tensorrt/tensorrt_execution_provider.cc | 59 +++++++++---------- .../tensorrt/tensorrt_execution_provider.h | 25 ++++---- 2 files changed, 38 insertions(+), 46 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index cd72f838..bbf93059 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -1642,52 +1642,46 @@ static OrtStatus* ORT_API_CALL CompileImpl(OrtEp* this_ptr, const OrtGraph** gra gsl::span node_inputs{}; gsl::span node_outputs{}; - RETURN_IF_ERROR(GetSpanFromConstPointerArray(inputs_array, node_inputs)); - RETURN_IF_ERROR(GetSpanFromConstPointerArray(outputs_array, node_outputs)); + GetSpanFromArrayOfConstObjects(inputs_array, node_inputs); + GetSpanFromArrayOfConstObjects(outputs_array, node_outputs); // Gets number of node's inputs and outputs size_t num_node_inputs = 0; size_t num_node_outputs = 0; - RETURN_IF_ERROR(ep->ort_api.ConstPointerArray_GetSize(inputs_array, &num_node_inputs)); - RETURN_IF_ERROR(ep->ort_api.ConstPointerArray_GetSize(outputs_array, &num_node_outputs)); + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(inputs_array, &num_node_inputs)); + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(outputs_array, &num_node_outputs)); // Builds map from input name to its index in input list std::unordered_map input_map; input_map.reserve(num_node_inputs); - for (size_t i = 0, i < num_node_inputs; i++) { - std::string& name = node_inputs[i]->GetName(); - input_map[name] = i; + for (size_t i = 0; i < num_node_inputs; i++) { + // TODO: Add ValueInfo_GetName() c api + //std::string& name = node_inputs[i]->GetName(); + //input_map[name] = i; } // Builds map from output name to its index in output list - std::unordered_map out_map; + std::unordered_map output_map; input_map.reserve(num_node_outputs); - for (size_t i = 0, i < num_node_outputs; i++) { - std::string& name = node_outputs[i]->GetName(); - out_map[name] = i; - } - - Status status; - if (GraphHasCtxNode(graph_body_viewer)) { - status = ep->CreateNodeComputeInfoFromPrecompiledEngine(graph_body_viewer, - fused_node, - input_map, - output_map, - node_compute_funcs); + for (size_t i = 0; i < num_node_outputs; i++) { + // TODO: Add ValueInfo_GetName() c api + //std::string& name = node_outputs[i]->GetName(); + //output_map[name] = i; + } + + OrtStatus* status; + //if (GraphHasCtxNode(graph_body_viewer)) { + if (false) { + status = ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[graph_idx], fused_node, + input_map, + output_map, node_compute_infos[graph_idx]); } else { - status = ep->CreateNodeComputeInfoFromGraph(graph_body_viewer, fused_node, input_map, output_map, node_compute_funcs); + status = ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[graph_idx], fused_node, input_map, + output_map, node_compute_infos[graph_idx]); } - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); - } - - /* - OrtArrayOfConstObjects* nodes_array = nullptr; - DeferOrtRelease release_nodes(&nodes_array, ep->ort_api.ReleaseArrayOfConstObjects); - size_t num_nodes = 0; - RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graphs[graph_idx], &nodes_array)); - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(nodes_array, &num_nodes)); - */ + //if (status != Status::OK()) { + // return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); + //} } return nullptr; @@ -1724,6 +1718,7 @@ struct TensorrtExecutionProvider : TensorrtExecutionProvider(ApiPtrs apis, const // The implementation of the SessionOptionsAppendExecutionProvider C API function automatically adds EP options to // the session option configurations with the key prefix "ep..". const std::string key_prefix = OrtSessionOptions::GetProviderOptionPrefix(name_.c_str()); + const ConfigOptions& config_options = session_options.GetConfigOptions(); const std::unordered_map& config_options_map = config_options.GetConfigOptionsMap(); // Get provider options as key-value pair strings diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index 360b35b9..eace15b4 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -246,6 +246,17 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs { SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations, const int max_iterations, const OrtGraph* graph, bool* early_termination) const; + OrtStatus* CreateNodeComputeInfoFromPrecompiledEngine(OrtEp* this_ptr, const OrtGraph* graphs, + const OrtNode* fused_nodes, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo* node_compute_infos); + + OrtStatus* CreateNodeComputeInfoFromGraph(OrtEp* this_ptr, const OrtGraph* graphs, const OrtNode* fused_nodes, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo* node_compute_infos); + /* bool IsGraphCaptured(int graph_annotation_id) const { return false; } @@ -386,20 +397,6 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs { // to allocate enough memory in Arena before graph capturing. const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. - OrtStatus* CreateNodeComputeInfoFromPrecompiledEngine(OrtEp* this_ptr, - const OrtGraph** graphs, - const OrtNode** fused_nodes, - std::unordered_map& input_map, - std::unordered_map& output_map, - OrtNodeComputeInfo** node_compute_infos); - - OrtStatus* CreateNodeComputeInfoFromGraph(OrtEp* this_ptr, - const OrtGraph** graphs, - const OrtNode** fused_nodes, - std::unordered_map& input_map, - std::unordered_map& output_map, - OrtNodeComputeInfo** node_compute_infos); - bool IsGraphCaptureAllowed() const { return false; }; nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const; From 4da9f90d8cb8e25a0a836929a7ca8ee4b98b317e Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 24 Jun 2025 22:56:49 -0700 Subject: [PATCH 04/60] update ep factory --- .../tensorrt/tensorrt_provider_factory.cc | 26 ++++++++++--------- .../tensorrt/tensorrt_provider_factory.h | 4 +-- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index e2f8a50e..3e224248 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -1,7 +1,8 @@ #define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" #undef ORT_API_MANUAL_INIT -#include +#include "tensorrt_provider_factory.h" +#include "tensorrt_execution_provider.h" #include #include @@ -11,7 +12,7 @@ #include #include -struct TensorrtExecutionProvider; +//struct TensorrtExecutionProvider; static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) { const auto* factory = static_cast(this_ptr); @@ -30,21 +31,22 @@ static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, size_t max_ep_devices, size_t* p_num_ep_devices) { size_t& num_ep_devices = *p_num_ep_devices; - auto* factory = static_cast(this_ptr); + auto* factory = static_cast(this_ptr); for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { // C API const OrtHardwareDevice& device = *devices[i]; - if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { - // these can be returned as nullptr if you have nothing to add. + if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // These can be returned as nullptr if you have nothing to add. OrtKeyValuePairs* ep_metadata = nullptr; OrtKeyValuePairs* ep_options = nullptr; factory->ort_api.CreateKeyValuePairs(&ep_metadata); factory->ort_api.CreateKeyValuePairs(&ep_options); - // random example using made up values - factory->ort_api.AddKeyValuePair(ep_metadata, "version", "0.1"); - factory->ort_api.AddKeyValuePair(ep_options, "run_really_fast", "true"); + // The ep options can be provided here as default values. + // Users can also call SessionOptionsAppendExecutionProvider_V2 C API with provided ep options to override. + factory->ort_api.AddKeyValuePair(ep_metadata, "version", "0.1"); // random example using made up values + factory->ort_api.AddKeyValuePair(ep_options, "trt_builder_optimization_level", "3"); // OrtEpDevice copies ep_metadata and ep_options. auto* status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, @@ -61,11 +63,11 @@ static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, // C++ API equivalent. Throws on error. //{ // Ort::ConstHardwareDevice device(devices[i]); - // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { // Ort::KeyValuePairs ep_metadata; // Ort::KeyValuePairs ep_options; // ep_metadata.Add("version", "0.1"); - // ep_options.Add("run_really_fast", "true"); + // ep_options.Add("trt_builder_optimization_level", "3"); // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()}; // ep_devices[num_ep_devices++] = ep_device.release(); // } @@ -102,14 +104,14 @@ static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, // const OrtHardwareDevice* device = devices[0]; // const OrtKeyValuePairs* ep_metadata = ep_metadata[0]; - auto dummy_ep = std::make_unique(*factory, factory->ep_name_, *session_options, *logger); + auto dummy_ep = std::make_unique(*factory, factory->ep_name_, *session_options, *logger); *ep = dummy_ep.release(); return nullptr; } static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) { - ExampleEp* dummy_ep = static_cast(ep); + onnxruntime::TensorrtExecutionProvider* dummy_ep = static_cast(ep); delete dummy_ep; } diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h index 63a9a40a..94283548 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -39,11 +39,9 @@ struct ApiPtrs { const OrtEpApi& ep_api; }; -/// /// /// Plugin TensorRT EP factory that can create an OrtEp and return information about the supported hardware devices. /// -/// struct TensorrtExecutionProviderFactory : OrtEpFactory, ApiPtrs { TensorrtExecutionProviderFactory(const char* ep_name, ApiPtrs apis) : ApiPtrs(apis), ep_name_{ep_name} { ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. @@ -53,6 +51,6 @@ struct TensorrtExecutionProviderFactory : OrtEpFactory, ApiPtrs { CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; } - const std::string ep_name_; // EP name + const std::string ep_name_; // EP name const std::string vendor_{"Nvidia"}; // EP vendor name }; \ No newline at end of file From 1928767feb377b28aafa5f681ab1d97455b04918 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 24 Jun 2025 23:06:32 -0700 Subject: [PATCH 05/60] update ep factory --- .../tensorrt/tensorrt_provider_factory.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index 3e224248..03ee8423 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -88,17 +88,17 @@ static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, *ep = nullptr; if (num_devices != 1) { - // we only registered for CPU and only expected to be selected for one CPU + // we only registered for GPU and only expected to be selected for one GPU // if you register for multiple devices (e.g. CPU, GPU and maybe NPU) you will get an entry for each device // the EP has been selected for. return factory->ort_api.CreateStatus(ORT_INVALID_ARGUMENT, - "Example EP only supports selection for one device."); + "TensorRT EP only supports selection for one device."); } // Create the execution provider RETURN_IF_ERROR(factory->ort_api.Logger_LogMessage(logger, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, - "Creating Example EP", ORT_FILE, __LINE__, __FUNCTION__)); + "Creating TensorRT EP", ORT_FILE, __LINE__, __FUNCTION__)); // use properties from the device and ep_metadata if needed // const OrtHardwareDevice* device = devices[0]; From 4f5ffcb59088a6adbcc45330aa8c72f04bac9178 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 24 Jun 2025 23:08:29 -0700 Subject: [PATCH 06/60] update ep factory --- .../tensorrt/tensorrt_provider_factory.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index 03ee8423..ebd36131 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -104,15 +104,15 @@ static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, // const OrtHardwareDevice* device = devices[0]; // const OrtKeyValuePairs* ep_metadata = ep_metadata[0]; - auto dummy_ep = std::make_unique(*factory, factory->ep_name_, *session_options, *logger); + auto trt_ep = std::make_unique(*factory, factory->ep_name_, *session_options, *logger); - *ep = dummy_ep.release(); + *ep = trt_ep.release(); return nullptr; } static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) { - onnxruntime::TensorrtExecutionProvider* dummy_ep = static_cast(ep); - delete dummy_ep; + onnxruntime::TensorrtExecutionProvider* trt_ep = static_cast(ep); + delete trt_ep; } // To make symbols visible on macOS/iOS From bc64bdca9078b19b6f60978a9ead14fe310cd3f4 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 25 Jun 2025 10:01:20 -0700 Subject: [PATCH 07/60] clean up and add back onnx_ctx_model_helper.cc --- ...el_helper.ccc => onnx_ctx_model_helper.cc} | 0 .../tensorrt/tensorrt_execution_provider.cc | 26 ++++++++----------- .../tensorrt/tensorrt_execution_provider.h | 2 -- 3 files changed, 11 insertions(+), 17 deletions(-) rename plugin_execution_providers/tensorrt/{onnx_ctx_model_helper.ccc => onnx_ctx_model_helper.cc} (100%) diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.ccc b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc similarity index 100% rename from plugin_execution_providers/tensorrt/onnx_ctx_model_helper.ccc rename to plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index bbf93059..df779e90 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -5,11 +5,15 @@ #include #include +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + #include "ep_abi_utils.h" #include "tensorrt_execution_provider.h" #include "tensorrt_execution_provider_utils.h" #include "tensorrt_cuda_allocator.h" -//#include "onnx_ctx_model_helper.h" +#include "onnx_ctx_model_helper.h" #include "onnx/onnx_pb.h" #include "cuda/unary_elementwise_ops_impl.h" @@ -1697,9 +1701,10 @@ static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) { /// Constructor of Plugin TensorRT EP /// /// -struct TensorrtExecutionProvider : TensorrtExecutionProvider(ApiPtrs apis, const std::string& name, const OrtHardwareDevice& device, - const OrtSessionOptions& session_options, const OrtLogger& logger) - : ApiPtrs(apis), name_{name}, hardware_device_{device}, session_options_{session_options}, logger_{logger} { +TensorrtExecutionProvider::TensorrtExecutionProvider(ApiPtrs apis, const std::string& name, + const OrtHardwareDevice& device, + const OrtSessionOptions& session_options, const OrtLogger& logger) + : ApiPtrs(apis), name_{name}, hardware_device_{device}, session_options_{session_options}, logger_{logger} { // Initialize the execution provider. auto status = ort_api.Logger_LogMessage(&logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, @@ -1731,9 +1736,9 @@ struct TensorrtExecutionProvider : TensorrtExecutionProvider(ApiPtrs apis, const // Provider options to TensorrtExecutionProviderInfo info_ = TensorrtExecutionProviderInfo::FromProviderOptions(provider_options); - if (ep_info.size() > 0) info_.has_trt_options = true; + info_.has_trt_options = true; device_id_ = info_.device_id; - api_->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, device_id_, &default_device); + //api_->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, device_id_, &default_device); std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; @@ -2167,15 +2172,6 @@ struct TensorrtExecutionProvider : TensorrtExecutionProvider(ApiPtrs apis, const } } -TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory() { - OrtExecutionProviderFactory::CreateExecutionProvider = [](OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size) -> OrtExecutionProvider* { - ProviderOptions options; - for (size_t i = 0; i < option_size; i++) options[ep_option_keys[i]] = ep_option_values[i]; - std::unique_ptr ret = std::make_unique(tensorrtEp.c_str(), std::move(options)); - return ret.release(); - }; -} - nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder(TensorrtLogger& trt_logger) const { if (!builder_) { { diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index eace15b4..3566d3dc 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -228,11 +228,9 @@ struct ApiPtrs { const OrtEpApi& ep_api; }; -/// /// /// Plugin TensorRT EP that implements OrtEp /// -/// struct TensorrtExecutionProvider : OrtEp, ApiPtrs { TensorrtExecutionProvider(ApiPtrs apis, const std::string& name, const OrtHardwareDevice& device, const OrtSessionOptions& session_options, const OrtLogger& logger); From c4437a2d19a78ce6c2de86a7efd0588348008691 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 25 Jun 2025 11:10:51 -0700 Subject: [PATCH 08/60] clean up --- .../tensorrt/tensorrt_execution_provider.cc | 2237 +---------------- .../tensorrt/tensorrt_execution_provider.h | 4 +- .../tensorrt_execution_provider_utils.h | 704 +++++- 3 files changed, 695 insertions(+), 2250 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index df779e90..cff6dfde 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -33,11 +33,12 @@ void CUDA_RETURN_IF_ERROR(cudaError_t res) { if (res != cudaSuccess) abort(); } -namespace onnxruntime { +//namespace onnxruntime { static const std::string tensorrtEp = "tensorrtEp"; const OrtApi& ort_api = Ort::GetApi(); +/* struct MemcpyFromHost : OrtCustomOp { MemcpyFromHost() { OrtCustomOp::version = ORT_API_VERSION; @@ -81,29 +82,11 @@ struct MemcpyFromHost : OrtCustomOp { OrtCustomOp::GetStartVersion = [](const struct OrtCustomOp* op) { return 1; }; } }; +*/ template using IAllocatorUniquePtr = std::unique_ptr>; -// Check if cycle exists in the graph after partitioning -bool FindCycleHelper(size_t i, const std::list* adjacency_map, bool visited[], bool* st, std::vector& cycles) { - if (!visited[i]) { - visited[i] = true; - st[i] = true; - for (auto iter = adjacency_map[i].begin(); iter != adjacency_map[i].end(); ++iter) { - if (!visited[*iter] && FindCycleHelper(*iter, adjacency_map, visited, st, cycles)) { - cycles.push_back(*iter); - return true; - } else if (st[*iter]) { - cycles.push_back(*iter); - return true; - } - } - } - st[i] = false; - return false; -} - bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept { size_t alloc_size = size; if (alignment == 0) { @@ -128,195 +111,7 @@ IAllocatorUniquePtr MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator T* p = static_cast(ort_allocator->Alloc(ort_allocator, alloc_size)); - return IAllocatorUniquePtr{p, - [ort_allocator](T* p) { - ort_allocator->Free(ort_allocator, p); - }}; -} - -bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map& dynamic_range_map) { - // Set dynamic range for input tensors - for (int i = 0; i < network.getNbInputs(); ++i) { - const std::string tensor_name = network.getInput(i)->getName(); - auto dynamic_range_iter = dynamic_range_map.find(tensor_name); - if (dynamic_range_iter != dynamic_range_map.end()) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - if (!network.getInput(i)->setDynamicRange(-dynamic_range_iter->second, dynamic_range_iter->second)) { -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - // LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for network input " << tensor_name; - return false; - } - } - } - - // Set dynamic range for activations and weights - for (int i = 0; i < network.getNbLayers(); ++i) { - auto trt_layer = network.getLayer(i); - for (int j = 0, e = trt_layer->getNbOutputs(); j < e; ++j) { - const std::string tensor_name = trt_layer->getOutput(j)->getName(); - auto dynamic_range_iter = dynamic_range_map.find(tensor_name); - if (dynamic_range_iter != dynamic_range_map.end()) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - if (!trt_layer->getOutput(j)->setDynamicRange(-dynamic_range_iter->second, dynamic_range_iter->second)) { -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - // LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for tensor " << tensor_name; - return false; - } - } else if (trt_layer->getType() == nvinfer1::LayerType::kCONSTANT) { - nvinfer1::IConstantLayer* const_layer = static_cast(trt_layer); - const std::string const_layer_name = const_layer->getName(); - auto trt_weights = const_layer->getWeights(); - double max_weight = std::numeric_limits::min(); - for (int64_t k = 0, end = trt_weights.count; k < end; ++k) { - double weight{}; - switch (trt_weights.type) { - case nvinfer1::DataType::kFLOAT: - weight = static_cast(trt_weights.values)[k]; - break; - case nvinfer1::DataType::kBOOL: - weight = static_cast(trt_weights.values)[k]; - break; - case nvinfer1::DataType::kINT8: - weight = static_cast(trt_weights.values)[k]; - break; - case nvinfer1::DataType::kHALF: - weight = static_cast(trt_weights.values)[k]; - break; - case nvinfer1::DataType::kINT32: - weight = static_cast(trt_weights.values)[k]; - break; -#if NV_TENSORRT_MAJOR >= 10 - case nvinfer1::DataType::kINT64: - weight = static_cast(static_cast(trt_weights.values)[k]); - break; -#endif // NV_TENSORRT_MAJOR >= 10 - default: - // LOGS_DEFAULT(ERROR) << "Found unsupported datatype for layer " << const_layer_name; - return false; - } - max_weight = std::max(max_weight, std::abs(weight)); - } -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - if (!trt_layer->getOutput(j)->setDynamicRange(static_cast(-max_weight), static_cast(max_weight))) { -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - // LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for layer " << const_layer_name; - return false; - } - } - } - } - return true; -} - -std::vector SplitToStringVec(std::string const& s, char separator) { - std::vector splitted; - - for (size_t start = 0; start < s.length();) { - size_t separatorIndex = s.find(separator, start); - if (separatorIndex == std::string::npos) { - separatorIndex = s.length(); - } - splitted.emplace_back(s.substr(start, separatorIndex - start)); - start = separatorIndex + 1; - } - - return splitted; -} - -nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { - nvinfer1::TacticSources disabledTactics = 0; - nvinfer1::TacticSources enabledTactics = 0; - std::vector tacticList = SplitToStringVec(tactic_string, ','); - for (auto& t : tacticList) { - bool enable{false}; - if (t.front() == '+') { - enable = true; - } else if (t.front() != '-') { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic source must be prefixed with + or - skipping: " << t; - } - t.erase(0, 1); - - const auto toUpper = [](std::string& sourceName) { - std::transform(sourceName.begin(), sourceName.end(), sourceName.begin(), - [](char c) { return static_cast(std::toupper(c)); }); - return sourceName; - }; - - nvinfer1::TacticSource source{}; - t = toUpper(t); - if (t == "CUBLAS") { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS is deprecated in TensorRT 10.0"; -#if NV_TENSORRT_MAJOR < 10 - source = nvinfer1::TacticSource::kCUBLAS; -#endif - } else if (t == "CUBLASLT" || t == "CUBLAS_LT") { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS_LT is deprecated in TensorRT 9.0"; -#if NV_TENSORRT_MAJOR < 9 - source = nvinfer1::TacticSource::kCUBLAS_LT; -#endif - } else if (t == "CUDNN") { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUDNN is deprecated in TensorRT 10.0"; -#if NV_TENSORRT_MAJOR < 10 - source = nvinfer1::TacticSource::kCUDNN; -#endif - } else if (t == "EDGE_MASK_CONVOLUTIONS") { - source = nvinfer1::TacticSource::kEDGE_MASK_CONVOLUTIONS; - } else if (t == "JIT_CONVOLUTIONS") { - source = nvinfer1::TacticSource::kJIT_CONVOLUTIONS; - } else { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic source was not found with name: " << t; - } - - uint32_t sourceBit = 1U << static_cast(source); - - if (enable) { - enabledTactics |= sourceBit; - } else { - disabledTactics |= sourceBit; - } - } - return enabledTactics & ~disabledTactics; -} - -inline std::vector loadTimingCacheFile(const std::string inFileName) { - std::ifstream iFile(inFileName, std::ios::in | std::ios::binary); - if (!iFile) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not read timing cache from: " << inFileName - // << ". A new timing cache will be generated and written."; - return std::vector(); - } - iFile.seekg(0, std::ifstream::end); - size_t fsize = iFile.tellg(); - iFile.seekg(0, std::ifstream::beg); - std::vector content(fsize); - iFile.read(content.data(), fsize); - iFile.close(); - return content; -} - -inline void saveTimingCacheFile(const std::string outFileName, const nvinfer1::IHostMemory* blob) { - std::ofstream oFile(outFileName, std::ios::out | std::ios::binary); - if (!oFile) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not write timing cache to: " << outFileName; - return; - } - oFile.write((char*)blob->data(), blob->size()); - oFile.close(); + return IAllocatorUniquePtr{p, [ort_allocator](T* p) { ort_allocator->Free(ort_allocator, p); }}; } #if NV_TENSORRT_MAJOR >= 10 @@ -983,353 +778,6 @@ OrtStatusPtr BindKernelOutput(Ort::KernelContext& ctx, return nullptr; } -/* -// Detect and remove cycles from supported node list -bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const OrtGraphViewer* graph, const HashValue& model_hash, bool remove_cycles) const { - const size_t* nodes_index = nullptr; - size_t node_count = 0; - graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_index, &node_count); - bool trt_cycle = true, cycle_detected = false; - while (trt_cycle) { - trt_cycle = false; - std::unordered_map node_to_index_map; - std::unordered_map index_to_node_map; - std::unordered_map> input_to_nodes_map, node_to_outputs_map; - std::unordered_set non_trt_node_index; - for (size_t i = 0; i < node_count; ++i) { - non_trt_node_index.insert(nodes_index[i]); - } - size_t id = 0; - int subgraph_index = 0; - for (const auto& group : supported_nodes_vector) { - if (!group.first.empty()) { - // Construct subgraph from node list - std::unique_ptr subgraph = GetSubGraph(group, graph, model_hash, subgraph_index); - - // Create node to inputs/outputs/index maps - const std::string node_name = subgraph->meta_def->name; - if (node_to_index_map.find(node_name) == node_to_index_map.end()) { - index_to_node_map[id] = node_name; - node_to_index_map[node_name] = id++; - } - - if (subgraph->meta_def != nullptr) { - for (size_t j = 0; j < subgraph->meta_def->input_len; j++) { - input_to_nodes_map[std::string(subgraph->meta_def->inputs[j])].insert(node_name); - } - for (size_t j = 0; j < subgraph->meta_def->output_len; j++) { - node_to_outputs_map[node_name].insert(std::string(subgraph->meta_def->outputs[j])); - } - } - - // Remove TensorRT nodes from node index list - for (const auto& index : group.first) { - non_trt_node_index.erase(nodes_index[index]); - } - subgraph_index++; - } - } - - // Add non TensorRT nodes to the maps - for (const auto& index : non_trt_node_index) { - const OrtNode* node = nullptr; - graph_api_->OrtGraph_GetOrtNode(graph, index, &node); - const char* node_name_char = nullptr; - graph_api_->OrtNode_GetName(node, &node_name_char); - const std::string node_name(node_name_char); - if (node_to_index_map.find(node_name) == node_to_index_map.end()) { - index_to_node_map[id] = node_name; - node_to_index_map[node_name] = id++; - } - - size_t input_count = 0; - graph_api_->OrtNode_GetNumInputs(node, &input_count); - for (size_t i = 0; i < input_count; ++i) { - const char* input_name_char = nullptr; - graph_api_->OrtNode_GetIthInputName(node, i, &input_name_char); - input_to_nodes_map[std::string(input_name_char)].insert(node_name); - } - - size_t implicit_input_count = 0; - graph_api_->OrtNode_GetImplicitInputSize(node, &implicit_input_count); - for (size_t i = 0; i < implicit_input_count; ++i) { - const char* input_name_char = nullptr; - graph_api_->OrtNode_GetIthImplicitInputName(node, i, &input_name_char); - input_to_nodes_map[std::string(input_name_char)].insert(node_name); - } - - size_t output_count = 0; - graph_api_->OrtNode_GetNumOutputs(node, &output_count); - for (size_t i = 0; i < output_count; ++i) { - const char* output_name_char = nullptr; - graph_api_->OrtNode_GetIthOutputName(node, i, &output_name_char); - node_to_outputs_map[node_name].insert(std::string(output_name_char)); - } - } - - // Create adjacency list - size_t graph_size = node_to_index_map.size(); - std::list* adjacency_map = new std::list[graph_size]; - for (const auto& node : node_to_outputs_map) { - for (auto iter = node.second.begin(); iter != node.second.end(); ++iter) { - const auto& loc = input_to_nodes_map.find(*iter); - if (loc != input_to_nodes_map.end()) { - size_t parent_node_index = node_to_index_map.find(node.first)->second; - for (auto child_node : loc->second) { - size_t child_node_index = node_to_index_map.find(child_node)->second; - adjacency_map[parent_node_index].push_back(child_node_index); - } - } - } - } - - // Check cycle in the graph - bool* visited = new bool[graph_size]; - bool* st = new bool[graph_size]; - for (size_t i = 0; i < graph_size; ++i) { - visited[i] = false; - st[i] = false; - } - - std::vector cycles; - bool has_cycle = false; - for (size_t i = 0; i < graph_size; ++i) { - if (FindCycleHelper(i, adjacency_map, visited, st, cycles)) { - has_cycle = true; - cycle_detected = true; - break; - } - } - - // Remove TensorRT subgraph from the supported node list if it's part of the cycle - if (has_cycle && remove_cycles) { - for (size_t i = 0; i < cycles.size(); ++i) { - auto loc = index_to_node_map.find(cycles[i]); - if (loc != index_to_node_map.end() && loc->second.find("TRTKernel") != std::string::npos) { - supported_nodes_vector.erase(supported_nodes_vector.begin() + cycles[i]); - trt_cycle = true; - break; - } - } - } - - delete[] adjacency_map; - delete[] visited; - delete[] st; - } - return cycle_detected; -} - -// Check the graph is the subgraph of control flow op -bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const OrtGraphViewer* graph) const { - bool is_subgraph = false; - graph_api_->OrtGraph_IsSubgraph(graph, &is_subgraph); - if (is_subgraph) { - const OrtNode* node = nullptr; - graph_api_->OrtGraph_GetParenNode(graph, &node); - const char* node_op_type = nullptr; - graph_api_->OrtNode_GetOpType(node, &node_op_type); - if (control_flow_op_set_.find(std::string(node_op_type)) != control_flow_op_set_.end()) { - return true; - } - } - return false; -} - -// Check whether all the nodes of the graph are assigned to specific ep -bool TensorrtExecutionProvider::AllNodesAssignedToSpecificEP(const OrtGraphViewer* graph, const std::string& provider_type) const { - size_t num_nodes = ort_api_.Graph_NumNodes(graph); - std::vector nodes(num_nodes, nullptr); - RETURN_IF_ERROR(ort_api_.Graph_GetNodes(graph, 1, nodes.data(), nodes.size())); - - for (const OrtNode* node : nodes) { - const char* node_ep_type = ort_api_.Node_GetExecutionProviderType(node); - if (strcmp(node_ep_type, provider_type.c_str())) { - return false; - } - } - return num_nodes != 0; -} - -// Check whether all the nodes of subgraph are supported -bool TensorrtExecutionProvider::IsSubGraphFullySupported(SubGraphCollection_t supported_nodes_vector, const int number_of_ort_nodes) const { - int number_of_trt_nodes = 0; - for (const auto& group : supported_nodes_vector) { - if (!group.first.empty()) { - number_of_trt_nodes += static_cast(group.first.size()); - } - } - - return number_of_trt_nodes == number_of_ort_nodes; -} - -std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, const OrtGraphViewer* graph, const HashValue& model_hash, int subgraph_index) const { - const size_t* node_index = nullptr; - size_t nodes_count = 0; - graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &node_index, &nodes_count); - std::unordered_set node_set; - node_set.reserve(graph_nodes_index.first.size()); - for (const auto& index : graph_nodes_index.first) { - node_set.insert(node_index[index]); - } - - // Get parent graph output names - std::unordered_set graph_output_names; - size_t graph_output_size = 0; - graph_api_->OrtGraph_GetOutputSize(graph, &graph_output_size); - for (size_t i = 0; i < graph_output_size; i++) { - char const* output_name = nullptr; - graph_api_->OrtGraph_GetIthOutputName(graph, i, &output_name); - graph_output_names.insert(output_name); - } - - // Find inputs and outputs of the subgraph - std::unique_ptr sub_graph = std::make_unique(); - sub_graph->node_index_len = graph_nodes_index.first.size(); - sub_graph->node_index = new size_t[sub_graph->node_index_len]; - sub_graph->meta_def = new OrtMetaDef(); - std::unordered_set erased; - std::unordered_map input_to_order; - std::unordered_map output_to_order; - int input_order = 0; - int output_order = 0; - - std::vector initializers; - int i = 0; - for (const auto& index : graph_nodes_index.first) { - sub_graph->node_index[i++] = node_index[index]; - const OrtNode* node = nullptr; - graph_api_->OrtGraph_GetOrtNode(graph, node_index[index], &node); - size_t input_size = 0; - graph_api_->OrtNode_GetNumInputs(node, &input_size); - for (size_t j = 0; j < input_size; j++) { - const char* input_name = nullptr; - graph_api_->OrtNode_GetIthInputName(node, j, &input_name); - bool is_initializer = false; - graph_api_->OrtGraph_IsConstantInitializer(graph, input_name, true, &is_initializer); - if (is_initializer) { - initializers.push_back(input_name); - continue; - } - const OrtNode* producer = nullptr; - graph_api_->OrtGraph_GetNodeProducingOutput(graph, input_name, &producer); - // If the input is not produced by any node, it is a graph input - if (producer == nullptr) { - input_to_order[input_name] = input_order++; - continue; - } - size_t producer_index = -1; - graph_api_->OrtNode_GetIndex(producer, &producer_index); - // If the producer node is not in the subgraph, the input is a graph input - if (node_set.find(producer_index) == node_set.end()) { - input_to_order[input_name] = input_order++; - } - } - - size_t implicit_input_size = 0; - graph_api_->OrtNode_GetImplicitInputSize(node, &implicit_input_size); - for (size_t j = 0; j < implicit_input_size; j++) { - const char* input_name = nullptr; - graph_api_->OrtNode_GetIthImplicitInputName(node, j, &input_name); - bool is_initializer = false; - graph_api_->OrtGraph_IsConstantInitializer(graph, input_name, true, &is_initializer); - if (is_initializer) { - initializers.push_back(input_name); - continue; - } - const OrtNode* producer = nullptr; - graph_api_->OrtGraph_GetNodeProducingOutput(graph, input_name, &producer); - // If the input is not produced by any node, it is a graph input - if (producer == nullptr) { - input_to_order[input_name] = input_order++; - continue; - } - size_t producer_index = -1; - graph_api_->OrtNode_GetIndex(producer, &producer_index); - // If the producer node is not in the subgraph, the input is a graph input - if (node_set.find(producer_index) == node_set.end()) { - input_to_order[input_name] = input_order++; - } - } - - size_t output_size = 0; - graph_api_->OrtNode_GetNumOutputs(node, &output_size); - for (size_t j = 0; j < output_size; j++) { - const char* output_name = nullptr; - graph_api_->OrtNode_GetIthOutputName(node, j, &output_name); - // If the output is the graph output, it is a subgraph output - if (graph_output_names.find(output_name) != graph_output_names.end()) { - output_to_order[output_name] = output_order++; - continue; - } - const OrtNode** consumers = nullptr; - size_t consumer_count = 0; - graph_api_->OrtGraph_GetNodesConsumingInput(graph, output_name, &consumers, &consumer_count); - for (size_t k = 0; k < consumer_count; k++) { - size_t consumer_index = -1; - graph_api_->OrtNode_GetIndex(consumers[k], &consumer_index); - // If the consumer node is not in the subgraph, the output is a subgraph output - if (node_set.find(consumer_index) == node_set.end()) { - output_to_order[output_name] = output_order++; - break; - } - } - graph_api_->ReleaseOrtNodeArray(consumers); - } - } - - // Sort inputs and outputs based on their order - std::multimap ordered_inputs, ordered_outputs; - for (const auto& input : input_to_order) { - ordered_inputs.insert(std::pair(input.second, input.first)); - } - for (const auto& output : output_to_order) { - ordered_outputs.insert(std::pair(output.second, output.first)); - } - - // Generate unique kernel name for TRT subgraph - std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(subgraph_index); - bool is_subgraph = false; - graph_api_->OrtGraph_IsSubgraph(graph, &is_subgraph); - const std::string graph_type = is_subgraph ? "subgraph" : "graph"; - const char* graph_name = nullptr; - graph_api_->OrtGraph_GetName(graph, &graph_name); - std::string meta_def_name = "TRTKernel_" + graph_type + "_" + std::string(graph_name) + subgraph_id; - sub_graph->meta_def->name = new char[meta_def_name.length() + 1]; - strcpy(sub_graph->meta_def->name, meta_def_name.c_str()); - - // Assign inputs and outputs to subgraph's meta_def - sub_graph->meta_def->input_len = ordered_inputs.size(); - sub_graph->meta_def->inputs = new char*[sub_graph->meta_def->input_len]; - i = 0; - for (const auto& input : ordered_inputs) { - sub_graph->meta_def->inputs[i] = new char[input.second.length() + 1]; - strcpy(sub_graph->meta_def->inputs[i++], input.second.c_str()); - } - - sub_graph->meta_def->initializer_len = initializers.size(); - sub_graph->meta_def->constant_initializers = new char*[sub_graph->meta_def->initializer_len]; - i = 0; - for (const auto& initializer : initializers) { - sub_graph->meta_def->constant_initializers[i] = new char[initializer.length() + 1]; - strcpy(sub_graph->meta_def->constant_initializers[i++], initializer.c_str()); - } - - sub_graph->meta_def->output_len = ordered_outputs.size(); - sub_graph->meta_def->outputs = new char*[sub_graph->meta_def->output_len]; - i = 0; - for (const auto& output : ordered_outputs) { - sub_graph->meta_def->outputs[i] = new char[output.second.length() + 1]; - strcpy(sub_graph->meta_def->outputs[i++], output.second.c_str()); - } - - sub_graph->meta_def->domain = "com.microsoft"; - sub_graph->meta_def->since_version = 1; - - return sub_graph; -} -*/ - static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, OrtEpGraphSupportInfo* graph_support_info) { TensorrtExecutionProvider* ep = static_cast(this_ptr); @@ -2182,1644 +1630,67 @@ nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder(TensorrtLogger& trt_lo return builder_.get(); } -OrtStatusPtr TensorrtExecutionProvider::RefitEngine(std::string onnx_model_filename, - std::string& onnx_model_folder_path, - std::string& weight_stripped_engine_cath_path, - bool path_check, - nvinfer1::ICudaEngine* trt_engine, - bool serialize_refitted_engine, - bool detailed_build_log) { -#if NV_TENSORRT_MAJOR >= 10 - std::filesystem::path onnx_model_path{onnx_model_folder_path}; - onnx_model_path.append(onnx_model_filename); - if (path_check && IsAbsolutePath(onnx_model_path.string())) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, - std::string("For security purpose, the ONNX model path should be set with " - "a relative path, but it is an absolute path: " + - onnx_model_path.string()) - .c_str()); - } - if (path_check && IsRelativePathToParentPath(onnx_model_path.string())) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, - "The ONNX model path has '..'. For security purpose, it's not " - "allowed to point outside the directory."); - } - - if (!std::filesystem::exists(onnx_model_path)) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, - std::string("The ONNX model " + onnx_model_path.string() + - " does not exist.") - .c_str()); - } - - // weight-stripped engine refit logic - TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log); - auto refitter = std::unique_ptr(nvinfer1::createInferRefitter(*trt_engine, trt_logger)); - auto parser_refitter = std::unique_ptr( - nvonnxparser::createParserRefitter(*refitter, trt_logger)); - if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, - std::string("TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()).c_str()); - } - if (refitter->refitCudaEngine()) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully refitted the weight-stripped engine."; - } else { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, - std::string("TensorRT EP's IRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()).c_str()); +SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t nodes_vector_input, int iterations, const int max_iterations, + const OrtGraph* graph, bool* early_termination) const { + // Return if iterations are exceeding predefined number + SubGraphCollection_t nodes_list_output; + if (iterations > max_iterations) { + *early_termination = true; + return nodes_list_output; } - // serialize the refitted engine to disk - if (serialize_refitted_engine) { - std::string refitted_engine_cache = GetWeightRefittedEnginePath(weight_stripped_engine_cath_path); - nvinfer1::IHostMemory* serialized_engine = trt_engine->serialize(); - std::ofstream engine_file(refitted_engine_cache, std::ios::binary | std::ios::out); - engine_file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialize the refitted engine to " << refitted_engine_cache; - } - return nullptr; -#else - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP's IParserRefitter can only be used on TRT 10.0 onwards."); -#endif -} + iterations++; + for (const auto& group : nodes_vector_input) { + // Construct subgraph + if (!group.first.empty()) { + if (group.second) { + nodes_list_output.push_back(group); + } else { + //const OrtGraphViewer* sub_graph_viewer = nullptr; + //graph_api_->OrtGraph_GetSubGraph(graph, group.first.size(), group.first.data(), &sub_graph_viewer); + void* buf_data = nullptr; + size_t buf_size = 0; + graph_api_->OrtGraph_SerializeToArray(sub_graph_viewer, &buf_data, &buf_size); -OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this_ptr, - const OrtGraph** graphs, - const OrtNode** fused_nodes, - std::unordered_map& input_map, - std::unordered_map& output_map, - OrtNodeComputeInfo** node_compute_infos) { - TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); - auto trt_builder = GetBuilder(trt_logger); - auto network_flags = 0; + // Get supported node list recursively + SubGraphCollection_t parser_nodes_list; + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); + auto trt_builder = GetBuilder(trt_logger); + auto network_flags = 0; #if NV_TENSORRT_MAJOR > 8 - network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); + network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); #endif - network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); - auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); - auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); - auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); - void* buf_data = nullptr; - size_t buf_size = 0; - graph_api_->OrtGraph_SerializeToArray(graph_body_viewer, &buf_data, &buf_size); - trt_parser->parse(buf_data, buf_size, model_path_); - trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); - graph_api_->OrtFreeMem(buf_data); - - // Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow - if (fp16_enable_ && layer_norm_fp32_fallback_) { - for (auto idx = 1; idx < trt_network->getNbLayers() - 1; ++idx) { - auto layer = trt_network->getLayer(idx); - auto next_layer = trt_network->getLayer(idx + 1); - if (layer->getType() == nvinfer1::LayerType::kELEMENTWISE && next_layer->getType() == nvinfer1::LayerType::kREDUCE && (static_cast(layer))->getOperation() == nvinfer1::ElementWiseOperation::kPOW) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow"; - layer->setPrecision(nvinfer1::DataType::kFLOAT); - next_layer->setPrecision(nvinfer1::DataType::kFLOAT); - layer->setOutputType(0, nvinfer1::DataType::kFLOAT); - next_layer->setOutputType(0, nvinfer1::DataType::kFLOAT); - } - } - } - - int num_inputs = trt_network->getNbInputs(); - int num_outputs = trt_network->getNbOutputs(); - std::unordered_map input_indexes(num_inputs); - std::unordered_map output_indexes(num_outputs); - std::unordered_map output_types(num_outputs); - - /* - * Initialize shape range for each dynamic shape input tensor: - * 1) If user explicitly specifies optimization profiles via provider options, TRT EP will create those profiles during EP compile time. - * It won't make adjustment for profile values during EP compute time. - * - * 2) If no explicit optimization profiles provided by user, TRT EP will firstly set min/max/opt shape to [INT_MAX, INT_MIN, INT_MIN]. - * Later in EP compute time, the shape will be adjusted to [min_input_value, max_input_value, max_input_value] based on input tensor value. - * - * - * Once the TRT profiles are created: - * 1) If all the dynamic shape input tensors have associated profiles explicitly provided by user, those profiles will be applied to TRT builder config - * and the engine will be built at EP compile time. - * - * 2) As long as one of the dynamic shape input tensors has no explicitly associated profile, TRT EP will create default shape as described above, - * and all the profiles won't be applied and engine won't be built until EP compute time. - */ - bool has_dynamic_shape = false; // True if input tensor has dynamic shape and no explicit profile is specified, otherwise false. - bool has_explicit_profile = false; - bool apply_explicit_profile = false; - int num_profiles = 0; - std::vector trt_profiles; - - // Following c++ map data structure is used to help serialize/deserialize profiles where it saves dynamic shape dimension(s) and min/max/opt values for dynamic shape input tensor. - // - // (1) Single profile case: - // For example, assume tensor_a has two dynamic shape dimensions: dim_0 and dim_2, and tensor_b - // has one dynamic shape dimension: dim_1. The data will be: - // { - // tensor_a: { - // dim_0: [[min_shape, max_shape, opt_shape]], - // dim_2: [[min_shape, max_shape, opt_shape]] - // }, - // tensor_b: { - // dim_1: [[min_shape, max_shape, opt_shape]] - // } - // } - // - // (2) Multiple profiles case: - // For example, assume tensor_a has one dynamic shap dimension: dim 0, and tensor_b has one dynamic shape dimension: dim_1, - // and both of the tensors have two profiles. The data will be: - // { - // tensor_a: { - // dim_0: [[min_shape_0, max_shape_0, opt_shape_0], [min_shape_1, max_shape_1, opt_shape_1]] - // }, - // tensor_b: { - // dim_1: [[min_shape_2, max_shape_2, opt_shape_2], [min_shape_3, max_shape_3, opt_shape_3]] - // } - // } - ShapeRangesMap input_explicit_shape_ranges; - ShapeRangesMap input_implicit_shape_ranges; - - if ((!profile_min_shapes_.empty()) && (!profile_max_shapes_.empty()) && (!profile_opt_shapes_.empty())) { - has_explicit_profile = true; - num_profiles = GetNumProfiles(profile_min_shapes_); - for (int i = 0; i < num_profiles; i++) { - trt_profiles.push_back(trt_builder->createOptimizationProfile()); - } - } - - // Iterate all input tensors to check dynamic shape - for (unsigned int i = 0, end = num_inputs; i < end; ++i) { - auto input = trt_network->getInput(i); - const std::string& input_name = input->getName(); - nvinfer1::Dims dims = input->getDimensions(); - int nb_dims = dims.nbDims; + network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); - // Apply explicit optimization profiles provided by user - if (has_explicit_profile) { - apply_explicit_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges); - } + auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + trt_parser->supportsModel(buf_data, buf_size, parser_nodes_list, model_path_); + graph_api_->OrtFreeMem(buf_data); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif - // If no explicit optimization profile is being applied, TRT EP will later set min/max/opt shape values based on input tensor values at EP compute time - if (!apply_explicit_profile) { - if (input->isShapeTensor()) { - // Shape tensor - std::vector> profile_vector; - std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN}; - profile_vector.push_back(shape_vector); // only one profile needed - input_implicit_shape_ranges[input_name][0] = profile_vector; - has_dynamic_shape = true; - } else { - // Execution tensor - for (int j = 0, end = nb_dims; j < end; ++j) { - if (dims.d[j] == -1) { - std::vector> profile_vector; - std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN}; - profile_vector.push_back(shape_vector); // only one profile needed - input_implicit_shape_ranges[input_name][j] = profile_vector; - has_dynamic_shape = true; + SubGraphCollection_t next_nodes_list; + const size_t* subgraph_node_index = nullptr; + size_t subgraph_node_count = 0; + graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(sub_graph_viewer, 1, &subgraph_node_index, &subgraph_node_count); + next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, sub_graph_viewer, early_termination); + for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { + for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { + next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]]; } + nodes_list_output.push_back(next_nodes_list[i]); } + graph_api_->OrtGraph_ReleaseGraphViewer(sub_graph_viewer, true); } - apply_explicit_profile = false; } } + return nodes_list_output; +} - // Set explicit profiles in TRT config if all dynamic shape inputs have associated profiles provided by user - if (has_explicit_profile) { - // TRT EP has a constraint here. - // Users need to provide all the dynamic shape inputs with associated profiles if they want to explicitly specify profiles through provider options. - if (has_dynamic_shape) { - std::ostringstream msg; - msg << "User needs to provide all the dynamic shape inputs with associated profiles if they want to explicitly set profiles through provider options.\n"; - msg << "Please note that main graph could be partitioned into TRT/CUDA/CPU subgraphs, in this case, user also needs to provide shape profiles for the TRT subgraph's input if it's dynamic shape input.\n"; - msg << "Following input(s) has no associated shape profiles provided: "; - auto begin = input_implicit_shape_ranges.begin(); - auto end = input_implicit_shape_ranges.end(); - auto it = begin; - if (it != end) { - msg << it->first; - ++it; - } - for (; it != end; ++it) { - msg << "," << it->first; - } - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, msg.str().c_str()); - } else { - for (auto trt_profile : trt_profiles) { - trt_config->addOptimizationProfile(trt_profile); - } - } - } - // If no explicit profile is applied and the input has dynamic shape, TRT EP simply creates one profile by default. - // It will later set proper min/max/opt shape values duing EP compute time. - else if (!has_explicit_profile && has_dynamic_shape) { - trt_profiles.push_back(trt_builder->createOptimizationProfile()); - } - - // Check platform availability for low precision - if (fp16_enable_) { - if (!trt_builder->platformHasFastFp16()) { - fp16_enable_ = false; - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE is set, but platform doesn't support fast native fp16"; - } - } - - if (int8_enable_) { - if (!trt_builder->platformHasFastInt8()) { - int8_enable_ = false; - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8"; - } - } - - const char* node_name = nullptr; - graph_api_->OrtNode_GetName(fused_node, &node_name); - - // Load INT8 calibration table - std::unordered_map dynamic_range_map; - if (int8_enable_ && int8_calibration_cache_available_) { - const std::string calibration_cache_path = GetCachePath(cache_path_, int8_calibration_cache_name_); - if (!ReadDynamicRange(calibration_cache_path, int8_use_native_tensorrt_calibration_table_, dynamic_range_map)) { - throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); - } - } - dynamic_range_map_[node_name] = dynamic_range_map; - - // Set precision flags - std::string trt_node_name_with_precision(node_name); - if (fp16_enable_ && int8_enable_) { - trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); - trt_node_name_with_precision += "_fp16_int8"; - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 and INT8 mode is enabled"; - } else if (fp16_enable_) { - trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); - trt_node_name_with_precision += "_fp16"; - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; - } else if (int8_enable_) { - trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); - trt_node_name_with_precision += "_int8"; - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; - } - - // Set DLA - if (fp16_enable_ || int8_enable_) { - if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8 - int number_of_dla_core = trt_builder->getNbDLACores(); - if (number_of_dla_core == 0) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core"; - dla_enable_ = false; - } else { - if (dla_core_ >= number_of_dla_core) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core #" << dla_core_ << ", but it exceeds platform's maximum DLA core number " << number_of_dla_core << ". Use DLA core 0 instead."; - dla_core_ = 0; - } - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << dla_core_; - trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); - trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); - trt_config->setDLACore(dla_core_); - trt_node_name_with_precision += "_dlacore" + std::to_string(dla_core_); - } - } - } - trt_node_name_with_precision_[node_name] = trt_node_name_with_precision; - - // enable sparse weights - if (sparsity_enable_) { - trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; - } -#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 - if (build_heuristics_enable_) { - trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are enabled." - // << " For TRT > 8.5, trt_build_heuristics_enable is deprecated, please set builder optimization level as 2 to enable builder heuristics."; - } -#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 - // for TRT 8.6 onwards, heuristic-based tactic option is automatically enabled by setting builder optimization level 2 - if (build_heuristics_enable_) { - if (builder_optimization_level_ == 2) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are automatically enabled by builder optimization level 2. trt_build_heuristics_enable is deprecated on TRT 8.6 onwards."; - } else { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] trt_build_heuristics_enable is deprecated on TRT 8.6 onwards. Please set builder optimization level as 2 to enable builder heuristics."; - } - } -#endif - -#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 - // switch optimizaion level - if (builder_optimization_level_ != 3) { - trt_config->setBuilderOptimizationLevel(builder_optimization_level_); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; - } - - // limit auxiliary streams - if (auxiliary_streams_ >= 0) { - trt_config->setMaxAuxStreams(auxiliary_streams_); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << auxiliary_streams_; - } -#else - if (builder_optimization_level_ != 3) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; - } - if (auxiliary_streams_ >= 0) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; - } -#endif - - if (weight_stripped_engine_enable_) { -#if NV_TENSORRT_MAJOR >= 10 - trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; - trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; -#else - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; -#endif - } - - // limit used tactic sources - if (!tactic_sources_.empty()) { - nvinfer1::TacticSources tactics = trt_config->getTacticSources(); - tactics |= GetTacticSourceFromString(tactic_sources_); - trt_config->setTacticSources(tactics); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using " << tactic_sources_; - } - - // Build TRT engine (if needed) and load TRT engine if: - // (1) Graph has no dynamic shape input - // (2) All the dynamic shape inputs have associated explicit profiles specified by user - // - // Otherwise engine will be handled at inference time. - std::unique_ptr trt_engine; - std::unique_ptr trt_context; - - std::string cache_path = ""; - std::string cache_suffix = ""; - // Customize cache prefix if assigned - if (!cache_prefix_.empty()) { - // Generate cache suffix in case user would like to customize cache prefix - cache_suffix = "_" + GetCacheSuffix(node_name, trt_node_name_with_precision); - cache_path = GetCachePath(cache_path_, cache_prefix_) + cache_suffix; - } else { - cache_path = GetCachePath(cache_path_, trt_node_name_with_precision); - } - cache_suffix_[node_name] = cache_suffix; - - std::string cache_hw_compat = "_sm" + compute_capability_; - // Enable hardware compatility mode if assigned - if (engine_cache_enable_ && engine_hw_compatible_) { - trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); - cache_hw_compat = "_sm80+"; - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; - } - - // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache - // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity - const std::string cache_path_prefix = cache_path + cache_hw_compat; - std::string engine_cache_path = cache_path_prefix + ".engine"; - const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; - const std::string profile_cache_path = cache_path_prefix + ".profile"; - - // If weight-stripped engine is enabled and refitted engine cache is not present, - // TRT EP will use the engine cache with ".stripped.engine" appended to the end. - const std::filesystem::path engine_cache_fs_path = engine_cache_path; - if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { - engine_cache_path = cache_path_prefix + ".stripped.engine"; - weight_stripped_engine_refit_ = true; - } - - auto create_ep_context_model = [this](const OrtGraphViewer* graph_body_viewer, - std::string& engine_cache_path, - std::string& engine_cache_relative_path_to_context_model_dir, - const char* ep_context_node_name, - char* serialized_engine, - size_t serialized_engine_size) { - // if ep context model name is not given, create a model name based on original model name - if (ctx_model_path_.empty()) { - ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_); - } - - // "ep_cache_context" node attribute should be a relative path to context model directory - if (ep_cache_context_attr_.empty()) { - auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); - ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); - } - - graph_api_->OrtGraph_CreateOrUpdateEpCtxGraph(graph_body_viewer, - ep_context_node_name, - 1, // main_context - ep_context_embed_mode_, - ep_cache_context_attr_.c_str(), - serialized_engine, - serialized_engine_size, - extra_attr_keys_.data(), - extra_attr_values_.data(), - extra_attr_keys_.size(), - &ep_ctx_graph_); - }; - - if (!has_dynamic_shape) { - std::string timing_cache_path = ""; - bool engine_update = false; - if (timing_cache_enable_) { - timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_); - } - { - // ifstream file check, engine serialization/deserialization and engine build are in critical section. It needs lock protection to prevent race condition when inferencing with multithreading. - auto lock = GetApiLock(); - - // If explicit profile flag is on and engine cache enable flag is on, - // we need to compare explicit profiles and profiles used to build the engine in order to decide whether to rebuild the engine. - if (has_explicit_profile && engine_cache_enable_) { - engine_update = CompareProfiles(profile_cache_path, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); - if (engine_update) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine will be built"; - } else { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine won't be rebuilt"; - } - } - - std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); - if (engine_cache_enable_ && !engine_decryption_enable_ && engine_file && !engine_update) { - engine_file.seekg(0, std::ios::end); - size_t engine_size = engine_file.tellg(); - engine_file.seekg(0, std::ios::beg); - std::unique_ptr engine_buf{new char[engine_size]}; - engine_file.read((char*)engine_buf.get(), engine_size); - trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; - if (trt_engine == nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not deserialize engine from cache: " + engine_cache_path).c_str()); - } - - } else if (engine_decryption_enable_ && engine_cache_enable_ && std::filesystem::exists(encrypted_engine_cache_path) && !engine_update) { - // Decrypt engine - size_t engine_size = 0; - if (!engine_decryption_(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not get engine buffer size"); - } - std::unique_ptr engine_buf{new char[engine_size]}; - if (!engine_decryption_(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not call engine decryption function decrypt"); - } - // Deserialize engine - trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; - if (trt_engine == nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path).c_str()); - } - } else { - // Set INT8 per tensor dynamic range - if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - trt_config->setInt8Calibrator(nullptr); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (!SetDynamicRange(*trt_network, dynamic_range_map)) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not set INT8 dynamic range for fused node: " + std::string(node_name)).c_str()); - } - } - - // Load timing cache from file. Create a fresh cache if the file doesn't exist - std::unique_ptr timing_cache = nullptr; - if (timing_cache_enable_) { - std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); - timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); - if (timing_cache == nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not create timing cache: " + timing_cache_path).c_str()); - } - trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); - if (detailed_build_log_) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; - } - } - - // Build engine - std::chrono::steady_clock::time_point engine_build_start; - if (detailed_build_log_) { - engine_build_start = std::chrono::steady_clock::now(); - } - std::unique_ptr serialized_engine{trt_builder->buildSerializedNetwork(*trt_network, *trt_config)}; - if (serialized_engine == nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP failed to create engine from network for fused node: " + std::string(node_name)).c_str()); - } - trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); - if (trt_engine == nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP failed to deserialize engine for fused node: " + std::string(node_name)).c_str()); - } - if (detailed_build_log_) { - auto engine_build_stop = std::chrono::steady_clock::now(); - // LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; - } - if (engine_cache_enable_) { - // Serialize engine profile if it has explicit profiles - if (has_explicit_profile) { - SerializeProfileV2(profile_cache_path, input_explicit_shape_ranges); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; - } - - if (engine_decryption_enable_) { - // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. - if (engine_encryption_ != nullptr) { - if (!engine_encryption_(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP call to engine encryption library failed"); - } - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; - } else { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; - } - } else { - std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); - file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path; - } - } - // serialize and save timing cache - if (timing_cache_enable_) { - auto timing_cache = trt_config->getTimingCache(); - std::unique_ptr timingCacheHostData{timing_cache->serialize()}; - if (timingCacheHostData == nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not serialize timing cache: " + timing_cache_path).c_str()); - } - saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); - if (detailed_build_log_) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; - } - } - - // create and dump ep context model - if (dump_ep_context_model_) { - create_ep_context_model(graph_body_viewer, engine_cache_path, engine_cache_relative_path_to_context_model_dir, node_name, reinterpret_cast(serialized_engine->data()), serialized_engine->size()); - graph_api_->OrtGraph_DumpOnnxModel(ep_ctx_graph_, ctx_model_path_.c_str()); - graph_api_->OrtGraph_ReleaseGraph(ep_ctx_graph_); - } - } - } - - if (weight_stripped_engine_refit_) { - auto status = RefitEngine(model_path_, - onnx_model_folder_path_, - engine_cache_path, - false /* path check for security */, - trt_engine.get(), - true /* serialize refitted engine to disk */, - detailed_build_log_); - if (status != nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); - } - } - - // Build context - // Note: Creating an execution context from an engine is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - if (context_memory_sharing_enable_) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - size_t mem_size = trt_engine->getDeviceMemorySize(); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (mem_size > max_ctx_mem_size_) { - max_ctx_mem_size_ = mem_size; - } -#if NV_TENSORRT_MAJOR < 10 - trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); -#else - trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); -#endif - } else { - trt_context = std::unique_ptr(trt_engine->createExecutionContext()); - } - if (!trt_context) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not build execution context for fused node: " + std::string(node_name)).c_str()); - } - } - - // Create input to index map - for (int i = 0; i < num_inputs; ++i) { - auto input = trt_network->getInput(i); - const std::string& input_name = input->getName(); - const auto& iter = input_map.find(input_name); - if (iter != input_map.end()) { - input_indexes[input_name] = iter->second; - } - } - - // Create output to index and type maps - for (int i = 0; i < num_outputs; ++i) { - const std::string& output_name = trt_network->getOutput(i)->getName(); - const auto& iter = output_map.find(output_name); - if (iter != output_map.end()) { - output_indexes[output_name] = iter->second; - } - int32_t output_type = 0; - graph_api_->OrtGraph_GetIthOutputElemType(graph_body_viewer, i, &output_type); - output_types[output_name] = output_type; - } - - // Save TRT engine, other TRT objects and input/output info to map - parsers_.emplace(node_name, std::move(trt_parser)); - engines_.emplace(node_name, std::move(trt_engine)); - contexts_.emplace(node_name, std::move(trt_context)); - networks_.emplace(node_name, std::move(trt_network)); - input_info_[node_name].push_back(input_indexes); - output_info_[node_name].push_back(output_indexes); - output_info_[node_name].push_back(output_types); - input_shape_ranges_[node_name] = input_implicit_shape_ranges; - profiles_.emplace(node_name, std::move(trt_profiles)); - - // Create ep context model if the model has dynamic shape, - // dump the model is embed mode is 0, otherwise update and dump the model at runtime. - if (has_dynamic_shape && dump_ep_context_model_) { - create_ep_context_model(graph_body_viewer, engine_cache_path, engine_cache_relative_path_to_context_model_dir, node_name, nullptr, 0); - if (ep_context_embed_mode_ == 0) { - graph_api_->OrtGraph_DumpOnnxModel(ep_ctx_graph_, ctx_model_path_.c_str()); - graph_api_->OrtGraph_ReleaseGraph(ep_ctx_graph_); - } - } - - // Create function state - node_compute_funcs->CreateFunctionStateFunc = [](OrtComputeContext* context, void* extra_param, void** state) -> int { - TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); - std::unique_ptr p = std::make_unique(); - - // translate tactic sources string to nvinfer1::TacticSources - nvinfer1::TacticSources tactics = 0; - if (!this_->tactic_sources_.empty()) { - tactics = GetTacticSourceFromString(this_->tactic_sources_); - } - *p = {context->AllocateFunc, context->DestroyFunc, context->allocator_handle, context->node_name, this_->builder_.get(), - &(this_->parsers_[context->node_name]), &(this_->engines_[context->node_name]), &(this_->contexts_[context->node_name]), - &(this_->networks_[context->node_name]), this_->input_info_[context->node_name], this_->output_info_[context->node_name], - this_->input_shape_ranges_[context->node_name], &this_->tensorrt_mu_, this_->fp16_enable_, this_->int8_enable_, this_->int8_calibration_cache_available_, - this_->dla_enable_, this_->dla_core_, &(this_->max_workspace_size_), this_->trt_node_name_with_precision_[context->node_name], - this_->engine_cache_enable_, this_->cache_path_, this_->runtime_.get(), this_->profiles_[context->node_name], - this_->context_memory_sharing_enable_, &(this_->max_ctx_mem_size_), this_->dynamic_range_map_[context->node_name], this_->engine_decryption_enable_, - this_->engine_decryption_, this_->engine_encryption_, this_->timing_cache_enable_, this_->global_cache_path_, this_->force_timing_cache_match_, - this_->detailed_build_log_, this_->build_heuristics_enable_, this_->sparsity_enable_, this_->builder_optimization_level_, - this_->auxiliary_streams_, !(this_->tactic_sources_.empty()), tactics, this_->cuda_graph_enable_, this_->cache_prefix_, this_->cache_suffix_[context->node_name], this_->engine_hw_compatible_}; - *state = p.release(); - return 0; - }; - - // Release function state - node_compute_funcs->DestroyFunctionStateFunc = [](void* state) { - delete static_cast(state); - }; - - // Create compute function - node_compute_funcs->ComputeFunc = [](void* state, void* extra_param, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { - Ort::KernelContext ctx(context); - TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); - TensorrtFuncState* trt_state = reinterpret_cast(state); - - // The whole compute_function should be considered the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine, - // save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading. - // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); - const std::unordered_map& input_indexes = (trt_state->input_info)[0]; - const std::unordered_map& output_indexes = (trt_state->output_info)[0]; - const std::unordered_map& output_types = (trt_state->output_info)[1]; - auto fused_node_name = trt_state->fused_node_name; - // This map "shape_ranges" contains the shape range info for setting TRT optimization profiles. - // The info is used for both shape tensor and execution tensor: - // tensor name->(dimension->[min, max, opt]) - auto& shape_ranges = trt_state->input_shape_ranges; - std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run - std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input - auto& dds_output_allocator_map = this_->dds_output_allocator_maps_[fused_node_name]; - auto trt_builder = trt_state->builder; - auto trt_engine = trt_state->engine->get(); - auto trt_context = trt_state->context->get(); - auto trt_profiles = trt_state->profiles; - auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; - int num_inputs = static_cast(input_indexes.size()); - int num_outputs = static_cast(output_indexes.size()); - bool engine_update = false; - bool context_update = false; - std::unordered_set input_names; - - OrtMemoryInfo* mem_info = nullptr; - api_->CreateMemoryInfo("Cuda", OrtAllocatorType::OrtDeviceAllocator, this_->device_id_, OrtMemType::OrtMemTypeDefault, &mem_info); - if (this_->alloc_ == nullptr) { - Ort::ThrowOnError(api_->KernelContext_GetAllocator(context, mem_info, &(this_->alloc_))); - } - OrtAllocator* alloc = this_->alloc_; - - void* cuda_stream; - Ort::ThrowOnError(api_->KernelContext_GetGPUComputeStream(context, &cuda_stream)); - cudaStream_t stream = static_cast(cuda_stream); - - // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache - // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity - // Prepare cache name - std::string cache_path = ""; - // Customize cache prefix if assigned - if (!this_->cache_prefix_.empty()) { - cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->cache_prefix) + trt_state->cache_suffix; - } else { - cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); - } - - // Enable hardware compatility mode if assigned - std::string cache_hw_compat = "_sm" + this_->compute_capability_; - if (this_->engine_cache_enable_ && this_->engine_hw_compatible_) { - cache_hw_compat = "_sm80+"; - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; - } - - // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache - // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity - const std::string cache_path_prefix = cache_path + cache_hw_compat; - std::string engine_cache_path = cache_path_prefix + ".engine"; - const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; - const std::string profile_cache_path = cache_path_prefix + ".profile"; - std::string timing_cache_path = ""; - if (this_->timing_cache_enable_) { - timing_cache_path = GetTimingCachePath(this_->global_cache_path_, this_->compute_capability_); - } - - // If weight-stripped engine is enabled and refitted engine cache is not present, - // TRT EP will use the engine cache with ".stripped.engine" appended to the end. - const std::filesystem::path engine_cache_fs_path = engine_cache_path; - if (this_->weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { - engine_cache_path = cache_path_prefix + ".stripped.engine"; - this_->weight_stripped_engine_refit_ = true; - } - - // Load serialized engine - if (trt_state->engine_cache_enable && trt_engine == nullptr) { - std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); - std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); - if (engine_file && !trt_state->engine_decryption_enable && profile_file) { - // Deserialize profile - shape_ranges = DeserializeProfileV2(profile_file); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; - - // Prepare buffer - engine_file.seekg(0, std::ios::end); - size_t engine_size = engine_file.tellg(); - engine_file.seekg(0, std::ios::beg); - std::unique_ptr engine_buf{new char[engine_size]}; - engine_file.read((char*)engine_buf.get(), engine_size); - - // Deserialize engine - // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_state->engine->reset(); - *(trt_state->engine) = std::unique_ptr( - trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); - if (!(*(trt_state->engine))) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP Failed to Build Engine."); - } - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; - trt_engine = trt_state->engine->get(); - context_update = true; - - } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) { - shape_ranges = DeserializeProfileV2(profile_file); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; - // Decrypt engine - size_t engine_size = 0; - if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not get engine buffer size"); - } - std::unique_ptr engine_buf{new char[engine_size]}; - if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not call engine decryption function decrypt"); - } - // Deserialize engine - // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_state->engine->reset(); - *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); - if (!(*(trt_state->engine))) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path).c_str()); - } - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; - trt_engine = trt_state->engine->get(); - context_update = true; - } - } - - // Check and update shape ranges for dynamic shape inputs. - for (int i = 0, end = num_inputs; i < end; ++i) { - auto input = trt_state->network->get()->getInput(i); - const std::string& input_name = input->getName(); - input_names.insert(input_name); - - // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved. - // TRT EP will help determine the min/max/opt profile values based on current input tensor value. - if (shape_ranges.find(input_name) != shape_ranges.end()) { - auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, shape_tensor_values, shape_tensor_values_int64, stream, &engine_update); - if (status != nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles."); - } - } - } - - // Regenerate engine - if (engine_update) { - // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior. - trt_state->context->reset(); - trt_state->engine->reset(); - auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); - trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, *(trt_state->max_workspace_size_ptr)); - for (auto trt_profile : trt_profiles) { - trt_config->addOptimizationProfile(trt_profile); - } - - // Set INT8 Per Tensor Dynamic range - if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - trt_config->setInt8Calibrator(nullptr); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); - } - } - - // Set precision - if (trt_state->fp16_enable && trt_state->int8_enable) { - trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); - } else if (trt_state->fp16_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); - } else if (trt_state->int8_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); - } - - // Set DLA (DLA can only run with FP16 or INT8) - if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; - trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); - trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); - trt_config->setDLACore(trt_state->dla_core); - } - - // enable sparse weights - if (trt_state->sparsity_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; - } -#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 - // enable builder heuristics - if (trt_state->build_heuristics_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; - } -#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 - // switch optimizaion level - if (trt_state->builder_optimization_level != 3) { - trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; - } - - // limit auxiliary streams - if (trt_state->auxiliary_streams >= 0) { - trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; - } -#else - if (trt_state->builder_optimization_level != 3) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; - } - if (trt_state->auxiliary_streams >= 0) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; - } -#endif - if (this_->weight_stripped_engine_enable_) { -#if NV_TENSORRT_MAJOR >= 10 - trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; - trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; -#else - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; -#endif - } - // limit used tactic sources - if (trt_state->filter_tactic_sources) { - nvinfer1::TacticSources tactics = trt_config->getTacticSources(); - tactics |= trt_state->tactic_sources; - trt_config->setTacticSources(tactics); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; - } - - // Load timing cache from file. Create a fresh cache if the file doesn't exist - std::unique_ptr timing_cache = nullptr; - if (trt_state->timing_cache_enable) { - std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); - timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); - if (timing_cache == nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not create timing cache: " + timing_cache_path).c_str()); - } - trt_config->setTimingCache(*timing_cache, this_->force_timing_cache_match_); - if (this_->detailed_build_log_) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; - } - } - - // Enable hardware compatility mode if assigned - if (trt_state->engine_hw_compatible) { - trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); - // LOGS_DEFAULT(INFO) << "[TensorRT EP] Re-generate engine with hardware compatibility enabled."; - } - - // Build engine - std::unique_ptr serialized_engine; - { - auto lock = this_->GetApiLock(); - std::chrono::steady_clock::time_point engine_build_start; - if (this_->detailed_build_log_) { - engine_build_start = std::chrono::steady_clock::now(); - } - serialized_engine = std::unique_ptr( - trt_builder->buildSerializedNetwork(*trt_state->network->get(), *trt_config)); - if (!serialized_engine) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to create engine from network."); - } - *(trt_state->engine) = std::unique_ptr( - trt_state->runtime->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); - if (!(*(trt_state->engine))) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to deserialize engine."); - } - if (this_->detailed_build_log_) { - auto engine_build_stop = std::chrono::steady_clock::now(); - // LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; - } - } - if (!(*(trt_state->engine))) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP Failed to Build Engine."); - } - trt_engine = trt_state->engine->get(); - if (trt_state->engine_cache_enable) { - // Serialize engine profile - SerializeProfileV2(profile_cache_path, shape_ranges); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; - - // Serialize engine - if (trt_state->engine_decryption_enable) { - // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. - if (trt_state->engine_encryption != nullptr) { - if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not call engine encryption function encrypt"); - } - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; - } else { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; - } - } else { - std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); - file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; - } - } - - // serialize and save timing cache - if (trt_state->timing_cache_enable) { - auto timing_cache = trt_config->getTimingCache(); - std::unique_ptr timingCacheHostData{timing_cache->serialize()}; - if (timingCacheHostData == nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not serialize timing cache: " + timing_cache_path).c_str()); - } - saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); - if (this_->detailed_build_log_) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; - } - } - - // dump ep context model - if (this_->dump_ep_context_model_ && this_->ep_context_embed_mode_) { - graph_api_->OrtGraph_CreateOrUpdateEpCtxGraph(nullptr, - fused_node_name.c_str(), - 1, // main_context - this_->ep_context_embed_mode_, - this_->ep_cache_context_attr_.c_str(), - reinterpret_cast(serialized_engine->data()), - serialized_engine->size(), - this_->extra_attr_keys_.data(), - this_->extra_attr_values_.data(), - this_->extra_attr_keys_.size(), - &this_->ep_ctx_graph_); - graph_api_->OrtGraph_DumpOnnxModel(this_->ep_ctx_graph_, this_->ctx_model_path_.c_str()); - graph_api_->OrtGraph_ReleaseGraph(this_->ep_ctx_graph_); - } - context_update = true; - - if (this_->weight_stripped_engine_refit_) { - auto status = RefitEngine(this_->model_path_, - this_->onnx_model_folder_path_, - engine_cache_path, - false /* path check for security */, - trt_engine, - true /* serialize refitted engine to disk */, - this_->detailed_build_log_); - if (status != nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); - } - } - } - - if (context_update) { - if (trt_state->context_memory_sharing_enable) { -#if NV_TENSORRT_MAJOR < 10 - *(trt_state->context) = std::unique_ptr( - trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); -#else - *(trt_state->context) = std::unique_ptr( - trt_state->engine->get()->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); -#endif - } else { - *(trt_state->context) = std::unique_ptr( - trt_state->engine->get()->createExecutionContext()); - } - if (!(*(trt_state->context))) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to create context."); - } - trt_context = trt_state->context->get(); - } - - // Get input and output binding names - int total_bindings = trt_engine->getNbIOTensors(); - std::vector input_binding_names, output_binding_names; - for (int i = 0, end = total_bindings; i < end; ++i) { - auto const& name = trt_engine->getIOTensorName(i); - auto const& mode = trt_engine->getTensorIOMode(name); - if (mode == nvinfer1::TensorIOMode::kINPUT) { - input_binding_names.push_back(name); - } else { - output_binding_names.push_back(name); - } - } - - /* - * Set input shapes and bind input buffers - */ - std::vector> scratch_buffers; - for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { - char const* input_name = input_binding_names[i]; - - size_t input_index = 0; - const auto iter = input_indexes.find(input_name); - if (iter != input_indexes.end()) { - input_index = iter->second; - } - auto input_tensor = ctx.GetInput(input_index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shapes = tensor_info.GetShape(); - - auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); - if (status != nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); - } - } - - /* - * Set output shapes and bind output buffers - */ - std::unordered_map buffers; - buffers.reserve(num_outputs); - using OutputOrtValue = Ort::UnownedValue; - std::unordered_map output_tensors; - output_tensors.reserve(num_outputs); - std::unordered_map output_dim_sizes; - output_dim_sizes.reserve(num_outputs); - - for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { - char const* output_name = output_binding_names[i]; - - size_t output_index = 0; - const auto& index_iter = output_indexes.find(output_name); - if (index_iter != output_indexes.end()) { - output_index = index_iter->second; - } - - size_t output_type = 0; - const auto type_iter = output_types.find(output_name); - if (type_iter != output_types.end()) { - output_type = type_iter->second; - } - - OrtStatusPtr status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, - dds_output_allocator_map, scratch_buffers, alloc, buffers); - if (status != nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); - } - } - - // Set execution context memory - if (trt_state->context_memory_sharing_enable) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - size_t mem_size = trt_engine->getDeviceMemorySize(); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (mem_size > *max_context_mem_size_ptr) { - *max_context_mem_size_ptr = mem_size; - } - trt_context->setDeviceMemory(MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); - } - - // Start CUDA graph capture. - // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because - // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - // if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { - // LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; - // cuda_graph_.SetStream(stream); - // CaptureBegin(0); - // } - - // Run TRT inference - if (!trt_context->enqueueV3(stream)) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP execution context enqueue failed."); - } - - /* - * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently, - * TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen: - * - * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream. - * In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently, - * the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT. - * So TRT EP will end up having one trt execution context using multiple streams which is not suggested. - * But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream - * is guaranteed. - * - * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. - * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. - */ - if (this_->sync_stream_after_enqueue_) { - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); - } - - // Assign TRT output back to ORT output - // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) - // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output - for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { - char const* output_name = output_binding_names[i]; - - size_t output_type = 0; - const auto& iter = output_types.find(output_name); - if (iter != output_types.end()) { - output_type = iter->second; - } - - if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { - size_t output_index = 0; - const auto& index_iter = output_indexes.find(output_name); - if (index_iter != output_indexes.end()) { - output_index = index_iter->second; - } - auto status = BindKernelOutput(ctx, mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); - if (status != nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); - } - } else { - auto& output_tensor = output_tensors[i]; -#if NV_TENSORRT_MAJOR < 10 - if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); - } - } -#endif - if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); - } - } - } - } - - // // End CUDA graph capture. - // // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture - // // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. - // // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. - // if (cuda_graph_enable_ && !IsGraphCaptured(0)) { - // if (IsGraphCaptureAllowed()) { - // CaptureEnd(0); - // // CUDA work issued to a capturing stream doesn’t actually run on the GPU, - // // so run the captured graph here to actually execute the work. - // ORT_RETURN_IF_ERROR(ReplayGraph(0)); - // } else { - // IncrementRegularRunCountBeforeGraphCapture(); - // } - // } - // std::cout << "end of ComputeFunc in TRTEp's CreateNodeComputeInfoFromGraph()\n"; - return nullptr; - }; - - return nullptr; -} - -OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const OrtGraphViewer* graph_body_viewer, const OrtNode* fused_node, - std::unordered_map& input_map, - std::unordered_map& output_map, - OrtNodeComputeInfo* node_compute_funcs) { - std::unique_ptr trt_engine; - std::unique_ptr trt_context; - std::unordered_map input_indexes; // TRT engine input name -> ORT kernel context input index - std::unordered_map output_indexes; // TRT engine output name -> ORT kernel context output index - std::unordered_map output_types; // TRT engine output name -> ORT output tensor type - - // Get engine binary data and deserialize it - auto trt_cache_model_handler = TensorRTCacheModelHandler(&trt_engine, - runtime_.get(), - model_path_, - compute_capability_, - weight_stripped_engine_enable_, - onnx_model_folder_path_, - detailed_build_log_); - auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer); - if (status != nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); - } - - // Build context - // - // Note: Creating an execution context from an engine is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - if (context_memory_sharing_enable_) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - size_t mem_size = trt_engine->getDeviceMemorySize(); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (mem_size > max_ctx_mem_size_) { - max_ctx_mem_size_ = mem_size; - } -#if NV_TENSORRT_MAJOR < 10 - trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); -#else - trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); -#endif - } else { - trt_context = std::unique_ptr(trt_engine->createExecutionContext()); - } - - const char* fused_node_name = nullptr; - graph_api_->OrtNode_GetName(fused_node, &fused_node_name); - if (!trt_context) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, - std::string("TensorRT EP could not build execution context for fused node: " + std::string(fused_node_name)).c_str()); - } - - // Create input/output to index maps - for (int32_t i = 0; i < trt_engine->getNbIOTensors(); ++i) { - auto const& name = trt_engine->getIOTensorName(i); - auto const& mode = trt_engine->getTensorIOMode(name); - if (mode == nvinfer1::TensorIOMode::kINPUT) { - const auto& iter = input_map.find(name); - if (iter != input_map.end()) { - input_indexes[name] = iter->second; - } - } else { - const auto& iter = output_map.find(name); - if (iter != output_map.end()) { - output_indexes[name] = iter->second; - } - } - } - - // Create output to type map - size_t graph_output_size = 0; - graph_api_->OrtGraph_GetOutputSize(graph_body_viewer, &graph_output_size); - for (size_t i = 0; i < graph_output_size; i++) { - char const* output_name = nullptr; - graph_api_->OrtGraph_GetIthOutputName(graph_body_viewer, i, &output_name); - int32_t output_type = 0; - graph_api_->OrtGraph_GetIthOutputElemType(graph_body_viewer, i, &output_type); - output_types[output_name] = output_type; - } - - // Save TRT engine, TRT context and input/output info to map - engines_.emplace(fused_node_name, std::move(trt_engine)); - contexts_.emplace(fused_node_name, std::move(trt_context)); - input_info_[fused_node_name].push_back(input_indexes); - output_info_[fused_node_name].push_back(output_indexes); - output_info_[fused_node_name].push_back(output_types); - - // Create function state - node_compute_funcs->CreateFunctionStateFunc = [](OrtComputeContext* context, void* extra_param, void** state) -> int { - TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); - std::unique_ptr p = std::make_unique(); - *p = {context->AllocateFunc, - context->DestroyFunc, - context->allocator_handle, - context->node_name, - &(this_->engines_[context->node_name]), - &(this_->contexts_[context->node_name]), - this_->input_info_[context->node_name], - this_->output_info_[context->node_name], - this_->context_memory_sharing_enable_, - &this_->max_ctx_mem_size_, - &this_->tensorrt_mu_}; - *state = p.release(); - return 0; - }; - - // Release function state - node_compute_funcs->DestroyFunctionStateFunc = [](void* state) { - delete reinterpret_cast(state); - }; - - // Create compute function - node_compute_funcs->ComputeFunc = [](void* state, void* extra_param, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { - TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); - TensorrtShortFuncState* trt_state = reinterpret_cast(state); - Ort::KernelContext ctx(context); - - // The whole compute_function should be considered the critical section. - // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); - const std::unordered_map& input_indexes = (trt_state->input_info)[0]; - const std::unordered_map& output_indexes = (trt_state->output_info)[0]; - const std::unordered_map& output_types = (trt_state->output_info)[1]; - auto fused_node_name = trt_state->fused_node_name; - std::cout << fused_node_name << std::endl; - auto& dds_output_allocator_map = this_->dds_output_allocator_maps_[fused_node_name]; - auto trt_engine = trt_state->engine->get(); - auto trt_context = trt_state->context->get(); - auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; - int num_outputs = static_cast(output_indexes.size()); - std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run - std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input - - OrtMemoryInfo* mem_info = nullptr; - api_->CreateMemoryInfo("Cuda", OrtAllocatorType::OrtDeviceAllocator, this_->device_id_, OrtMemType::OrtMemTypeDefault, &mem_info); - if (this_->alloc_ == nullptr) { - Ort::ThrowOnError(api_->KernelContext_GetAllocator(context, mem_info, &(this_->alloc_))); - } - OrtAllocator* alloc = this_->alloc_; - - void* cuda_stream; - Ort::ThrowOnError(api_->KernelContext_GetGPUComputeStream(context, &cuda_stream)); - cudaStream_t stream = static_cast(cuda_stream); - - // Get input and output binding names - int total_bindings = trt_engine->getNbIOTensors(); - std::vector input_binding_names, output_binding_names; - for (int i = 0, end = total_bindings; i < end; ++i) { - auto const& name = trt_engine->getIOTensorName(i); - auto const& mode = trt_engine->getTensorIOMode(name); - if (mode == nvinfer1::TensorIOMode::kINPUT) { - input_binding_names.push_back(name); - } else { - output_binding_names.push_back(name); - } - } - - /* - * Set input shapes and bind input buffers - */ - std::vector> scratch_buffers; - for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { - char const* input_name = input_binding_names[i]; - - size_t input_index = 0; - const auto iter = input_indexes.find(input_name); - if (iter != input_indexes.end()) { - input_index = iter->second; - } - - OrtStatusPtr status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); - if (status != nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); - } - } - - /* - * Set output shapes and bind output buffers - */ - std::unordered_map buffers; - buffers.reserve(num_outputs); - using OutputOrtValue = Ort::UnownedValue; - std::unordered_map output_tensors; - output_tensors.reserve(num_outputs); - std::unordered_map output_dim_sizes; - output_dim_sizes.reserve(num_outputs); - - for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { - char const* output_name = output_binding_names[i]; - - size_t output_index = 0; - const auto& index_iter = output_indexes.find(output_name); - if (index_iter != output_indexes.end()) { - output_index = index_iter->second; - } - - size_t output_type = 0; - const auto type_iter = output_types.find(output_name); - if (type_iter != output_types.end()) { - output_type = type_iter->second; - } - - OrtStatusPtr status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, - dds_output_allocator_map, scratch_buffers, alloc, buffers); - if (status != nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); - } - } - - // Set execution context memory - if (trt_state->context_memory_sharing_enable) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - size_t mem_size = trt_engine->getDeviceMemorySize(); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (mem_size > *max_context_mem_size_ptr) { - *max_context_mem_size_ptr = mem_size; - } - trt_context->setDeviceMemory(MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); - } - - // Start CUDA graph capture. - // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because - // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (this_->cuda_graph_enable_ && this_->IsGraphCaptureAllowed() && !this_->IsGraphCaptured(0)) { - // LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; - // cuda_graph_.SetStream(stream); - // CaptureBegin(0); - } - - // Run TRT inference - if (!trt_context->enqueueV3(stream)) { - return api_->CreateStatus(OrtErrorCode::ORT_FAIL, "TensorRT EP execution context enqueue failed."); - } - - /* - * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently, - * TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen: - * - * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream. - * In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently, - * the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT. - * So TRT EP will end up having one trt execution context using multiple streams which is not suggested. - * But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream - * is guaranteed. - * - * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. - * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. - */ - if (this_->sync_stream_after_enqueue_) { - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); - } - - // Assign TRT output back to ORT output - // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) - // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output - for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { - char const* output_name = output_binding_names[i]; - - size_t output_type = 0; - const auto& iter = output_types.find(output_name); - if (iter != output_types.end()) { - output_type = iter->second; - } - - if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { - size_t output_index = 0; - const auto& index_iter = output_indexes.find(output_name); - if (index_iter != output_indexes.end()) { - output_index = index_iter->second; - } - OrtStatusPtr status = BindKernelOutput(ctx, mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); - if (status != nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_FAIL, api_->GetErrorMessage(status)); - } - } else { - auto& output_tensor = output_tensors[i]; -#if NV_TENSORRT_MAJOR < 10 - if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); - } - } -#endif - if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); - } - } - } - } - - // End CUDA graph capture. - // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture - // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. - // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. - if (this_->cuda_graph_enable_ && !this_->IsGraphCaptured(0)) { - // if (IsGraphCaptureAllowed()) { - // CaptureEnd(0); - // // CUDA work issued to a capturing stream doesn’t actually run on the GPU, - // // so run the captured graph here to actually execute the work. - // ORT_RETURN_IF_ERROR(ReplayGraph(0)); - // } else { - // IncrementRegularRunCountBeforeGraphCapture(); - // } - } - - return nullptr; - }; - - return nullptr; -} - -SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t nodes_vector_input, int iterations, const int max_iterations, - const OrtGraph* graph, bool* early_termination) const { - // Return if iterations are exceeding predefined number - SubGraphCollection_t nodes_list_output; - if (iterations > max_iterations) { - *early_termination = true; - return nodes_list_output; - } - - iterations++; - for (const auto& group : nodes_vector_input) { - // Construct subgraph - if (!group.first.empty()) { - if (group.second) { - nodes_list_output.push_back(group); - } else { - //const OrtGraphViewer* sub_graph_viewer = nullptr; - //graph_api_->OrtGraph_GetSubGraph(graph, group.first.size(), group.first.data(), &sub_graph_viewer); - - void* buf_data = nullptr; - size_t buf_size = 0; - graph_api_->OrtGraph_SerializeToArray(sub_graph_viewer, &buf_data, &buf_size); - - // Get supported node list recursively - SubGraphCollection_t parser_nodes_list; - TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); - auto trt_builder = GetBuilder(trt_logger); - auto network_flags = 0; -#if NV_TENSORRT_MAJOR > 8 - network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); -#endif - network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); - auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); - - auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - trt_parser->supportsModel(buf_data, buf_size, parser_nodes_list, model_path_); - graph_api_->OrtFreeMem(buf_data); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - - SubGraphCollection_t next_nodes_list; - const size_t* subgraph_node_index = nullptr; - size_t subgraph_node_count = 0; - graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(sub_graph_viewer, 1, &subgraph_node_index, &subgraph_node_count); - next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, sub_graph_viewer, early_termination); - for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { - for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { - next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]]; - } - nodes_list_output.push_back(next_nodes_list[i]); - } - graph_api_->OrtGraph_ReleaseGraphViewer(sub_graph_viewer, true); - } - } - } - return nodes_list_output; -} - -} // namespace onnxruntime - -#ifdef __cplusplus -extern "C" { -#endif -OrtExecutionProviderFactory* RegisterCustomEp() { - std::unique_ptr ret = std::make_unique(); - return ret.release(); -} -#ifdef __cplusplus -} -#endif +//} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index 3566d3dc..294f5d03 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -19,7 +19,7 @@ #define EXPORT_API #endif -namespace onnxruntime { +//namespace onnxruntime { namespace tensorrt_env_vars { static const std::string kMaxPartitionIterations = "ORT_TENSORRT_MAX_PARTITION_ITERATIONS"; @@ -399,4 +399,4 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs { nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const; }; -} // namespace onnxruntime +//} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h index 60ff20e7..2ff91908 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h @@ -1,18 +1,228 @@ -#pragma once -#include -#include -#include +#include #include +#include #include -#include +#include +#include +#include #include "flatbuffers/idl.h" #include "ort_trt_int8_cal_table.fbs.h" -#include "murmurhash3.h" -#include "path_string.h" +#include +//#include "core/providers/cuda/cuda_pch.h" +//#include "core/common/path_string.h" +//#include "core/framework/murmurhash3.h" namespace fs = std::filesystem; -namespace onnxruntime { +//namespace onnxruntime { + +// Check if cycle exists in the graph after partitioning +/* +bool FindCycleHelper(size_t i, gsl::span> adjacency_map, gsl::span visited, + gsl::span st, InlinedVector& cycles) { + if (!visited[i]) { + visited[i] = true; + st[i] = true; + for (auto iter = adjacency_map[i].begin(); iter != adjacency_map[i].end(); ++iter) { + if (!visited[*iter] && FindCycleHelper(*iter, adjacency_map, visited, st, cycles)) { + cycles.push_back(*iter); + return true; + } else if (st[*iter]) { + cycles.push_back(*iter); + return true; + } + } + } + st[i] = false; + return false; +} +*/ + +bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map& dynamic_range_map) { + // Set dynamic range for input tensors + for (int i = 0; i < network.getNbInputs(); ++i) { + const std::string tensor_name = network.getInput(i)->getName(); + auto dynamic_range_iter = dynamic_range_map.find(tensor_name); + if (dynamic_range_iter != dynamic_range_map.end()) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (!network.getInput(i)->setDynamicRange(-dynamic_range_iter->second, dynamic_range_iter->second)) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + //LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for network input " << tensor_name; + return false; + } + } + } + + // Set dynamic range for activations and weights + for (int i = 0; i < network.getNbLayers(); ++i) { + auto trt_layer = network.getLayer(i); + for (int j = 0, e = trt_layer->getNbOutputs(); j < e; ++j) { + const std::string tensor_name = trt_layer->getOutput(j)->getName(); + auto dynamic_range_iter = dynamic_range_map.find(tensor_name); + if (dynamic_range_iter != dynamic_range_map.end()) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (!trt_layer->getOutput(j)->setDynamicRange(-dynamic_range_iter->second, dynamic_range_iter->second)) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + //LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for tensor " << tensor_name; + return false; + } + } else if (trt_layer->getType() == nvinfer1::LayerType::kCONSTANT) { + nvinfer1::IConstantLayer* const_layer = static_cast(trt_layer); + const std::string const_layer_name = const_layer->getName(); + auto trt_weights = const_layer->getWeights(); + double max_weight = std::numeric_limits::min(); + for (int64_t k = 0, end = trt_weights.count; k < end; ++k) { + double weight{}; + switch (trt_weights.type) { + case nvinfer1::DataType::kFLOAT: + weight = static_cast(trt_weights.values)[k]; + break; + case nvinfer1::DataType::kBOOL: + weight = static_cast(trt_weights.values)[k]; + break; + case nvinfer1::DataType::kINT8: + weight = static_cast(trt_weights.values)[k]; + break; + case nvinfer1::DataType::kHALF: + weight = static_cast(trt_weights.values)[k]; + break; + case nvinfer1::DataType::kINT32: + weight = static_cast(trt_weights.values)[k]; + break; +#if NV_TENSORRT_MAJOR >= 10 + case nvinfer1::DataType::kINT64: + weight = static_cast(static_cast(trt_weights.values)[k]); + break; +#endif // NV_TENSORRT_MAJOR >= 10 + default: + //LOGS_DEFAULT(ERROR) << "Found unsupported datatype for layer " << const_layer_name; + return false; + } + max_weight = std::max(max_weight, std::abs(weight)); + } +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (!trt_layer->getOutput(j)->setDynamicRange(static_cast(-max_weight), + static_cast(max_weight))) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + //LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for layer " << const_layer_name; + return false; + } + } + } + } + return true; +} + +std::vector SplitToStringVec(std::string const& s, char separator) { + std::vector splitted; + + for (size_t start = 0; start < s.length();) { + size_t separatorIndex = s.find(separator, start); + if (separatorIndex == std::string::npos) { + separatorIndex = s.length(); + } + splitted.emplace_back(s.substr(start, separatorIndex - start)); + start = separatorIndex + 1; + } + + return splitted; +} + +nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { + nvinfer1::TacticSources disabledTactics = 0; + nvinfer1::TacticSources enabledTactics = 0; + std::vector tacticList = SplitToStringVec(tactic_string, ','); + for (auto& t : tacticList) { + bool enable{false}; + if (t.front() == '+') { + enable = true; + } else if (t.front() != '-') { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic source must be prefixed with + or - skipping: " << t; + } + t.erase(0, 1); + + const auto toUpper = [](std::string& sourceName) { + std::transform(sourceName.begin(), sourceName.end(), sourceName.begin(), + [](char c) { return onnxruntime::narrow(std::toupper(c)); }); + return sourceName; + }; + + nvinfer1::TacticSource source{}; + t = toUpper(t); + if (t == "CUBLAS") { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS is deprecated in TensorRT 10.0"; +#if NV_TENSORRT_MAJOR < 10 + source = nvinfer1::TacticSource::kCUBLAS; +#endif + } else if (t == "CUBLASLT" || t == "CUBLAS_LT") { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS_LT is deprecated in TensorRT 9.0"; +#if NV_TENSORRT_MAJOR < 9 + source = nvinfer1::TacticSource::kCUBLAS_LT; +#endif + } else if (t == "CUDNN") { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUDNN is deprecated in TensorRT 10.0"; +#if NV_TENSORRT_MAJOR < 10 + source = nvinfer1::TacticSource::kCUDNN; +#endif + } else if (t == "EDGE_MASK_CONVOLUTIONS") { + source = nvinfer1::TacticSource::kEDGE_MASK_CONVOLUTIONS; + } else if (t == "JIT_CONVOLUTIONS") { + source = nvinfer1::TacticSource::kJIT_CONVOLUTIONS; + } else { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic source was not found with name: " << t; + } + + uint32_t sourceBit = 1U << static_cast(source); + + if (enable) { + enabledTactics |= sourceBit; + } else { + disabledTactics |= sourceBit; + } + } + return enabledTactics & ~disabledTactics; +} + +inline std::vector loadTimingCacheFile(const std::string inFileName) { + std::ifstream iFile(inFileName, std::ios::in | std::ios::binary); + if (!iFile) { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not read timing cache from: " << inFileName + // << ". A new timing cache will be generated and written."; + return std::vector(); + } + iFile.seekg(0, std::ifstream::end); + size_t fsize = iFile.tellg(); + iFile.seekg(0, std::ifstream::beg); + std::vector content(fsize); + iFile.read(content.data(), fsize); + iFile.close(); + return content; +} + +inline void saveTimingCacheFile(const std::string outFileName, const nvinfer1::IHostMemory* blob) { + std::ofstream oFile(outFileName, std::ios::out | std::ios::binary); + if (!oFile) { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not write timing cache to: " << outFileName; + return; + } + oFile.write((char*)blob->data(), blob->size()); + oFile.close(); +} float ConvertSinglePrecisionIEEE754ToFloat(unsigned long input) { int s = (input >> 31) & 0x01; @@ -25,6 +235,24 @@ float ConvertSinglePrecisionIEEE754ToFloat(unsigned long input) { return static_cast((s ? -1 : 1) * pow(2.0, e) * (m + 1.0)); } +/* + * Read calibration table for INT8 quantization + * Two kind of calibration tables are supported, + * 1. ORT generated calibration table + * The table is pre-serialized by flatbuffers. + * Each entry in the table is a key-value pair, + * key: tensor name, value: maximum absolute value in floating point + * For example, + * data_0 2.008338 + * ... + * 2. Native TensorRT generated calibration table + * Data format is defined by TensorRT as, + * tensor name : scale in 32-bit single precision IEEE754 format + * For example, + * TRT-7103-EntropyCalibration2 + * data_0: 4000889d + * ... + */ bool ReadDynamicRange(const std::string file_name, const bool is_trt_calibration_table, std::unordered_map& dynamic_range_map) { std::ifstream infile(file_name, std::ios::binary | std::ios::in); if (!infile) { @@ -74,6 +302,18 @@ bool ReadDynamicRange(const std::string file_name, const bool is_trt_calibration return true; } +/* + * Get number of profile setting. + * + * profile_min_shapes/profile_max_shapes/profile_opt_shapes may contain multiple profile settings. + * Note: TRT EP currently only supports one profile setting. + * + * { + * tensor_a: [[dim_0_value_0, dim_1_value_1, dim_2_value_2]], + * tensor_b: [[dim_0_value_3, dim_1_value_4, dim_2_value_5]] + * } + * + */ int GetNumProfiles(std::unordered_map>>& profile_shapes) { int num_profile = 0; for (auto it = profile_shapes.begin(); it != profile_shapes.end(); it++) { @@ -85,23 +325,134 @@ int GetNumProfiles(std::unordered_map>>& shape_ranges) { + // Serialize profile + flexbuffers::Builder builder; + auto profile_start = builder.StartMap(); + for (auto outer_it = shape_ranges.begin(); outer_it != shape_ranges.end(); ++outer_it) { + builder.TypedVector(outer_it->first.c_str(), [&] { + for (auto inner_it = outer_it->second.begin(); inner_it != outer_it->second.end(); ++inner_it) { + builder.Int(inner_it->first); + builder.Int(inner_it->second.first); + builder.Int(inner_it->second.second); + } + }); + } + builder.EndMap(profile_start); + builder.Finish(); + + // Save flexbuffer + std::ofstream file(file_name, std::ios::binary | std::ios::out); + auto buf = builder.GetBuffer(); + size_t size = builder.GetSize(); + file.write(reinterpret_cast(&buf[0]), size); + file.close(); +} + +// Deserialize engine profile +// [Deprecated] Use DeserializeProfileV2 +std::unordered_map>> DeserializeProfile(std::ifstream& infile) { + // Load flexbuffer + infile.seekg(0, std::ios::end); + size_t length = infile.tellg(); + infile.seekg(0, std::ios::beg); + std::unique_ptr data{new char[length]}; + infile.read((char*)data.get(), length); + infile.close(); + + // Deserialize profile + std::unordered_map>> shape_ranges; + auto tensors_range_entries = flexbuffers::GetRoot((const uint8_t*)data.get(), length).AsMap(); + auto keys = tensors_range_entries.Keys(); + auto values = tensors_range_entries.Values(); + for (size_t i = 0, i_end = keys.size(); i < i_end; ++i) { + auto dim_range_vectors = values[i].AsTypedVector(); + std::unordered_map> inner_map; + for (size_t j = 0, j_end = dim_range_vectors.size() / 3; j < j_end; ++j) { + size_t idx = 3 * j; + inner_map[dim_range_vectors[idx].AsInt64()] = std::make_pair(dim_range_vectors[idx + 1].AsInt64(), dim_range_vectors[idx + 2].AsInt64()); + } + shape_ranges[keys[i].AsString().c_str()] = inner_map; + } + return shape_ranges; +} + +/* + * Seralize engine profile. (This function starts from ORT 1.15) + * + * + * (1) Single profile case: + * Assume tensor_a has two dynamic shape dimensions: dim_0 and dim_2, + * and tensor_b has one dynamic shape dimension: dim_1. + * + * The data before serialization will be: + * { + * tensor_a: { + * dim_0: [[min_shape_0, max_shape_0, opt_shape_0]], + * dim_2: [[min_shape_2, max_shape_2, opt_shape_2]] + * }, + * tensor_b: { + * dim_1: [[min_shape_1, max_shape_1, opt_shape_1]] + * } + * } + * + * The data after serialization will be: + * { + * tensor_a: [dim_0, min_shape_0, max_shape_0, opt_shape_0, dim_2, min_shape_2, max_shape_2, opt_shape_2] + * tensor_b: [dim_1, min_shape_1, max_shape_1, opt_shape_1] + * } + * + * + * (2) Multiple profiles case: + * For example, if the data before serialization is: + * { + * tensor_a: { + * dim_0: [[min_shape_0, max_shape_0, opt_shape_0], [min_shape_1, max_shape_1, opt_shape_1]] + * }, + * tensor_b: { + * dim_1: [[min_shape_2, max_shape_2, opt_shape_2], [min_shape_3, max_shape_3, opt_shape_3]] + * } + * } + * + * The data after serialization will be: + * { + * tensor_a: [dim_0, min_shape_0, max_shape_0, opt_shape_0, dim_0, min_shape_1, max_shape_1, opt_shape_1] + * | | | | + * ---------------- profile 0 ----------------- ---------------- profile 1 ----------------- + * + * tensor_b: [dim_1, min_shape_2, max_shape_2, opt_shape_2, dim_1, min_shape_3, max_shape_3, opt_shape_3] + * | | | | + * ---------------- profile 0 ----------------- ---------------- profile 1 ----------------- + * } + * + */ void SerializeProfileV2(const std::string& file_name, std::unordered_map>>>& shape_ranges) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] In SerializeProfileV2()"; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] In SerializeProfileV2()"; // Serialize profile flexbuffers::Builder builder; auto tensor_map_start = builder.StartMap(); for (auto tensor_it = shape_ranges.begin(); tensor_it != shape_ranges.end(); tensor_it++) { // iterate tensors - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] input tensor is '" << tensor_it->first.c_str() << "'"; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] input tensor is '" << tensor_it->first.c_str() << "'"; builder.TypedVector(tensor_it->first.c_str(), [&] { for (auto dim_it = tensor_it->second.begin(); dim_it != tensor_it->second.end(); dim_it++) { size_t num_profiles = dim_it->second.size(); for (size_t i = 0; i < num_profiles; i++) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] profile #" << i << ", dim is " << dim_it->first; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] profile #" << i << ", dim is " << dim_it->first; builder.Int(dim_it->first); builder.Int(dim_it->second[i][0]); builder.Int(dim_it->second[i][1]); builder.Int(dim_it->second[i][2]); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << dim_it->first << ", " << dim_it->second[i][0] << ", " << dim_it->second[i][1] << ", " << dim_it->second[i][2]; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << dim_it->first << ", " << dim_it->second[i][0] << ", " << dim_it->second[i][1] << ", " << dim_it->second[i][2]; } } }); @@ -117,8 +468,56 @@ void SerializeProfileV2(const std::string& file_name, std::unordered_map>>> DeserializeProfileV2(std::ifstream& infile) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] In DeserializeProfileV2()"; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] In DeserializeProfileV2()"; // Load flexbuffer infile.seekg(0, std::ios::end); size_t length = infile.tellg(); @@ -133,7 +532,7 @@ std::unordered_map>> inner_map; std::vector> profile_vector; @@ -150,20 +549,25 @@ std::unordered_map>>& profile_min_shapes, std::unordered_map>>& profile_max_shapes, std::unordered_map>>& profile_opt_shapes) { std::ifstream profile_file(file_name, std::ios::binary | std::ios::in); if (!profile_file) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << file_name << " doesn't exist."; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << file_name << " doesn't exist."; return true; } @@ -193,7 +597,7 @@ bool CompareProfiles(const std::string& file_name, // Check number of dynamic shape inputs if (profile_min_shapes.size() != shape_ranges.size()) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Numbers of dynamic shape inputs are not the same."; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Numbers of dynamic shape inputs are not the same."; return true; } @@ -201,7 +605,7 @@ bool CompareProfiles(const std::string& file_name, for (auto tensor_it = shape_ranges.begin(); tensor_it != shape_ranges.end(); tensor_it++) { // iterate tensors auto tensor_name = tensor_it->first; if (profile_min_shapes.find(tensor_name) == profile_min_shapes.end()) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tensor name '" << tensor_name << "' doesn't exist in trt_profile_min_shapes."; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tensor name '" << tensor_name << "' doesn't exist in trt_profile_min_shapes."; return true; } @@ -210,35 +614,35 @@ bool CompareProfiles(const std::string& file_name, auto num_profiles = GetNumProfiles(profile_min_shapes); if (dim_it->second.size() != static_cast(num_profiles)) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Numbers of profiles are not the same."; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Numbers of profiles are not the same."; return true; } for (size_t i = 0; i < dim_it->second.size(); i++) { // iterate (multiple) profile(s) auto shape_values = dim_it->second[i]; if (dim > (profile_min_shapes[tensor_name][i].size() - 1)) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dimension " << dim << " of '" << tensor_name << "' in " << file_name << " exceeds the total dimension of trt_profile_min_shapes."; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dimension " << dim << " of '" << tensor_name << "' in " << file_name << " exceeds the total dimension of trt_profile_min_shapes."; return true; } - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] min shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_min_shapes[tensor_name][i][dim]; - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] min shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[0] << " in " << file_name; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] min shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_min_shapes[tensor_name][i][dim]; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] min shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[0] << " in " << file_name; if (profile_min_shapes[tensor_name][i][dim] != shape_values[0]) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] min shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] min shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; return true; } - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] max shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_max_shapes[tensor_name][i][dim]; - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] max shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[1] << " in " << file_name; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] max shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_max_shapes[tensor_name][i][dim]; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] max shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[1] << " in " << file_name; if (profile_max_shapes[tensor_name][i][dim] != shape_values[1]) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] max shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] max shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; return true; } - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] opt shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_opt_shapes[tensor_name][i][dim]; - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] opt shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[2] << " in " << file_name; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] opt shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_opt_shapes[tensor_name][i][dim]; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] opt shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[2] << " in " << file_name; if (profile_opt_shapes[tensor_name][i][dim] != shape_values[2]) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] opt shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] opt shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; return true; } } @@ -247,6 +651,10 @@ bool CompareProfiles(const std::string& file_name, return false; } +/* + * Get cache by name + * + */ std::string GetCachePath(const std::string& root, const std::string& name) { if (root.empty()) { return name; @@ -257,11 +665,19 @@ std::string GetCachePath(const std::string& root, const std::string& name) { } } +/* + * Get compute capability + * + */ std::string GetComputeCapacity(const cudaDeviceProp& prop) { const std::string compute_capability = std::to_string(prop.major * 10 + prop.minor); return compute_capability; } +/* + * Get Timing by compute capability + * + */ std::string GetTimingCachePath(const std::string& root, std::string& compute_cap) { // append compute capability of the GPU as this invalidates the cache and TRT will throw when loading the cache const std::string timing_cache_name = "TensorrtExecutionProvider_cache_sm" + @@ -270,30 +686,67 @@ std::string GetTimingCachePath(const std::string& root, std::string& compute_cap } /* -HashValue TRTGenerateId(const OrtApi& api, const OrtGraph* graph, std::string trt_version, std::string cuda_version) { + * Get cache by type + * + * \param root root path of the cache + * \param file_extension It could be ".engine", ".profile" or ".timing" + */ +std::vector GetCachesByType(const std::string& root, std::string file_extension) { + std::vector cache_files; + for (const auto& entry : fs::directory_iterator(root)) { + if (fs::path(file_extension) == fs::path(entry).extension()) { + cache_files.push_back(fs::path(entry)); + } + } + return cache_files; +} + +bool IsCacheExistedByType(const std::string& root, std::string file_extension) { + auto cache_files = GetCachesByType(root, file_extension); + if (cache_files.size() == 0) { + return false; + } + return true; +} + +void RemoveCachesByType(const std::string& root, std::string file_extension) { + auto cache_files = GetCachesByType(root, file_extension); + for (const auto& entry : cache_files) { + fs::remove(entry); + } +} + +/** + * + * Helper class to generate engine id via model name/model content/env metadata + * + * + * The TensorRT Execution Provider is used in multiple sessions and the underlying infrastructure caches + * compiled kernels, so the name must be unique and deterministic across models and sessions. + * + */ +/* +HashValue TRTGenerateId(const GraphViewer& graph_viewer, std::string trt_version, std::string cuda_version) { HashValue model_hash = 0; - - //// find the top level graph - //const Graph* cur_graph = &graph_viewer.GetGraph(); - //while (cur_graph->IsSubgraph()) { - // cur_graph = cur_graph->ParentGraph(); - //} + // find the top level graph + const Graph* cur_graph = &graph_viewer.GetGraph(); + while (cur_graph->IsSubgraph()) { + cur_graph = cur_graph->ParentGraph(); + } + const Graph& main_graph = *cur_graph; uint32_t hash[4] = {0, 0, 0, 0}; auto hash_str = [&hash](const std::string& str) { - MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); + MurmurHash3::x86_128(str.data(), str.size(), hash[0], &hash); }; - const std::filesystem::path* model_path = nullptr; - api.OrtGraph_GetModelPath(graph_viewer, reinterpret_cast(&model_path)); - // Use the model's file name instead of the entire path to avoid cache regeneration if path changes - if (model_path->has_filename()) { - std::string model_name = PathToUTF8String(model_path->filename()); + if (main_graph.ModelPath().has_filename()) { + std::string model_name = PathToUTF8String(main_graph.ModelPath().filename()); - // LOGS_DEFAULT(INFO) << "[TensorRT EP] Model name is " << model_name; + //LOGS_DEFAULT(INFO) << "[TensorRT EP] Model name is " << model_name; // Ensure enough characters are hashed in case model names are too short const size_t model_name_length = model_name.size(); constexpr size_t hash_string_length = 500; @@ -303,36 +756,24 @@ HashValue TRTGenerateId(const OrtApi& api, const OrtGraph* graph, std::string tr } hash_str(repeat_model_name); } else { - // LOGS_DEFAULT(INFO) << "[TensorRT EP] Model path is empty"; + //LOGS_DEFAULT(INFO) << "[TensorRT EP] Model path is empty"; } // fingerprint current graph by hashing graph inputs - // const std::vector& input_names = nullptr; - const char** input_names = nullptr; // TODO(leca): release input_names - size_t input_count = 0; - api.OrtGraph_GetAllInputs(graph_viewer, &input_names, &input_count); - for (size_t i = 0; i < input_count; ++i) { - hash_str(input_names[i]); + for (const auto* node_arg : graph_viewer.GetInputsIncludingInitializers()) { + hash_str(node_arg->Name()); } // hashing output of each node - int number_of_ort_nodes = 0; - api.OrtGraph_NumberOfNodes(graph_viewer, &number_of_ort_nodes); + const int number_of_ort_nodes = graph_viewer.NumberOfNodes(); std::vector nodes_vector(number_of_ort_nodes); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); - const size_t* nodes_index = nullptr; - size_t nodes_count = 0; - api.OrtGraph_GetNodesIndexInTopologicalOrder(graph_viewer, 0, &nodes_index, &nodes_count); + const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(); for (const auto& index : nodes_vector) { - const OrtNode* node = nullptr; - graph_api->OrtGraph_GetOrtNode(graph_viewer, nodes_index[index], &node); - size_t output_size = 0; - graph_api->OrtNode_GetNumOutputs(node, &output_size); - for (size_t i = 0; i < output_size; ++i) { - const char* output_name = nullptr; - graph_api->OrtNode_GetIthOutputName(node, i, &output_name); - if (output_name != nullptr) { - hash_str(output_name); + const auto& node = graph_viewer.GetNode(node_index[index]); + for (const auto* node_arg : node->OutputDefs()) { + if (node_arg->Exists()) { + hash_str(node_arg->Name()); } } } @@ -343,6 +784,10 @@ HashValue TRTGenerateId(const OrtApi& api, const OrtGraph* graph, std::string tr hash_str("WINDOWS"); #endif +#ifdef ORT_VERSION + hash_str(ORT_VERSION); +#endif + #ifdef CUDA_VERSION hash_str(cuda_version); #endif @@ -358,6 +803,127 @@ HashValue TRTGenerateId(const OrtApi& api, const OrtGraph* graph, std::string tr } */ +bool ValidateProfileShapes(std::unordered_map>>& profile_min_shapes, + std::unordered_map>>& profile_max_shapes, + std::unordered_map>>& profile_opt_shapes) { + if (profile_min_shapes.empty() && profile_max_shapes.empty() && profile_opt_shapes.empty()) { + return true; + } + + if ((profile_min_shapes.size() != profile_max_shapes.size()) && + (profile_min_shapes.size() != profile_opt_shapes.size()) && + (profile_max_shapes.size() != profile_opt_shapes.size())) { + return false; + } + + std::unordered_map>>::iterator it; + for (it = profile_min_shapes.begin(); it != profile_min_shapes.end(); it++) { + auto input_name = it->first; + auto num_profile = it->second.size(); + + // input_name must also be in max/opt profile + if ((profile_max_shapes.find(input_name) == profile_max_shapes.end()) || + (profile_opt_shapes.find(input_name) == profile_opt_shapes.end())) { + return false; + } + + // number of profiles should be the same + if ((num_profile != profile_max_shapes[input_name].size()) || + (num_profile != profile_opt_shapes[input_name].size())) { + return false; + } + } + + return true; +} + +/* + * Make input-name and shape as a pair. + * This helper function is being used by ParseProfileShapes(). + * + * For example: + * The input string is "input_id:32x1", + * after the string is being parsed, the pair object is returned as below. + * pair("input_id", [32, 1]) + * + * Return true if string can be successfully parsed or false if string has wrong format. + */ +bool MakeInputNameShapePair(std::string pair_string, std::pair>& pair) { + if (pair_string.empty()) { + return true; + } + + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << pair_string; + + std::stringstream input_string_stream(pair_string); + char first_delim = ':'; + char second_delim = 'x'; + std::string input_name; + std::string shape; + std::getline(input_string_stream, input_name, first_delim); + std::getline(input_string_stream, shape, first_delim); + + std::vector shapes; + std::stringstream shape_string_stream(shape); + std::string value; + while (std::getline(shape_string_stream, value, second_delim)) { + shapes.push_back(std::stoi(value)); + } + + // wrong input string + if (input_name.empty() || shapes.empty()) { + return false; + } + + pair.first = input_name; + pair.second = shapes; + + return true; +} + +/* + * Parse explicit profile min/max/opt shapes from TensorRT EP provider options. + * + * For example: + * The provider option is --trt_profile_min_shapes="input_id:32x1,attention_mask:32x1,input_id:32x41,attention_mask:32x41", + * after string is being parsed, the profile shapes has two profiles and is being represented as below. + * {"input_id": [[32, 1], [32, 41]], "attention_mask": [[32, 1], [32, 41]]} + * + * Return true if string can be successfully parsed or false if string has wrong format. + */ +bool ParseProfileShapes(std::string profile_shapes_string, std::unordered_map>>& profile_shapes) { + if (profile_shapes_string.empty()) { + return true; + } + + std::stringstream input_string_stream(profile_shapes_string); + char delim = ','; + std::string input_name_with_shape; // input_name:shape, ex: "input_id:32x1" + while (std::getline(input_string_stream, input_name_with_shape, delim)) { + std::pair> pair; + if (!MakeInputNameShapePair(input_name_with_shape, pair)) { + return false; + } + + std::string input_name = pair.first; + if (profile_shapes.find(input_name) == profile_shapes.end()) { + std::vector> profile_shape_vector; + profile_shapes[input_name] = profile_shape_vector; + } + profile_shapes[input_name].push_back(pair.second); + + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << input_name; + std::string shape_string = ""; + for (auto v : pair.second) { + shape_string += std::to_string(v); + shape_string += ", "; + } + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << shape_string; + } + + return true; +} + std::vector split(const std::string& str, char delimiter) { std::vector tokens; std::string token; @@ -379,6 +945,14 @@ std::string join(const std::vector& vec, const std::string& delimit return result; } +/* + * Parse engine cache name suffix when user customizes prefix for engine cache name + * + * For example: + * When default subgraph name is "TensorrtExecutionProvider_TRTKernel_graph_torch-jit-export_2068723788287043730_189_189_fp16" + * This func will generate the suffix "2068723788287043730_189_fp16" + * + */ std::string GetCacheSuffix(const std::string& fused_node_name, const std::string& trt_node_name_with_precision) { std::vector split_fused_node_name = split(fused_node_name, '_'); if (split_fused_node_name.size() >= 3) { @@ -394,4 +968,4 @@ std::string GetCacheSuffix(const std::string& fused_node_name, const std::string } return ""; } -} +//} // namespace onnxruntime From a5a294e05241b90a61bc8c173a6a4b5854eeca14 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 25 Jun 2025 13:36:13 -0700 Subject: [PATCH 09/60] remove onnxruntime namespace --- .../tensorrt/tensorrt_execution_provider.cc | 4 ---- .../tensorrt/tensorrt_execution_provider.h | 3 --- 2 files changed, 7 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index cff6dfde..9d3883db 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -33,8 +33,6 @@ void CUDA_RETURN_IF_ERROR(cudaError_t res) { if (res != cudaSuccess) abort(); } -//namespace onnxruntime { - static const std::string tensorrtEp = "tensorrtEp"; const OrtApi& ort_api = Ort::GetApi(); @@ -1692,5 +1690,3 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect } return nodes_list_output; } - -//} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index 294f5d03..d4017837 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -19,8 +19,6 @@ #define EXPORT_API #endif -//namespace onnxruntime { - namespace tensorrt_env_vars { static const std::string kMaxPartitionIterations = "ORT_TENSORRT_MAX_PARTITION_ITERATIONS"; static const std::string kMinSubgraphSize = "ORT_TENSORRT_MIN_SUBGRAPH_SIZE"; @@ -399,4 +397,3 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs { nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const; }; -//} // namespace onnxruntime From f990a7b0208735a05fc1eb8584370876bf1ddea1 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 25 Jun 2025 13:38:01 -0700 Subject: [PATCH 10/60] update --- .../tensorrt/tensorrt_provider_factory.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index ebd36131..4937e1c4 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -104,14 +104,14 @@ static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, // const OrtHardwareDevice* device = devices[0]; // const OrtKeyValuePairs* ep_metadata = ep_metadata[0]; - auto trt_ep = std::make_unique(*factory, factory->ep_name_, *session_options, *logger); + auto trt_ep = std::make_unique(*factory, factory->ep_name_, *session_options, *logger); *ep = trt_ep.release(); return nullptr; } static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) { - onnxruntime::TensorrtExecutionProvider* trt_ep = static_cast(ep); + TensorrtExecutionProvider* trt_ep = static_cast(ep); delete trt_ep; } From 7851a1c26ebf0949edb40b82ec383c6a78bf4840 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Sat, 28 Jun 2025 11:53:04 -0700 Subject: [PATCH 11/60] Add TRTEpNodeComputeInfo --- .../tensorrt/tensorrt_execution_provider.cc | 2240 ++++++++++++++++- .../tensorrt/tensorrt_execution_provider.h | 52 +- 2 files changed, 2194 insertions(+), 98 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 9d3883db..da5e29de 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -13,7 +13,7 @@ #include "tensorrt_execution_provider.h" #include "tensorrt_execution_provider_utils.h" #include "tensorrt_cuda_allocator.h" -#include "onnx_ctx_model_helper.h" +//#include "onnx_ctx_model_helper.h" #include "onnx/onnx_pb.h" #include "cuda/unary_elementwise_ops_impl.h" @@ -33,7 +33,6 @@ void CUDA_RETURN_IF_ERROR(cudaError_t res) { if (res != cudaSuccess) abort(); } -static const std::string tensorrtEp = "tensorrtEp"; const OrtApi& ort_api = Ort::GetApi(); /* @@ -171,6 +170,16 @@ std::unique_lock TensorrtExecutionProvider::GetApiLock() const { return std::unique_lock(singleton); } +nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder(TensorrtLogger& trt_logger) const { + if (!builder_) { + { + auto lock = GetApiLock(); + builder_ = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + } + } + return builder_.get(); +} + template void GetShapeOfShapeTensor(Ort::ConstValue& input_tensor, void* shape_values, @@ -639,7 +648,14 @@ OrtStatusPtr BindContextInput(Ort::KernelContext& ctx, CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) #endif // Cast double input to float because TensorRT doesn't support double - CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + auto input_tensor_ptr = input_tensor.GetTensorData(); if (input_tensor_ptr != nullptr && elem_cnt > 0) { + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(float))); data = scratch_buffers.back().get(); cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), elem_cnt); + } + else { + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, 1)); data = scratch_buffers.back().get(); + } break; +} default: { return ort_api.CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported.").c_str()); } @@ -743,39 +759,1529 @@ OrtStatusPtr BindKernelOutput(Ort::KernelContext& ctx, */ auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); - /* - * Copy output data from allocation buffer to ORT kernel context output location or - * cast (int32 or float) -> (int64 or double) to ORT kernel context output location. - * - * Note: - * 1. If the output tensor is empty tensor (i.e. any of the dimension is 0) which means element count is 0, - * TRT EP does not perform cuda memory copy nor cuda cast to prevent overwriting other location that might belong to other tensors. - * 2. The cudaMemcpyAsync() and cuda::Impl_Cast() (implemented as _UnaryElementWise() in cuda ep) are all async, but we - * don't need to explicitly call cudaStreamSynchronize() after those APIs due to CUDA EP and TRT EP uses same stream, - * and within the same stream, operations are guaranteed to be executed in order. - */ - switch (output_type) { - CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) - CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) - CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) - CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) - CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) - CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) -#if NV_TENSORRT_MAJOR >= 10 - CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) -#else - // The allocation buffer holds the int32 output data since TRT doesn't support int64. So, we need to cast the data (int32 -> int64) for ORT kernel output. -// CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int32_t, int64_t) + /* + * Copy output data from allocation buffer to ORT kernel context output location or + * cast (int32 or float) -> (int64 or double) to ORT kernel context output location. + * + * Note: + * 1. If the output tensor is empty tensor (i.e. any of the dimension is 0) which means element count is 0, + * TRT EP does not perform cuda memory copy nor cuda cast to prevent overwriting other location that might belong to other tensors. + * 2. The cudaMemcpyAsync() and cuda::Impl_Cast() (implemented as _UnaryElementWise() in cuda ep) are all async, but we + * don't need to explicitly call cudaStreamSynchronize() after those APIs due to CUDA EP and TRT EP uses same stream, + * and within the same stream, operations are guaranteed to be executed in order. + */ + switch (output_type) { + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) +#if NV_TENSORRT_MAJOR >= 10 + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) +#else + // The allocation buffer holds the int32 output data since TRT doesn't support int64. So, we need to cast the data (int32 -> int64) for ORT kernel output. +// CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int32_t, int64_t) +#endif + // The allocation buffer holds the float output data since TRT doesn't support double. So, we need to cast the data (float -> double) for ORT kernel output. + // CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) + default: { + return ort_api.CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported.").c_str()); + } + } + return nullptr; +} + +SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t nodes_vector_input, + int iterations, const int max_iterations, + const OrtGraph* graph, bool* early_termination) const { + // Return if iterations are exceeding predefined number + SubGraphCollection_t nodes_list_output; + if (iterations > max_iterations) { + *early_termination = true; + return nodes_list_output; + } + + iterations++; + for (const auto& group : nodes_vector_input) { + // Construct subgraph + if (!group.first.empty()) { + if (group.second) { + nodes_list_output.push_back(group); + } else { + // const OrtGraphViewer* sub_graph_viewer = nullptr; + // graph_api_->OrtGraph_GetSubGraph(graph, group.first.size(), group.first.data(), &sub_graph_viewer); + + void* buf_data = nullptr; + size_t buf_size = 0; + graph_api_->OrtGraph_SerializeToArray(sub_graph_viewer, &buf_data, &buf_size); + + // Get supported node list recursively + SubGraphCollection_t parser_nodes_list; + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); + auto trt_builder = GetBuilder(trt_logger); + auto network_flags = 0; +#if NV_TENSORRT_MAJOR > 8 + network_flags |= fp16_enable_ || int8_enable_ + ? 0 + : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); +#endif + network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); + + auto trt_parser = + tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + trt_parser->supportsModel(buf_data, buf_size, parser_nodes_list, model_path_); + graph_api_->OrtFreeMem(buf_data); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + + SubGraphCollection_t next_nodes_list; + const size_t* subgraph_node_index = nullptr; + size_t subgraph_node_count = 0; + graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(sub_graph_viewer, 1, &subgraph_node_index, + &subgraph_node_count); + next_nodes_list = + GetSupportedList(parser_nodes_list, iterations, max_iterations, sub_graph_viewer, early_termination); + for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { + for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { + next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]]; + } + nodes_list_output.push_back(next_nodes_list[i]); + } + graph_api_->OrtGraph_ReleaseGraphViewer(sub_graph_viewer, true); + } + } + } + return nodes_list_output; +} + +OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this_ptr, + const OrtGraph* graph, + const OrtNode* fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo* node_compute_info) { + TensorrtExecutionProvider* ep = static_cast(this_ptr); + + /* + // Reconstruct graph proto from fused node's function body + auto model = graph_body_viewer.CreateModel(*GetLogger()); + auto model_proto = model->ToProto(); + + // ORT's default topological sort is using reversed DFS. + // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. + // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating + // the model proto that has different node ordering compared to original onnx model. + graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true, 1); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + std::string string_buf; + model_proto->SerializeToString(string_buf); + + if (dump_subgraphs_) { + // Dump TensorRT subgraphs + std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary); + model_proto->SerializeToOstream(dump); + } + */ + std::string string_buf; + + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); + auto trt_builder = GetBuilder(trt_logger); + auto network_flags = 0; +#if NV_TENSORRT_MAJOR > 8 + network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_) + ? 0 + : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); +#else + network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); +#endif + auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); + auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); + auto trt_parser = + tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); + trt_parser->parse(string_buf.data(), string_buf.size(), model_path_); + if (max_workspace_size_ > 0) { + trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); + } + + // Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if ((fp16_enable_ || bf16_enable_) && layer_norm_fp32_fallback_) { + for (auto idx = 1; idx < trt_network->getNbLayers() - 1; ++idx) { + auto layer = trt_network->getLayer(idx); + auto next_layer = trt_network->getLayer(idx + 1); + if (layer->getType() == nvinfer1::LayerType::kELEMENTWISE && + next_layer->getType() == nvinfer1::LayerType::kREDUCE && + (static_cast(layer))->getOperation() == nvinfer1::ElementWiseOperation::kPOW) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow"; + layer->setPrecision(nvinfer1::DataType::kFLOAT); + next_layer->setPrecision(nvinfer1::DataType::kFLOAT); + layer->setOutputType(0, nvinfer1::DataType::kFLOAT); + next_layer->setOutputType(0, nvinfer1::DataType::kFLOAT); + } + } + } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + + int num_inputs = trt_network->getNbInputs(); + int num_outputs = trt_network->getNbOutputs(); + std::unordered_map input_indexes(num_inputs); + std::unordered_map output_indexes(num_outputs); + std::unordered_map output_types(num_outputs); + + /* + * Initialize shape range for each dynamic shape input tensor: + * 1) If user explicitly specifies optimization profiles via provider options, TRT EP will create those profiles + * during EP compile time. It won't make adjustment for profile values during EP compute time. + * + * 2) If no explicit optimization profiles provided by user, TRT EP will firstly set min/max/opt shape to [INT_MAX, + * INT_MIN, INT_MIN]. Later in EP compute time, the shape will be adjusted to [min_input_value, max_input_value, + * max_input_value] based on input tensor value. + * + * + * Once the TRT profiles are created: + * 1) If all the dynamic shape input tensors have associated profiles explicitly provided by user, those profiles + * will be applied to TRT builder config and the engine will be built at EP compile time. + * + * 2) As long as one of the dynamic shape input tensors has no explicitly associated profile, TRT EP will create + * default shape as described above, and all the profiles won't be applied and engine won't be built until EP compute + * time. + */ + bool has_dynamic_shape = + false; // True if input tensor has dynamic shape and no explicit profile is specified, otherwise false. + bool has_explicit_profile = false; + bool apply_explicit_profile = false; + int num_profiles = 0; + std::vector trt_profiles; + + // Following c++ map data structure is used to help serialize/deserialize profiles where it saves dynamic shape + // dimension(s) and min/max/opt values for dynamic shape input tensor. + // + // (1) Single profile case: + // For example, assume tensor_a has two dynamic shape dimensions: dim_0 and dim_2, and tensor_b + // has one dynamic shape dimension: dim_1. The data will be: + // { + // tensor_a: { + // dim_0: [[min_shape, max_shape, opt_shape]], + // dim_2: [[min_shape, max_shape, opt_shape]] + // }, + // tensor_b: { + // dim_1: [[min_shape, max_shape, opt_shape]] + // } + // } + // + // (2) Multiple profiles case: + // For example, assume tensor_a has one dynamic shap dimension: dim 0, and tensor_b has one dynamic shape dimension: + // dim_1, and both of the tensors have two profiles. The data will be: + // { + // tensor_a: { + // dim_0: [[min_shape_0, max_shape_0, opt_shape_0], [min_shape_1, max_shape_1, opt_shape_1]] + // }, + // tensor_b: { + // dim_1: [[min_shape_2, max_shape_2, opt_shape_2], [min_shape_3, max_shape_3, opt_shape_3]] + // } + // } + ShapeRangesMap input_explicit_shape_ranges; + ShapeRangesMap input_implicit_shape_ranges; + + if ((!profile_min_shapes_.empty()) && (!profile_max_shapes_.empty()) && (!profile_opt_shapes_.empty())) { + has_explicit_profile = true; + num_profiles = GetNumProfiles(profile_min_shapes_); + for (int i = 0; i < num_profiles; i++) { + trt_profiles.push_back(trt_builder->createOptimizationProfile()); + } + } + + // Iterate all input tensors to check dynamic shape + for (unsigned int i = 0, end = num_inputs; i < end; ++i) { + auto input = trt_network->getInput(i); + const std::string& input_name = input->getName(); + nvinfer1::Dims dims = input->getDimensions(); + int nb_dims = dims.nbDims; + + // Apply explicit optimization profiles provided by user + if (has_explicit_profile) { + apply_explicit_profile = + ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, + profile_opt_shapes_, input_explicit_shape_ranges); + } + + // If no explicit optimization profile is being applied, TRT EP will later set min/max/opt shape values based on + // input tensor values at EP compute time + if (!apply_explicit_profile) { + if (input->isShapeTensor()) { + // Shape tensor + std::vector> profile_vector; + std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN}; + profile_vector.push_back(shape_vector); // only one profile needed + input_implicit_shape_ranges[input_name][0] = profile_vector; + has_dynamic_shape = true; + } else { + // Execution tensor + for (int j = 0, end = nb_dims; j < end; ++j) { + if (dims.d[j] == -1) { + std::vector> profile_vector; + std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN}; + profile_vector.push_back(shape_vector); // only one profile needed + input_implicit_shape_ranges[input_name][j] = profile_vector; + has_dynamic_shape = true; + } + } + } + apply_explicit_profile = false; + } + } + + // Set explicit profiles in TRT config if all dynamic shape inputs have associated profiles provided by user + if (has_explicit_profile) { + // TRT EP has a constraint here. + // Users need to provide all the dynamic shape inputs with associated profiles if they want to explicitly specify + // profiles through provider options. + if (has_dynamic_shape) { + std::ostringstream msg; + msg << "User needs to provide all the dynamic shape inputs with associated profiles if they want to explicitly " + "set profiles through provider options.\n"; + msg << "Please note that main graph could be partitioned into TRT/CUDA/CPU subgraphs, in this case, user also " + "needs to provide shape profiles for the TRT subgraph's input if it's dynamic shape input.\n"; + msg << "Following input(s) has no associated shape profiles provided: "; + auto begin = input_implicit_shape_ranges.begin(); + auto end = input_implicit_shape_ranges.end(); + auto it = begin; + if (it != end) { + msg << it->first; + ++it; + } + for (; it != end; ++it) { + msg << "," << it->first; + } + //return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, msg.str()); + } else { + for (auto trt_profile : trt_profiles) { + trt_config->addOptimizationProfile(trt_profile); + } + } + } + // If no explicit profile is applied and the input has dynamic shape, TRT EP simply creates one profile by default. + // It will later set proper min/max/opt shape values duing EP compute time. + else if (!has_explicit_profile && has_dynamic_shape) { + trt_profiles.push_back(trt_builder->createOptimizationProfile()); + } + + // Check platform availability for low precision + if (fp16_enable_ || bf16_enable_) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (!trt_builder->platformHasFastFp16()) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + fp16_enable_ = false; + bf16_enable_ = false; + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE or ORT_TENSORRT_BF16_ENABLE is set, but " + // "platform doesn't support fast native fp16/bf16"; + } + } + + if (int8_enable_) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (!trt_builder->platformHasFastInt8()) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + int8_enable_ = false; + //LOGS_DEFAULT(WARNING) + // << "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8"; + } + } + + // Load INT8 calibration table + std::unordered_map dynamic_range_map; + if (int8_enable_ && int8_calibration_cache_available_) { + const std::string calibration_cache_path = GetCachePath(cache_path_, int8_calibration_cache_name_); + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_tensorrt_calibration_table_, dynamic_range_map)) { + throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); + } + } + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + const char* name = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetName(fused_node, &name)); + std::string fused_node_name = name; + + // Set precision flags + std::string trt_node_name_with_precision = fused_node_name; + if (fp16_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + trt_node_name_with_precision += "_fp16"; + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; + } + if (bf16_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kBF16); + trt_node_name_with_precision += "_bf16"; + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] BF16 mode is enabled"; + } + if (int8_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + trt_node_name_with_precision += "_int8"; + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; + } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + // Set DLA + if (fp16_enable_ || int8_enable_) { + if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8 + int number_of_dla_core = trt_builder->getNbDLACores(); + if (number_of_dla_core == 0) { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core"; + dla_enable_ = false; + } else { + if (dla_core_ >= number_of_dla_core) { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core #" << dla_core_ + // << ", but it exceeds platform's maximum DLA core number " << number_of_dla_core + // << ". Use DLA core 0 instead."; + dla_core_ = 0; + } + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << dla_core_; + trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); + trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); + trt_config->setDLACore(dla_core_); + trt_node_name_with_precision += "_dlacore" + std::to_string(dla_core_); + } + } + } + + // enable sparse weights + if (sparsity_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + } +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 + if (build_heuristics_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are enabled." + << " For TRT > 8.5, trt_build_heuristics_enable is deprecated, please set builder " + "optimization level as 2 to enable builder heuristics."; + } +#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // for TRT 8.6 onwards, heuristic-based tactic option is automatically enabled by setting builder optimization level 2 + if (build_heuristics_enable_) { + if (builder_optimization_level_ == 2) { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are automatically enabled by builder optimization " + // "level 2. trt_build_heuristics_enable is deprecated on TRT 8.6 onwards."; + } else { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] trt_build_heuristics_enable is deprecated on TRT 8.6 onwards. Please set " + // "builder optimization level as 2 to enable builder heuristics."; + } + } +#endif + +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // switch optimizaion level + if (builder_optimization_level_ != 3) { + trt_config->setBuilderOptimizationLevel(builder_optimization_level_); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + } + + // limit auxiliary streams + if (auxiliary_streams_ >= 0) { + trt_config->setMaxAuxStreams(auxiliary_streams_); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << auxiliary_streams_; + } +#else + if (builder_optimization_level_ != 3) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + } + if (auxiliary_streams_ >= 0) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + } +#endif + + if (weight_stripped_engine_enable_) { +#if NV_TENSORRT_MAJOR >= 10 + trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; + trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; +#else + LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; +#endif + } + + // limit used tactic sources + if (!tactic_sources_.empty()) { + nvinfer1::TacticSources tactics = trt_config->getTacticSources(); + tactics |= GetTacticSourceFromString(tactic_sources_); + trt_config->setTacticSources(tactics); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using " << tactic_sources_; + } + + // Set preview feature flags + for (auto feature : preview_features_) { + trt_config->setPreviewFeature(feature, true); + } + + // Build TRT engine (if needed) and load TRT engine if: + // (1) Graph has no dynamic shape input + // (2) All the dynamic shape inputs have associated explicit profiles specified by user + // + // Otherwise engine will be handled at inference time. + std::unique_ptr trt_engine; + std::unique_ptr trt_context; + + std::string cache_path = ""; + std::string cache_suffix = ""; + // Customize cache prefix if assigned + if (!cache_prefix_.empty()) { + // Generate cache suffix in case user would like to customize cache prefix + cache_suffix = "_" + GetCacheSuffix(fused_node_name, trt_node_name_with_precision); + cache_path = GetCachePath(cache_path_, cache_prefix_) + cache_suffix; + } else { + cache_path = GetCachePath(cache_path_, trt_node_name_with_precision); + } + + std::string cache_hw_compat = "_sm" + compute_capability_; +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // Enable hardware compatility mode if assigned + if (engine_cache_enable_ && engine_hw_compatible_) { + trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); + cache_hw_compat = "_sm80+"; + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; + } +#endif + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if + // they share the same compute capacity + const std::string cache_path_prefix = cache_path + cache_hw_compat; + std::string engine_cache_path = cache_path_prefix + ".engine"; + const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; + const std::string profile_cache_path = cache_path_prefix + ".profile"; + + // If weight-stripped engine is enabled and refitted engine cache is not present, + // TRT EP will use the engine cache with ".stripped.engine" appended to the end. + const std::filesystem::path engine_cache_fs_path = engine_cache_path; + if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { + engine_cache_path = cache_path_prefix + ".stripped.engine"; + weight_stripped_engine_refit_ = true; + } + + // Generate file name for dumping ep context model + if (dump_ep_context_model_ && ctx_model_path_.empty()) { + ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_); + } + + if (!has_dynamic_shape) { + std::string timing_cache_path = ""; + bool engine_update = false; + if (timing_cache_enable_) { + timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_); + } + { + // ifstream file check, engine serialization/deserialization and engine build are in critical section. It needs + // lock protection to prevent race condition when inferencing with multithreading. + auto lock = GetApiLock(); + + // If explicit profile flag is on and engine cache enable flag is on, + // we need to compare explicit profiles and profiles used to build the engine in order to decide whether to + // rebuild the engine. + if (has_explicit_profile && engine_cache_enable_) { + engine_update = + CompareProfiles(profile_cache_path, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); + if (engine_update) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine will be built"; + } else { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine won't be rebuilt"; + } + } + + std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); + if (engine_cache_enable_ && !engine_decryption_enable_ && engine_file && !engine_update) { + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + trt_engine = + std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + if (trt_engine == nullptr) { + std::string err_msg = "TensorRT EP could not deserialize engine from cache: " + engine_cache_path; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + } else if (engine_decryption_enable_ && engine_cache_enable_ && + std::filesystem::exists(encrypted_engine_cache_path) && !engine_update) { + // Decrypt engine + size_t engine_size = 0; + if (!engine_decryption_(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { + std::string err_msg = "TensorRT EP could not get engine buffer size"; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + std::unique_ptr engine_buf{new char[engine_size]}; + if (!engine_decryption_(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + std::string err_msg = "TensorRT EP could not call engine decryption function decrypt"; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + // Deserialize engine + trt_engine = + std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + if (trt_engine == nullptr) { + std::string err_msg = "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } else { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + // Set INT8 per tensor dynamic range + if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) { + trt_config->setInt8Calibrator(nullptr); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (!SetDynamicRange(*trt_network, dynamic_range_map)) { + std::string err_msg = "TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node_name; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } + + // Load timing cache from file. Create a fresh cache if the file doesn't exist + std::unique_ptr timing_cache = nullptr; + if (timing_cache_enable_) { + std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); + timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), + loaded_timing_cache.size())); + if (timing_cache == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not create timing cache: " + timing_cache_path); + } + trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); + if (detailed_build_log_) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; + } + } + + // Build engine + std::chrono::steady_clock::time_point engine_build_start; + if (detailed_build_log_) { + engine_build_start = std::chrono::steady_clock::now(); + } + std::unique_ptr serialized_engine{ + trt_builder->buildSerializedNetwork(*trt_network, *trt_config)}; + if (serialized_engine == nullptr) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, EP_FAIL, + "TensorRT EP failed to create engine from network for fused node: " + fused_node.Name()); + } + trt_engine = std::unique_ptr( + runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); + if (trt_engine == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP failed to deserialize engine for fused node: " + fused_node.Name()); + } + if (detailed_build_log_) { + auto engine_build_stop = std::chrono::steady_clock::now(); + //LOGS_DEFAULT(INFO) + // << "TensorRT engine build for " << trt_node_name_with_precision << " took: " + // << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() + // << "ms" << std::endl; + } + if (engine_cache_enable_) { + // Serialize engine profile if it has explicit profiles + if (has_explicit_profile) { + SerializeProfileV2(profile_cache_path, input_explicit_shape_ranges); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; + } + + if (engine_decryption_enable_) { + // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available + // first. + if (engine_encryption_ != nullptr) { + if (!engine_encryption_(encrypted_engine_cache_path.c_str(), + reinterpret_cast(serialized_engine->data()), serialized_engine->size())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP call to engine encryption library failed"); + } + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + } else { + //LOGS_DEFAULT(WARNING) + // << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; + } + } else { + std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); + file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path; + } + } + // serialize and save timing cache + if (timing_cache_enable_) { + auto timing_cache = trt_config->getTimingCache(); + std::unique_ptr timingCacheHostData{timing_cache->serialize()}; + if (timingCacheHostData == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not serialize timing cache: " + timing_cache_path); + } + saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); + if (detailed_build_log_) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; + } + } + // dump EP context node model + if (dump_ep_context_model_) { + // "ep_cache_context" node attribute should be a relative path to context model directory + if (ep_cache_context_attr_.empty()) { + auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); + ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir) + .append(cache_file_name.string()) + .string(); + } + std::string compute_capability_hw_compat = compute_capability_; + if (engine_cache_enable_ && engine_hw_compatible_) { + compute_capability_hw_compat = "80+"; + } + std::unique_ptr model_proto{ + CreateCtxModel(graph_body_viewer, ep_cache_context_attr_, + reinterpret_cast(serialized_engine->data()), serialized_engine->size(), + ep_context_embed_mode_, compute_capability_hw_compat, model_path_, GetLogger())}; + DumpCtxModel(model_proto.get(), ctx_model_path_); + } + } + } + + if (weight_stripped_engine_refit_) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refit engine from main ONNX file after engine build"; + char* onnx = string_buf.data(); + size_t onnx_size = string_buf.size(); + auto status = RefitEngine(model_path_, onnx_model_folder_path_, engine_cache_path, + false /* path check for security */, onnx, onnx_size, trt_engine.get(), + true /* serialize refitted engine to disk */, detailed_build_log_); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + } + + // Build context + // Note: Creating an execution context from an engine is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + if (context_memory_sharing_enable_) { + // Reset the max_ctx_mem_size_ and context_memory_ since we don't have access to the allocator here. + max_ctx_mem_size_ = 0; + context_memory_ = nullptr; +#if NV_TENSORRT_MAJOR < 10 + trt_context = + std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); +#else + trt_context = std::unique_ptr( + trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); +#endif + } else { + trt_context = std::unique_ptr(trt_engine->createExecutionContext()); + } + if (!trt_context) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not build execution context for fused node: " + fused_node.Name()); + } + } + + // Create input to index map + for (int i = 0; i < num_inputs; ++i) { + auto input = trt_network->getInput(i); + const std::string& input_name = input->getName(); + const auto& iter = input_map.find(input_name); + if (iter != input_map.end()) { + input_indexes[input_name] = iter->second; + } + } + + // Create output to index and type maps + const auto& graph_output = model_proto->graph().output(); + for (int i = 0; i < num_outputs; ++i) { + const std::string& output_name = trt_network->getOutput(i)->getName(); + const auto& iter = output_map.find(output_name); + if (iter != output_map.end()) { + output_indexes[output_name] = iter->second; + } + const auto& tensor_type = graph_output[i].type().tensor_type(); + output_types[output_name] = tensor_type.elem_type(); + } + + // Save TRT engine, other TRT objects and input/output info to map + parsers_.emplace(fused_node.Name(), std::move(trt_parser)); + engines_.emplace(fused_node.Name(), std::move(trt_engine)); + contexts_.emplace(fused_node.Name(), std::move(trt_context)); + networks_.emplace(fused_node.Name(), std::move(trt_network)); + input_info_[fused_node.Name()].push_back(input_indexes); + output_info_[fused_node.Name()].push_back(output_indexes); + output_info_[fused_node.Name()].push_back(output_types); + input_shape_ranges_[fused_node.Name()] = input_implicit_shape_ranges; + profiles_.emplace(fused_node.Name(), std::move(trt_profiles)); + + // For dynamic shape input model, firstly TRT EP creates a model proto which includes inputs, outputs and empty + // engine. TRT EP will serialize the model at inference time due to engine can be updated and the updated engine + // should be included in the model. However, if the embed_mode is 0 (only includes engine path), TRT EP will serialize + // it here. + if (dump_ep_context_model_ && has_dynamic_shape) { + // "ep_cache_context" node attribute should be a relative path to context model directory + if (ep_cache_context_attr_.empty()) { + auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); + ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir) + .append(cache_file_name.string()) + .string(); + } + std::string compute_capability_hw_compat = compute_capability_; + if (engine_cache_enable_ && engine_hw_compatible_) { + compute_capability_hw_compat = "80+"; + } + model_proto_.reset(CreateCtxModel(graph_body_viewer, ep_cache_context_attr_, nullptr, 0, ep_context_embed_mode_, + compute_capability_hw_compat, model_path_, GetLogger())); + if (ep_context_embed_mode_ == 0) { + DumpCtxModel(model_proto_.get(), ctx_model_path_); + } + } + + std::unique_ptr func_state = std::make_unique(); + // translate tactic sources string to nvinfer1::TacticSources + nvinfer1::TacticSources tactics = 0; + if (!tactic_sources_.empty()) { + tactics = GetTacticSourceFromString(tactic_sources_); + } + *func_state = { + fused_node_name, + builder_.get(), + &parsers_[fused_node_name], + &engines_[fused_node_name], + &contexts_[fused_node_name], + &networks_[fused_node_name], + input_info_[fused_node_name], + output_info_[fused_node_name], + input_shape_ranges_[fused_node_name], + &tensorrt_mu_, + fp16_enable_, + bf16_enable_, + int8_enable_, + int8_calibration_cache_available_, + dla_enable_, + dla_core_, + trt_node_name_with_precision, + engine_cache_enable_, + cache_path_, + runtime_.get(), + profiles_[fused_node_name], + context_memory_sharing_enable_, + &max_ctx_mem_size_, + &context_memory_, + dynamic_range_map, + engine_decryption_enable_, + engine_decryption_, + engine_encryption_, + timing_cache_enable_, + global_cache_path_, + force_timing_cache_match_, + detailed_build_log_, + build_heuristics_enable_, + sparsity_enable_, + builder_optimization_level_, + auxiliary_streams_, + !tactic_sources_.empty(), + tactics, + cuda_graph_enable_, + cache_prefix_, + cache_suffix, + engine_hw_compatible_, + preview_features_}; + ep->func_states.emplace(fused_node_name, std::move(func_state)); + + // Update the OrtNodeComputeInfo associated with the graph. + auto node_compute_info = std::make_unique(*ep); + + node_compute_info->CreateStateImpl = [=](ComputeContext* context, FunctionState* state) { + std::unique_ptr p = std::make_unique(); + // translate tactic sources string to nvinfer1::TacticSources + nvinfer1::TacticSources tactics = 0; + if (!tactic_sources_.empty()) { + tactics = GetTacticSourceFromString(tactic_sources_); + } + *p = {context->allocate_func, + context->release_func, + context->allocator_handle, + context->node_name, + builder_.get(), + &parsers_[context->node_name], + &engines_[context->node_name], + &contexts_[context->node_name], + &networks_[context->node_name], + input_info_[context->node_name], + output_info_[context->node_name], + input_shape_ranges_[context->node_name], + &tensorrt_mu_, + fp16_enable_, + bf16_enable_, + int8_enable_, + int8_calibration_cache_available_, + dla_enable_, + dla_core_, + trt_node_name_with_precision, + engine_cache_enable_, + cache_path_, + runtime_.get(), + profiles_[context->node_name], + context_memory_sharing_enable_, + &max_ctx_mem_size_, + &context_memory_, + dynamic_range_map, + engine_decryption_enable_, + engine_decryption_, + engine_encryption_, + timing_cache_enable_, + global_cache_path_, + force_timing_cache_match_, + detailed_build_log_, + build_heuristics_enable_, + sparsity_enable_, + builder_optimization_level_, + auxiliary_streams_, + !tactic_sources_.empty(), + tactics, + cuda_graph_enable_, + cache_prefix_, + cache_suffix, + engine_hw_compatible_, + preview_features_}; + *state = p.release(); + return 0; + }; + + // Release function state + compute_info.release_state_func = [](FunctionState state) { delete static_cast(state); }; + + // Create compute function + compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + Ort::KernelContext ctx(context); + + TensorrtFuncState* trt_state = reinterpret_cast(state); + + // The whole compute_function should be considered the critical section where multiple threads may update kernel + // function state, access one builder, create/serialize/save engine, save profile and serialize/save timing cache. + // Therefore, those operations should be synchronized across different threads when ORT is using multithreading. + // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); + const std::unordered_map& input_indexes = (trt_state->input_info)[0]; + const std::unordered_map& output_indexes = (trt_state->output_info)[0]; + const std::unordered_map& output_types = (trt_state->output_info)[1]; + auto fused_node_name = trt_state->fused_node_name; + // This map "shape_ranges" contains the shape range info for setting TRT optimization profiles. + // The info is used for both shape tensor and execution tensor: + // tensor name->(dimension->[min, max, opt]) + auto& shape_ranges = trt_state->input_shape_ranges; + std::unordered_map> + shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this + // inference run + std::unordered_map> + shape_tensor_values_int64; // same as above but for int64 shape tensor input + auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; + auto trt_builder = trt_state->builder; + auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); + auto trt_profiles = trt_state->profiles; + auto context_memory = trt_state->context_memory; + auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; + int num_inputs = static_cast(input_indexes.size()); + int num_outputs = static_cast(output_indexes.size()); + bool engine_update = false; + bool context_update = false; + std::unordered_set input_names; + + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, + narrow(device_id_)); + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device); + if (alloc_ == nullptr) { + Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); + } + OrtAllocator* alloc = alloc_; + + void* cuda_stream; + Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + cudaStream_t stream = static_cast(cuda_stream); + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even + // if they share the same compute capacity Prepare cache name + std::string cache_path = ""; + // Customize cache prefix if assigned + if (!cache_prefix_.empty()) { + cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->cache_prefix) + trt_state->cache_suffix; + } else { + cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); + } + + // Enable hardware compatility mode if assigned + std::string cache_hw_compat = "_sm" + compute_capability_; +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + if (engine_cache_enable_ && engine_hw_compatible_) { + cache_hw_compat = "_sm80+"; + //LOGS_DEFAULT(VERBOSE) + // << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; + } +#endif + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even + // if they share the same compute capacity + const std::string cache_path_prefix = cache_path + cache_hw_compat; + std::string engine_cache_path = cache_path_prefix + ".engine"; + const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; + const std::string profile_cache_path = cache_path_prefix + ".profile"; + std::string timing_cache_path = ""; + if (timing_cache_enable_) { + timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_); + } + + // If weight-stripped engine is enabled and refitted engine cache is not present, + // TRT EP will use the engine cache with ".stripped.engine" appended to the end. + const std::filesystem::path engine_cache_fs_path = engine_cache_path; + if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { + engine_cache_path = cache_path_prefix + ".stripped.engine"; + weight_stripped_engine_refit_ = true; + } + + // Load serialized engine + if (trt_state->engine_cache_enable && trt_engine == nullptr) { + std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); + std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); + if (engine_file && !trt_state->engine_decryption_enable && profile_file) { + // Deserialize profile + shape_ranges = DeserializeProfileV2(profile_file); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + + // Prepare buffer + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + trt_state->engine->reset(); + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + } + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + trt_engine = trt_state->engine->get(); + context_update = true; + + } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && + profile_file) { + shape_ranges = DeserializeProfileV2(profile_file); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + // Decrypt engine + size_t engine_size = 0; + if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not get engine buffer size"); + } + std::unique_ptr engine_buf{new char[engine_size]}; + if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not call engine decryption function decrypt"); + } + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + trt_state->engine->reset(); + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); + } + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + trt_engine = trt_state->engine->get(); + context_update = true; + } + } + + // Check and update shape ranges for dynamic shape inputs. + for (int i = 0, end = num_inputs; i < end; ++i) { + auto input = trt_state->network->get()->getInput(i); + const std::string& input_name = input->getName(); + input_names.insert(input_name); + + // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile + // shape values have not yet resolved. TRT EP will help determine the min/max/opt profile values based on current + // input tensor value. + if (shape_ranges.find(input_name) != shape_ranges.end()) { + auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, + shape_tensor_values, shape_tensor_values_int64, stream, + &engine_update); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP failed to parse input tensor and generate optimization profiles."); + } + } + } + + // Regenerate engine + if (engine_update) { + // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined + // behavior. + trt_state->context->reset(); + trt_state->engine->reset(); + auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); + if (max_workspace_size_ > 0) { + trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); + } + for (auto trt_profile : trt_profiles) { + trt_config->addOptimizationProfile(trt_profile); + } +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + // Set INT8 Per Tensor Dynamic range + if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { + trt_config->setInt8Calibrator(nullptr); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); + } + } +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + // Set precision + if (trt_state->int8_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; + } + if (trt_state->fp16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; + } + if (trt_state->bf16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kBF16); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] BF16 mode is enabled"; + } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + // Set DLA (DLA can only run with FP16 or INT8) + if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; + trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); + trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); + trt_config->setDLACore(trt_state->dla_core); + } + + // enable sparse weights + if (trt_state->sparsity_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + } +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 + // enable builder heuristics + if (trt_state->build_heuristics_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; + } +#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // switch optimizaion level + if (trt_state->builder_optimization_level != 3) { + trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + } + + // limit auxiliary streams + if (trt_state->auxiliary_streams >= 0) { + trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; + } +#else + if (trt_state->builder_optimization_level != 3) { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + } + if (trt_state->auxiliary_streams >= 0) { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + } +#endif + if (weight_stripped_engine_enable_) { +#if NV_TENSORRT_MAJOR >= 10 + trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; + trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; +#else + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; +#endif + } + // limit used tactic sources + if (trt_state->filter_tactic_sources) { + nvinfer1::TacticSources tactics = trt_config->getTacticSources(); + tactics |= trt_state->tactic_sources; + trt_config->setTacticSources(tactics); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; + } + + // Load timing cache from file. Create a fresh cache if the file doesn't exist + std::unique_ptr timing_cache = nullptr; + if (trt_state->timing_cache_enable) { + std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); + timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), + loaded_timing_cache.size())); + if (timing_cache == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not create timing cache: " + timing_cache_path); + } + trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); + if (detailed_build_log_) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; + } + } + +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // Enable hardware compatility mode if assigned + if (trt_state->engine_hw_compatible) { + trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); + //LOGS_DEFAULT(INFO) << "[TensorRT EP] Re-generate engine with hardware compatibility enabled."; + } +#endif + + // Set preview feature flags + for (auto feature : trt_state->preview_features) { + trt_config->setPreviewFeature(feature, true); + } + + // Build engine + std::unique_ptr serialized_engine; + { + auto lock = GetApiLock(); + std::chrono::steady_clock::time_point engine_build_start; + if (detailed_build_log_) { + engine_build_start = std::chrono::steady_clock::now(); + } + serialized_engine = std::unique_ptr( + trt_builder->buildSerializedNetwork(*trt_state->network->get(), *trt_config)); + if (!serialized_engine) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create engine from network."); + } + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to deserialize engine."); + } + if (detailed_build_log_) { + auto engine_build_stop = std::chrono::steady_clock::now(); + //LOGS_DEFAULT(INFO) + << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " + << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() + << "ms" << std::endl; + } + } + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + } + trt_engine = trt_state->engine->get(); + if (trt_state->engine_cache_enable) { + // Serialize engine profile + SerializeProfileV2(profile_cache_path, shape_ranges); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; + + // Serialize engine + if (trt_state->engine_decryption_enable) { + // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available + // first. + if (trt_state->engine_encryption != nullptr) { + if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), + reinterpret_cast(serialized_engine->data()), + serialized_engine->size())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not call engine encryption function encrypt"); + } + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + } else { + //LOGS_DEFAULT(WARNING) + << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; + } + } else { + std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); + file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; + } + } + + // serialize and save timing cache + if (trt_state->timing_cache_enable) { + auto timing_cache = trt_config->getTimingCache(); + std::unique_ptr timingCacheHostData{timing_cache->serialize()}; + if (timingCacheHostData == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not serialize timing cache: " + timing_cache_path); + } + saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); + if (detailed_build_log_) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; + } + } + + // dump ep context model + if (dump_ep_context_model_ && ep_context_embed_mode_) { + UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast(serialized_engine->data()), + serialized_engine->size()); + DumpCtxModel(model_proto_.get(), ctx_model_path_); + } + context_update = true; + + if (weight_stripped_engine_refit_) { + auto status = + RefitEngine(model_path_, onnx_model_folder_path_, engine_cache_path, false /* path check for security */, + onnx_model_bytestream_, onnx_model_bytestream_size_, trt_engine, + true /* serialize refitted engine to disk */, detailed_build_log_); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + } + } + + if (context_update) { + if (trt_state->context_memory_sharing_enable) { +#if NV_TENSORRT_MAJOR < 10 + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); +#else + *(trt_state->context) = + std::unique_ptr(trt_state->engine->get()->createExecutionContext( + nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); +#endif + } else { + *(trt_state->context) = + std::unique_ptr(trt_state->engine->get()->createExecutionContext()); + } + if (!(*(trt_state->context))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); + } + trt_context = trt_state->context->get(); + } + + // Check before using trt_engine + if (trt_engine == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "No engine is found."); + } + + // Get input and output binding names + int total_bindings = trt_engine->getNbIOTensors(); + std::vector input_binding_names, output_binding_names; + for (int i = 0, end = total_bindings; i < end; ++i) { + auto const& name = trt_engine->getIOTensorName(i); + auto const& mode = trt_engine->getTensorIOMode(name); + if (mode == nvinfer1::TensorIOMode::kINPUT) { + input_binding_names.push_back(name); + } else { + output_binding_names.push_back(name); + } + } + + /* + * Set input shapes and bind input buffers + */ + std::vector> scratch_buffers; + for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { + char const* input_name = input_binding_names[i]; + + size_t input_index = 0; + const auto iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } + auto input_tensor = ctx.GetInput(input_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + + auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, + shape_tensor_values_int64, scratch_buffers, alloc, stream); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + } + + /* + * Set output shapes and bind output buffers + */ + std::unordered_map buffers; + buffers.reserve(num_outputs); + using OutputOrtValue = Ort::UnownedValue; + std::unordered_map output_tensors; + output_tensors.reserve(num_outputs); + std::unordered_map output_dim_sizes; + output_dim_sizes.reserve(num_outputs); + + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + + size_t output_type = 0; + const auto type_iter = output_types.find(output_name); + if (type_iter != output_types.end()) { + output_type = type_iter->second; + } + + Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, + output_dim_sizes, dds_output_allocator_map, scratch_buffers, alloc, buffers); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + } + + // Set execution context memory + if (trt_state->context_memory_sharing_enable) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + size_t mem_size = trt_engine->getDeviceMemorySize(); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (mem_size > *max_context_mem_size_ptr) { + *max_context_mem_size_ptr = mem_size; + *context_memory = + IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr, true /*use_reserve*/); + } + trt_context->setDeviceMemory((*context_memory).get()); + } + + // Start CUDA graph capture. + // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because + // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { + //LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; + cuda_graph_.SetStream(stream); + CaptureBegin(0); + } + + // Run TRT inference + if (!trt_context->enqueueV3(stream)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); + } + + /* + * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this + * function concurrently, TRT EP needs to carefully take care of concurrency here, if not, following concurrent + * issue might happen: + * + * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per + * stream. In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling + * InferenceSession::Run() concurrently, the trt execution context instance is shared by all the threads and each + * thread aquires different stream from ORT. So TRT EP will end up having one trt execution context using multiple + * streams which is not suggested. But, since the whole compute_func() is protected by the lock and if + * cudaStreamSynchronize() is enforced here, one trt execution context per stream is guaranteed. + * + * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all + * operations to prevent the concurrent issue mentioned above. However, if cuda graph is enabled, TRT EP won't call + * cudaStreamSynchronize() since it's not allowed during graph capture. + */ + if (sync_stream_after_enqueue_) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + } + + // Assign TRT output back to ORT output + // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) + // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_type = 0; + const auto& iter = output_types.find(output_name); + if (iter != output_types.end()) { + output_type = iter->second; + } + + if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + auto status = + BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); + } + } else { + auto& output_tensor = output_tensors[i]; +#if NV_TENSORRT_MAJOR < 10 + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), + output_tensor_ptr, output_dim_sizes[i]); + } + } #endif - // The allocation buffer holds the float output data since TRT doesn't support double. So, we need to cast the data (float -> double) for ORT kernel output. - // CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) - default: { - return ort_api.CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported.").c_str()); + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, + output_dim_sizes[i]); + } + } + } } - } - return nullptr; + + // End CUDA graph capture. + // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream + // mentioned in graph capture above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and + // ExecuteGraph() per inference_session.cc. It's safe to start/end CUDA graph capture in compute_func() here since + // cuda graph object is maintained by a per thread basis. + if (cuda_graph_enable_ && !IsGraphCaptured(0)) { + if (IsGraphCaptureAllowed()) { + CaptureEnd(0); + // CUDA work issued to a capturing stream doesn�t actually run on the GPU, + // so run the captured graph here to actually execute the work. + ORT_RETURN_IF_ERROR(ReplayGraph(0)); + } else { + IncrementRegularRunCountBeforeGraphCapture(); + } + } + + return Status::OK(); + }; + + node_compute_funcs.push_back(compute_info); + return Status::OK(); } + static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, OrtEpGraphSupportInfo* graph_support_info) { TensorrtExecutionProvider* ep = static_cast(this_ptr); @@ -1142,11 +2648,9 @@ static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) { return ep->name_.c_str(); } -/// /// -/// Constructor of Plugin TensorRT EP +/// The Plugin TensorRT EP (Implementation of TensorrtExecutionProvider) /// -/// TensorrtExecutionProvider::TensorrtExecutionProvider(ApiPtrs apis, const std::string& name, const OrtHardwareDevice& device, const OrtSessionOptions& session_options, const OrtLogger& logger) @@ -1618,75 +3122,643 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(ApiPtrs apis, const std::st } } -nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder(TensorrtLogger& trt_logger) const { - if (!builder_) { - { - auto lock = GetApiLock(); - builder_ = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); - } + +// +// Implementation of TRTEpNodeComputeInfo +// +TRTEpNodeComputeInfo::TRTEpNodeComputeInfo(TensorrtExecutionProvider& ep) : ep(ep) { + ort_version_supported = ORT_API_VERSION; + CreateState = CreateStateImpl; + Compute = ComputeImpl; + ReleaseState = ReleaseStateImpl; +} + +OrtStatus* TRTEpNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context, + void** compute_state) { + auto* node_compute_info = static_cast(this_ptr); + TensorrtExecutionProvider& ep = node_compute_info->ep; + + std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context); + auto state_it = ep.GetComputeStates().find(fused_node_name); + if (state_it == ep.GetComputeStates().end()) { + std::string message = "Unable to TensorRT EP's compute state for fused node with name " + fused_node_name; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, message.c_str()); } - return builder_.get(); + + TensorrtComputeState& compute_state = *state_it->second; + *compute_state = &compute_state; + return nullptr; } -SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t nodes_vector_input, int iterations, const int max_iterations, - const OrtGraph* graph, bool* early_termination) const { - // Return if iterations are exceeding predefined number - SubGraphCollection_t nodes_list_output; - if (iterations > max_iterations) { - *early_termination = true; - return nodes_list_output; +OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context) { + auto* node_compute_info = static_cast(this_ptr); + TensorrtExecutionProvider& ep = node_compute_info->ep; + + TensorrtComputeState* trt_state = reinterpret_cast(compute_state); + Ort::KernelContext ctx(kernel_context); + + // The whole compute_function should be considered the critical section where multiple threads may update kernel + // function state, access one builder, create/serialize/save engine, save profile and serialize/save timing cache. + // Therefore, those operations should be synchronized across different threads when ORT is using multithreading. + // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); + const std::unordered_map& input_indexes = (trt_state->input_info)[0]; + const std::unordered_map& output_indexes = (trt_state->output_info)[0]; + const std::unordered_map& output_types = (trt_state->output_info)[1]; + auto fused_node_name = trt_state->fused_node_name; + // This map "shape_ranges" contains the shape range info for setting TRT optimization profiles. + // The info is used for both shape tensor and execution tensor: + // tensor name->(dimension->[min, max, opt]) + auto& shape_ranges = trt_state->input_shape_ranges; + std::unordered_map> + shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this + // inference run + std::unordered_map> + shape_tensor_values_int64; // same as above but for int64 shape tensor input + + auto trt_builder = trt_state->builder; + auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); + auto trt_profiles = trt_state->profiles; + auto context_memory = trt_state->context_memory; + auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; + int num_inputs = static_cast(input_indexes.size()); + int num_outputs = static_cast(output_indexes.size()); + bool engine_update = false; + bool context_update = false; + std::unordered_set input_names; + + std::unordered_map dds_output_allocator_maps = ep.GetDDSOutputAllocators(); + auto& dds_output_allocator_map = dds_output_allocator_maps[fused_node_name]; + + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, + narrow(device_id_)); + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device); + if (alloc_ == nullptr) { + Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); + } + OrtAllocator* alloc = alloc_; + + void* cuda_stream; + Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + cudaStream_t stream = static_cast(cuda_stream); + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even + // if they share the same compute capacity Prepare cache name + std::string cache_path = ""; + // Customize cache prefix if assigned + if (!cache_prefix_.empty()) { + cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->cache_prefix) + trt_state->cache_suffix; + } else { + cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); } - iterations++; - for (const auto& group : nodes_vector_input) { - // Construct subgraph - if (!group.first.empty()) { - if (group.second) { - nodes_list_output.push_back(group); + // Enable hardware compatility mode if assigned + std::string cache_hw_compat = "_sm" + compute_capability_; +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + if (engine_cache_enable_ && engine_hw_compatible_) { + cache_hw_compat = "_sm80+"; + // LOGS_DEFAULT(VERBOSE) + // << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; + } +#endif + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even + // if they share the same compute capacity + const std::string cache_path_prefix = cache_path + cache_hw_compat; + std::string engine_cache_path = cache_path_prefix + ".engine"; + const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; + const std::string profile_cache_path = cache_path_prefix + ".profile"; + std::string timing_cache_path = ""; + if (timing_cache_enable_) { + timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_); + } + + // If weight-stripped engine is enabled and refitted engine cache is not present, + // TRT EP will use the engine cache with ".stripped.engine" appended to the end. + const std::filesystem::path engine_cache_fs_path = engine_cache_path; + if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { + engine_cache_path = cache_path_prefix + ".stripped.engine"; + weight_stripped_engine_refit_ = true; + } + + // Load serialized engine + if (trt_state->engine_cache_enable && trt_engine == nullptr) { + std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); + std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); + if (engine_file && !trt_state->engine_decryption_enable && profile_file) { + // Deserialize profile + shape_ranges = DeserializeProfileV2(profile_file); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + + // Prepare buffer + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + trt_state->engine->reset(); + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + } + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + trt_engine = trt_state->engine->get(); + context_update = true; + + } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && + profile_file) { + shape_ranges = DeserializeProfileV2(profile_file); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + // Decrypt engine + size_t engine_size = 0; + if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not get engine buffer size"); + } + std::unique_ptr engine_buf{new char[engine_size]}; + if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not call engine decryption function decrypt"); + } + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + trt_state->engine->reset(); + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); + } + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + trt_engine = trt_state->engine->get(); + context_update = true; + } + } + + // Check and update shape ranges for dynamic shape inputs. + for (int i = 0, end = num_inputs; i < end; ++i) { + auto input = trt_state->network->get()->getInput(i); + const std::string& input_name = input->getName(); + input_names.insert(input_name); + + // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile + // shape values have not yet resolved. TRT EP will help determine the min/max/opt profile values based on current + // input tensor value. + if (shape_ranges.find(input_name) != shape_ranges.end()) { + auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, + shape_tensor_values, shape_tensor_values_int64, stream, + &engine_update); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP failed to parse input tensor and generate optimization profiles."); + } + } + } + + // Regenerate engine + if (engine_update) { + // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined + // behavior. + trt_state->context->reset(); + trt_state->engine->reset(); + auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); + if (max_workspace_size_ > 0) { + trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); + } + for (auto trt_profile : trt_profiles) { + trt_config->addOptimizationProfile(trt_profile); + } +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + // Set INT8 Per Tensor Dynamic range + if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { + trt_config->setInt8Calibrator(nullptr); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); + } + } +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + // Set precision + if (trt_state->int8_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; + } + if (trt_state->fp16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; + } + if (trt_state->bf16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kBF16); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] BF16 mode is enabled"; + } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + // Set DLA (DLA can only run with FP16 or INT8) + if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; + trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); + trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); + trt_config->setDLACore(trt_state->dla_core); + } + + // enable sparse weights + if (trt_state->sparsity_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + } +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 + // enable builder heuristics + if (trt_state->build_heuristics_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; + } +#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // switch optimizaion level + if (trt_state->builder_optimization_level != 3) { + trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + } + + // limit auxiliary streams + if (trt_state->auxiliary_streams >= 0) { + trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; + } +#else + if (trt_state->builder_optimization_level != 3) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + } + if (trt_state->auxiliary_streams >= 0) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + } +#endif + if (weight_stripped_engine_enable_) { +#if NV_TENSORRT_MAJOR >= 10 + trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; + trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; +#else + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; +#endif + } + // limit used tactic sources + if (trt_state->filter_tactic_sources) { + nvinfer1::TacticSources tactics = trt_config->getTacticSources(); + tactics |= trt_state->tactic_sources; + trt_config->setTacticSources(tactics); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; + } + + // Load timing cache from file. Create a fresh cache if the file doesn't exist + std::unique_ptr timing_cache = nullptr; + if (trt_state->timing_cache_enable) { + std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); + timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), + loaded_timing_cache.size())); + if (timing_cache == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not create timing cache: " + timing_cache_path); + } + trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); + if (detailed_build_log_) { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; + } + } + +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // Enable hardware compatility mode if assigned + if (trt_state->engine_hw_compatible) { + trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Re-generate engine with hardware compatibility enabled."; + } +#endif + + // Set preview feature flags + for (auto feature : trt_state->preview_features) { + trt_config->setPreviewFeature(feature, true); + } + + // Build engine + std::unique_ptr serialized_engine; + { + auto lock = GetApiLock(); + std::chrono::steady_clock::time_point engine_build_start; + if (detailed_build_log_) { + engine_build_start = std::chrono::steady_clock::now(); + } + serialized_engine = std::unique_ptr( + trt_builder->buildSerializedNetwork(*trt_state->network->get(), *trt_config)); + if (!serialized_engine) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create engine from network."); + } + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to deserialize engine."); + } + if (detailed_build_log_) { + auto engine_build_stop = std::chrono::steady_clock::now(); + // LOGS_DEFAULT(INFO) + << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " + << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" + << std::endl; + } + } + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + } + trt_engine = trt_state->engine->get(); + if (trt_state->engine_cache_enable) { + // Serialize engine profile + SerializeProfileV2(profile_cache_path, shape_ranges); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; + + // Serialize engine + if (trt_state->engine_decryption_enable) { + // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available + // first. + if (trt_state->engine_encryption != nullptr) { + if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), + reinterpret_cast(serialized_engine->data()), + serialized_engine->size())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not call engine encryption function encrypt"); + } + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + } else { + // LOGS_DEFAULT(WARNING) + << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; + } } else { - //const OrtGraphViewer* sub_graph_viewer = nullptr; - //graph_api_->OrtGraph_GetSubGraph(graph, group.first.size(), group.first.data(), &sub_graph_viewer); + std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); + file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; + } + } - void* buf_data = nullptr; - size_t buf_size = 0; - graph_api_->OrtGraph_SerializeToArray(sub_graph_viewer, &buf_data, &buf_size); + // serialize and save timing cache + if (trt_state->timing_cache_enable) { + auto timing_cache = trt_config->getTimingCache(); + std::unique_ptr timingCacheHostData{timing_cache->serialize()}; + if (timingCacheHostData == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not serialize timing cache: " + timing_cache_path); + } + saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); + if (detailed_build_log_) { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; + } + } - // Get supported node list recursively - SubGraphCollection_t parser_nodes_list; - TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); - auto trt_builder = GetBuilder(trt_logger); - auto network_flags = 0; -#if NV_TENSORRT_MAJOR > 8 - network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); + // dump ep context model + if (dump_ep_context_model_ && ep_context_embed_mode_) { + UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast(serialized_engine->data()), + serialized_engine->size()); + DumpCtxModel(model_proto_.get(), ctx_model_path_); + } + context_update = true; + + if (weight_stripped_engine_refit_) { + auto status = + RefitEngine(model_path_, onnx_model_folder_path_, engine_cache_path, false /* path check for security */, + onnx_model_bytestream_, onnx_model_bytestream_size_, trt_engine, + true /* serialize refitted engine to disk */, detailed_build_log_); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + } + } + + if (context_update) { + if (trt_state->context_memory_sharing_enable) { +#if NV_TENSORRT_MAJOR < 10 + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); +#else + *(trt_state->context) = + std::unique_ptr(trt_state->engine->get()->createExecutionContext( + nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); #endif - network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); - auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); + } else { + *(trt_state->context) = + std::unique_ptr(trt_state->engine->get()->createExecutionContext()); + } + if (!(*(trt_state->context))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); + } + trt_context = trt_state->context->get(); + } + + // Check before using trt_engine + if (trt_engine == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "No engine is found."); + } + + // Get input and output binding names + int total_bindings = trt_engine->getNbIOTensors(); + std::vector input_binding_names, output_binding_names; + for (int i = 0, end = total_bindings; i < end; ++i) { + auto const& name = trt_engine->getIOTensorName(i); + auto const& mode = trt_engine->getTensorIOMode(name); + if (mode == nvinfer1::TensorIOMode::kINPUT) { + input_binding_names.push_back(name); + } else { + output_binding_names.push_back(name); + } + } + + /* + * Set input shapes and bind input buffers + */ + std::vector> scratch_buffers; + for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { + char const* input_name = input_binding_names[i]; + + size_t input_index = 0; + const auto iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } + auto input_tensor = ctx.GetInput(input_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + + auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, + shape_tensor_values_int64, scratch_buffers, alloc, stream); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + } + + /* + * Set output shapes and bind output buffers + */ + std::unordered_map buffers; + buffers.reserve(num_outputs); + using OutputOrtValue = Ort::UnownedValue; + std::unordered_map output_tensors; + output_tensors.reserve(num_outputs); + std::unordered_map output_dim_sizes; + output_dim_sizes.reserve(num_outputs); + + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + + size_t output_type = 0; + const auto type_iter = output_types.find(output_name); + if (type_iter != output_types.end()) { + output_type = type_iter->second; + } + + Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, + output_dim_sizes, dds_output_allocator_map, scratch_buffers, alloc, buffers); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + } - auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); + // Set execution context memory + if (trt_state->context_memory_sharing_enable) { #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4996) #endif - trt_parser->supportsModel(buf_data, buf_size, parser_nodes_list, model_path_); - graph_api_->OrtFreeMem(buf_data); + size_t mem_size = trt_engine->getDeviceMemorySize(); #if defined(_MSC_VER) #pragma warning(pop) #endif + if (mem_size > *max_context_mem_size_ptr) { + *max_context_mem_size_ptr = mem_size; + *context_memory = + IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr, true /*use_reserve*/); + } + trt_context->setDeviceMemory((*context_memory).get()); + } - SubGraphCollection_t next_nodes_list; - const size_t* subgraph_node_index = nullptr; - size_t subgraph_node_count = 0; - graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(sub_graph_viewer, 1, &subgraph_node_index, &subgraph_node_count); - next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, sub_graph_viewer, early_termination); - for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { - for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { - next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]]; - } - nodes_list_output.push_back(next_nodes_list[i]); + // Start CUDA graph capture. + // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because + // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { + // LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; + cuda_graph_.SetStream(stream); + CaptureBegin(0); + } + + // Run TRT inference + if (!trt_context->enqueueV3(stream)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); + } + + /* + * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this + * function concurrently, TRT EP needs to carefully take care of concurrency here, if not, following concurrent + * issue might happen: + * + * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per + * stream. In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling + * InferenceSession::Run() concurrently, the trt execution context instance is shared by all the threads and each + * thread aquires different stream from ORT. So TRT EP will end up having one trt execution context using multiple + * streams which is not suggested. But, since the whole compute_func() is protected by the lock and if + * cudaStreamSynchronize() is enforced here, one trt execution context per stream is guaranteed. + * + * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all + * operations to prevent the concurrent issue mentioned above. However, if cuda graph is enabled, TRT EP won't call + * cudaStreamSynchronize() since it's not allowed during graph capture. + */ + if (sync_stream_after_enqueue_) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + } + + // Assign TRT output back to ORT output + // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) + // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_type = 0; + const auto& iter = output_types.find(output_name); + if (iter != output_types.end()) { + output_type = iter->second; + } + + if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + auto status = + BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); + } + } else { + auto& output_tensor = output_tensors[i]; +#if NV_TENSORRT_MAJOR < 10 + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, + output_dim_sizes[i]); + } + } +#endif + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, + output_dim_sizes[i]); } - graph_api_->OrtGraph_ReleaseGraphViewer(sub_graph_viewer, true); } } } - return nodes_list_output; + + // End CUDA graph capture. + // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream + // mentioned in graph capture above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and + // ExecuteGraph() per inference_session.cc. It's safe to start/end CUDA graph capture in compute_func() here since + // cuda graph object is maintained by a per thread basis. + if (cuda_graph_enable_ && !IsGraphCaptured(0)) { + if (IsGraphCaptureAllowed()) { + CaptureEnd(0); + // CUDA work issued to a capturing stream doesn�t actually run on the GPU, + // so run the captured graph here to actually execute the work. + ORT_RETURN_IF_ERROR(ReplayGraph(0)); + } else { + IncrementRegularRunCountBeforeGraphCapture(); + } + } + + return kernel.Compute(kernel_context); +} + +void TRTEpNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) { + (void)this_ptr; + TensorrtComputeState& compute_state = *reinterpret_cast(compute_state); + (void)compute_state; + // Do nothing for here. } diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index d4017837..2b925b1f 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -152,10 +152,7 @@ class OutputAllocator : public nvinfer1::IOutputAllocator { using ShapeRangesMap = std::unordered_map>>>; -struct TensorrtFuncState { - AllocateFunc test_allocate_func = nullptr; - DestroyFunc test_release_func = nullptr; - void* allocator = nullptr; +struct TensorrtComputeState { std::string fused_node_name; nvinfer1::IBuilder* builder; tensorrt_ptr::unique_pointer* parser = nullptr; @@ -200,10 +197,7 @@ struct TensorrtFuncState { }; // Minimum information to construct kernel function state for direct engine load code path -struct TensorrtShortFuncState { - AllocateFunc test_allocate_func = nullptr; - DestroyFunc test_release_func = nullptr; - void* allocator = nullptr; +struct TensorrtComputeStateForEPContext { std::string fused_node_name; std::unique_ptr* engine = nullptr; std::unique_ptr* context = nullptr; @@ -226,9 +220,28 @@ struct ApiPtrs { const OrtEpApi& ep_api; }; +/// /// -/// Plugin TensorRT EP that implements OrtEp +/// Plugin TensorRT EP OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. +/// +/// +struct TRTEpNodeComputeInfo : OrtNodeComputeInfo { + explicit TRTEpNodeComputeInfo(TensorrtExecutionProvider& ep); + + static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context, + void** compute_state); + static OrtStatus* ORT_API_CALL ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context); + static void ORT_API_CALL ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state); + + TensorrtExecutionProvider& ep; +}; + +/// /// +/// Plugin TensorRT EP that implements OrtEp +/// +/// struct TensorrtExecutionProvider : OrtEp, ApiPtrs { TensorrtExecutionProvider(ApiPtrs apis, const std::string& name, const OrtHardwareDevice& device, const OrtSessionOptions& session_options, const OrtLogger& logger); @@ -242,16 +255,24 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs { SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations, const int max_iterations, const OrtGraph* graph, bool* early_termination) const; - OrtStatus* CreateNodeComputeInfoFromPrecompiledEngine(OrtEp* this_ptr, const OrtGraph* graphs, - const OrtNode* fused_nodes, + OrtStatus* CreateNodeComputeInfoFromPrecompiledEngine(OrtEp* this_ptr, const OrtGraph* graph, + const OrtNode* fused_node, std::unordered_map& input_map, std::unordered_map& output_map, - OrtNodeComputeInfo* node_compute_infos); + OrtNodeComputeInfo* node_compute_info); - OrtStatus* CreateNodeComputeInfoFromGraph(OrtEp* this_ptr, const OrtGraph* graphs, const OrtNode* fused_nodes, + OrtStatus* CreateNodeComputeInfoFromGraph(OrtEp* this_ptr, const OrtGraph* graph, const OrtNode* fused_node, std::unordered_map& input_map, std::unordered_map& output_map, - OrtNodeComputeInfo* node_compute_infos); + OrtNodeComputeInfo* node_compute_info); + + std::unordered_map>& GetComputeStates() { return compute_states_; } + + std::unordered_map>& GetComputeStatesForEPContext() { return compute_states_; } + + std::unordered_map& GetDDSOutputAllocators() { + return dds_output_allocator_maps_; + } /* bool IsGraphCaptured(int graph_annotation_id) const { return false; } @@ -377,6 +398,9 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs { std::unordered_map> profiles_; std::unordered_map dds_output_allocator_maps_; + std::unordered_map> compute_states_; + std::unordered_map> compute_states_for_ep_context; + // for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture // cudnnHandle_t external_cudnn_handle_ = nullptr; // cublasHandle_t external_cublas_handle_ = nullptr; From be453b11549a6cde12206638fc25c3e9e2020f34 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 1 Jul 2025 17:46:29 -0700 Subject: [PATCH 12/60] add allocator and data transfer --- ...rt_cuda_allocator.cc => cuda_allocator.cc} | 5 +- ...orrt_cuda_allocator.h => cuda_allocator.h} | 12 +- .../tensorrt/tensorrt_execution_provider.cc | 684 +----------------- .../tensorrt/tensorrt_execution_provider.h | 12 +- ...nsorrt_execution_provider_data_transfer.cc | 99 +++ ...ensorrt_execution_provider_data_transfer.h | 30 + .../tensorrt_execution_provider_info.cc | 2 - .../tensorrt_execution_provider_info.h | 4 +- .../tensorrt_execution_provider_utils.h | 44 +- .../tensorrt/tensorrt_provider_factory.cc | 103 ++- .../tensorrt/tensorrt_provider_factory.h | 90 ++- 11 files changed, 322 insertions(+), 763 deletions(-) rename plugin_execution_providers/tensorrt/{tensorrt_cuda_allocator.cc => cuda_allocator.cc} (95%) rename plugin_execution_providers/tensorrt/{tensorrt_cuda_allocator.h => cuda_allocator.h} (90%) create mode 100644 plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc create mode 100644 plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h diff --git a/plugin_execution_providers/tensorrt/tensorrt_cuda_allocator.cc b/plugin_execution_providers/tensorrt/cuda_allocator.cc similarity index 95% rename from plugin_execution_providers/tensorrt/tensorrt_cuda_allocator.cc rename to plugin_execution_providers/tensorrt/cuda_allocator.cc index 89e62dae..058d96f4 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_cuda_allocator.cc +++ b/plugin_execution_providers/tensorrt/cuda_allocator.cc @@ -3,11 +3,10 @@ #include #include -#include "tensorrt_cuda_allocator.h" +#include "cuda_allocator.h" void CUDA_RETURN_IF_ERROR(cudaError_t res); -namespace onnxruntime { void CUDAAllocator::CheckDevice(bool throw_when_fail) const { #ifndef NDEBUG // check device to match at debug build @@ -75,5 +74,3 @@ void CUDAPinnedAllocator::Free(void* p) { const OrtMemoryInfo* CUDAPinnedAllocator::Info() const { return mem_info_; } - -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/tensorrt_cuda_allocator.h b/plugin_execution_providers/tensorrt/cuda_allocator.h similarity index 90% rename from plugin_execution_providers/tensorrt/tensorrt_cuda_allocator.h rename to plugin_execution_providers/tensorrt/cuda_allocator.h index 64767d8e..37e7f462 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_cuda_allocator.h +++ b/plugin_execution_providers/tensorrt/cuda_allocator.h @@ -7,16 +7,13 @@ #define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" -namespace onnxruntime { - -// Following names are originally defined in allocator.h constexpr const char* CUDA_ALLOCATOR = "Cuda"; constexpr const char* CUDA_PINNED_ALLOCATOR = "CudaPinned"; using DeviceId = int16_t; struct CUDAAllocator : OrtAllocator { - CUDAAllocator(DeviceId device_id, const char* name = onnxruntime::CUDA_ALLOCATOR) { + CUDAAllocator(DeviceId device_id, const char* name = CUDA_ALLOCATOR) { OrtAllocator::version = ORT_API_VERSION; OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; @@ -31,6 +28,7 @@ struct CUDAAllocator : OrtAllocator { OrtMemType::OrtMemTypeDefault, &mem_info_); } + // TODO: Handle destructor //~CUDAAllocator(); void* Alloc(size_t size); @@ -50,7 +48,7 @@ struct CUDAAllocator : OrtAllocator { }; struct CUDAPinnedAllocator : OrtAllocator { - CUDAPinnedAllocator(const char* name = onnxruntime::CUDA_PINNED_ALLOCATOR) { + CUDAPinnedAllocator(const char* name = CUDA_PINNED_ALLOCATOR) { OrtAllocator::version = ORT_API_VERSION; OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; @@ -62,6 +60,7 @@ struct CUDAPinnedAllocator : OrtAllocator { OrtMemType::OrtMemTypeDefault, &mem_info_); } + // TODO: Handle destructor //~CUDAPinnedAllocator(); void* Alloc(size_t size); @@ -77,6 +76,3 @@ struct CUDAPinnedAllocator : OrtAllocator { DeviceId device_id_ = 0; OrtMemoryInfo* mem_info_ = nullptr; }; - - -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index da5e29de..fa532dc2 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -12,7 +12,7 @@ #include "ep_abi_utils.h" #include "tensorrt_execution_provider.h" #include "tensorrt_execution_provider_utils.h" -#include "tensorrt_cuda_allocator.h" +#include "cuda_allocator.h" //#include "onnx_ctx_model_helper.h" #include "onnx/onnx_pb.h" #include "cuda/unary_elementwise_ops_impl.h" @@ -81,9 +81,6 @@ struct MemcpyFromHost : OrtCustomOp { }; */ -template -using IAllocatorUniquePtr = std::unique_ptr>; - bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept { size_t alloc_size = size; if (alignment == 0) { @@ -1560,13 +1557,14 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } } - std::unique_ptr func_state = std::make_unique(); + std::unique_ptr compute_state = std::make_unique(); + // translate tactic sources string to nvinfer1::TacticSources nvinfer1::TacticSources tactics = 0; if (!tactic_sources_.empty()) { tactics = GetTacticSourceFromString(tactic_sources_); } - *func_state = { + *compute_state = { fused_node_name, builder_.get(), &parsers_[fused_node_name], @@ -1578,7 +1576,6 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this input_shape_ranges_[fused_node_name], &tensorrt_mu_, fp16_enable_, - bf16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, @@ -1608,676 +1605,11 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this cuda_graph_enable_, cache_prefix_, cache_suffix, - engine_hw_compatible_, - preview_features_}; - ep->func_states.emplace(fused_node_name, std::move(func_state)); + engine_hw_compatible_}; // Update the OrtNodeComputeInfo associated with the graph. auto node_compute_info = std::make_unique(*ep); - - node_compute_info->CreateStateImpl = [=](ComputeContext* context, FunctionState* state) { - std::unique_ptr p = std::make_unique(); - // translate tactic sources string to nvinfer1::TacticSources - nvinfer1::TacticSources tactics = 0; - if (!tactic_sources_.empty()) { - tactics = GetTacticSourceFromString(tactic_sources_); - } - *p = {context->allocate_func, - context->release_func, - context->allocator_handle, - context->node_name, - builder_.get(), - &parsers_[context->node_name], - &engines_[context->node_name], - &contexts_[context->node_name], - &networks_[context->node_name], - input_info_[context->node_name], - output_info_[context->node_name], - input_shape_ranges_[context->node_name], - &tensorrt_mu_, - fp16_enable_, - bf16_enable_, - int8_enable_, - int8_calibration_cache_available_, - dla_enable_, - dla_core_, - trt_node_name_with_precision, - engine_cache_enable_, - cache_path_, - runtime_.get(), - profiles_[context->node_name], - context_memory_sharing_enable_, - &max_ctx_mem_size_, - &context_memory_, - dynamic_range_map, - engine_decryption_enable_, - engine_decryption_, - engine_encryption_, - timing_cache_enable_, - global_cache_path_, - force_timing_cache_match_, - detailed_build_log_, - build_heuristics_enable_, - sparsity_enable_, - builder_optimization_level_, - auxiliary_streams_, - !tactic_sources_.empty(), - tactics, - cuda_graph_enable_, - cache_prefix_, - cache_suffix, - engine_hw_compatible_, - preview_features_}; - *state = p.release(); - return 0; - }; - - // Release function state - compute_info.release_state_func = [](FunctionState state) { delete static_cast(state); }; - - // Create compute function - compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { - Ort::KernelContext ctx(context); - - TensorrtFuncState* trt_state = reinterpret_cast(state); - - // The whole compute_function should be considered the critical section where multiple threads may update kernel - // function state, access one builder, create/serialize/save engine, save profile and serialize/save timing cache. - // Therefore, those operations should be synchronized across different threads when ORT is using multithreading. - // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); - const std::unordered_map& input_indexes = (trt_state->input_info)[0]; - const std::unordered_map& output_indexes = (trt_state->output_info)[0]; - const std::unordered_map& output_types = (trt_state->output_info)[1]; - auto fused_node_name = trt_state->fused_node_name; - // This map "shape_ranges" contains the shape range info for setting TRT optimization profiles. - // The info is used for both shape tensor and execution tensor: - // tensor name->(dimension->[min, max, opt]) - auto& shape_ranges = trt_state->input_shape_ranges; - std::unordered_map> - shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this - // inference run - std::unordered_map> - shape_tensor_values_int64; // same as above but for int64 shape tensor input - auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; - auto trt_builder = trt_state->builder; - auto trt_engine = trt_state->engine->get(); - auto trt_context = trt_state->context->get(); - auto trt_profiles = trt_state->profiles; - auto context_memory = trt_state->context_memory; - auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; - int num_inputs = static_cast(input_indexes.size()); - int num_outputs = static_cast(output_indexes.size()); - bool engine_update = false; - bool context_update = false; - std::unordered_set input_names; - - OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, - narrow(device_id_)); - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device); - if (alloc_ == nullptr) { - Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); - } - OrtAllocator* alloc = alloc_; - - void* cuda_stream; - Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); - cudaStream_t stream = static_cast(cuda_stream); - - // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache - // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even - // if they share the same compute capacity Prepare cache name - std::string cache_path = ""; - // Customize cache prefix if assigned - if (!cache_prefix_.empty()) { - cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->cache_prefix) + trt_state->cache_suffix; - } else { - cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); - } - - // Enable hardware compatility mode if assigned - std::string cache_hw_compat = "_sm" + compute_capability_; -#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 - if (engine_cache_enable_ && engine_hw_compatible_) { - cache_hw_compat = "_sm80+"; - //LOGS_DEFAULT(VERBOSE) - // << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; - } -#endif - - // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache - // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even - // if they share the same compute capacity - const std::string cache_path_prefix = cache_path + cache_hw_compat; - std::string engine_cache_path = cache_path_prefix + ".engine"; - const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; - const std::string profile_cache_path = cache_path_prefix + ".profile"; - std::string timing_cache_path = ""; - if (timing_cache_enable_) { - timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_); - } - - // If weight-stripped engine is enabled and refitted engine cache is not present, - // TRT EP will use the engine cache with ".stripped.engine" appended to the end. - const std::filesystem::path engine_cache_fs_path = engine_cache_path; - if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { - engine_cache_path = cache_path_prefix + ".stripped.engine"; - weight_stripped_engine_refit_ = true; - } - - // Load serialized engine - if (trt_state->engine_cache_enable && trt_engine == nullptr) { - std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); - std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); - if (engine_file && !trt_state->engine_decryption_enable && profile_file) { - // Deserialize profile - shape_ranges = DeserializeProfileV2(profile_file); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; - - // Prepare buffer - engine_file.seekg(0, std::ios::end); - size_t engine_size = engine_file.tellg(); - engine_file.seekg(0, std::ios::beg); - std::unique_ptr engine_buf{new char[engine_size]}; - engine_file.read((char*)engine_buf.get(), engine_size); - - // Deserialize engine - // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_state->engine->reset(); - *(trt_state->engine) = std::unique_ptr( - trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); - if (!(*(trt_state->engine))) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); - } - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; - trt_engine = trt_state->engine->get(); - context_update = true; - - } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && - profile_file) { - shape_ranges = DeserializeProfileV2(profile_file); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; - // Decrypt engine - size_t engine_size = 0; - if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not get engine buffer size"); - } - std::unique_ptr engine_buf{new char[engine_size]}; - if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not call engine decryption function decrypt"); - } - // Deserialize engine - // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_state->engine->reset(); - *(trt_state->engine) = std::unique_ptr( - trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); - if (!(*(trt_state->engine))) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); - } - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; - trt_engine = trt_state->engine->get(); - context_update = true; - } - } - - // Check and update shape ranges for dynamic shape inputs. - for (int i = 0, end = num_inputs; i < end; ++i) { - auto input = trt_state->network->get()->getInput(i); - const std::string& input_name = input->getName(); - input_names.insert(input_name); - - // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile - // shape values have not yet resolved. TRT EP will help determine the min/max/opt profile values based on current - // input tensor value. - if (shape_ranges.find(input_name) != shape_ranges.end()) { - auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, - shape_tensor_values, shape_tensor_values_int64, stream, - &engine_update); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP failed to parse input tensor and generate optimization profiles."); - } - } - } - - // Regenerate engine - if (engine_update) { - // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined - // behavior. - trt_state->context->reset(); - trt_state->engine->reset(); - auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); - if (max_workspace_size_ > 0) { - trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); - } - for (auto trt_profile : trt_profiles) { - trt_config->addOptimizationProfile(trt_profile); - } -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - // Set INT8 Per Tensor Dynamic range - if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { - trt_config->setInt8Calibrator(nullptr); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); - } - } -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - // Set precision - if (trt_state->int8_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; - } - if (trt_state->fp16_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; - } - if (trt_state->bf16_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kBF16); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] BF16 mode is enabled"; - } -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - // Set DLA (DLA can only run with FP16 or INT8) - if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; - trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); - trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); - trt_config->setDLACore(trt_state->dla_core); - } - - // enable sparse weights - if (trt_state->sparsity_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; - } -#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 - // enable builder heuristics - if (trt_state->build_heuristics_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; - } -#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 - // switch optimizaion level - if (trt_state->builder_optimization_level != 3) { - trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; - } - - // limit auxiliary streams - if (trt_state->auxiliary_streams >= 0) { - trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; - } -#else - if (trt_state->builder_optimization_level != 3) { - //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; - } - if (trt_state->auxiliary_streams >= 0) { - //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; - } -#endif - if (weight_stripped_engine_enable_) { -#if NV_TENSORRT_MAJOR >= 10 - trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; - trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; -#else - //LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; -#endif - } - // limit used tactic sources - if (trt_state->filter_tactic_sources) { - nvinfer1::TacticSources tactics = trt_config->getTacticSources(); - tactics |= trt_state->tactic_sources; - trt_config->setTacticSources(tactics); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; - } - - // Load timing cache from file. Create a fresh cache if the file doesn't exist - std::unique_ptr timing_cache = nullptr; - if (trt_state->timing_cache_enable) { - std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); - timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), - loaded_timing_cache.size())); - if (timing_cache == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not create timing cache: " + timing_cache_path); - } - trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); - if (detailed_build_log_) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; - } - } - -#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 - // Enable hardware compatility mode if assigned - if (trt_state->engine_hw_compatible) { - trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); - //LOGS_DEFAULT(INFO) << "[TensorRT EP] Re-generate engine with hardware compatibility enabled."; - } -#endif - - // Set preview feature flags - for (auto feature : trt_state->preview_features) { - trt_config->setPreviewFeature(feature, true); - } - - // Build engine - std::unique_ptr serialized_engine; - { - auto lock = GetApiLock(); - std::chrono::steady_clock::time_point engine_build_start; - if (detailed_build_log_) { - engine_build_start = std::chrono::steady_clock::now(); - } - serialized_engine = std::unique_ptr( - trt_builder->buildSerializedNetwork(*trt_state->network->get(), *trt_config)); - if (!serialized_engine) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create engine from network."); - } - *(trt_state->engine) = std::unique_ptr( - trt_state->runtime->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); - if (!(*(trt_state->engine))) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to deserialize engine."); - } - if (detailed_build_log_) { - auto engine_build_stop = std::chrono::steady_clock::now(); - //LOGS_DEFAULT(INFO) - << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " - << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() - << "ms" << std::endl; - } - } - if (!(*(trt_state->engine))) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); - } - trt_engine = trt_state->engine->get(); - if (trt_state->engine_cache_enable) { - // Serialize engine profile - SerializeProfileV2(profile_cache_path, shape_ranges); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; - - // Serialize engine - if (trt_state->engine_decryption_enable) { - // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available - // first. - if (trt_state->engine_encryption != nullptr) { - if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), - reinterpret_cast(serialized_engine->data()), - serialized_engine->size())) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not call engine encryption function encrypt"); - } - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; - } else { - //LOGS_DEFAULT(WARNING) - << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; - } - } else { - std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); - file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; - } - } - - // serialize and save timing cache - if (trt_state->timing_cache_enable) { - auto timing_cache = trt_config->getTimingCache(); - std::unique_ptr timingCacheHostData{timing_cache->serialize()}; - if (timingCacheHostData == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not serialize timing cache: " + timing_cache_path); - } - saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); - if (detailed_build_log_) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; - } - } - - // dump ep context model - if (dump_ep_context_model_ && ep_context_embed_mode_) { - UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast(serialized_engine->data()), - serialized_engine->size()); - DumpCtxModel(model_proto_.get(), ctx_model_path_); - } - context_update = true; - - if (weight_stripped_engine_refit_) { - auto status = - RefitEngine(model_path_, onnx_model_folder_path_, engine_cache_path, false /* path check for security */, - onnx_model_bytestream_, onnx_model_bytestream_size_, trt_engine, - true /* serialize refitted engine to disk */, detailed_build_log_); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); - } - } - } - - if (context_update) { - if (trt_state->context_memory_sharing_enable) { -#if NV_TENSORRT_MAJOR < 10 - *(trt_state->context) = std::unique_ptr( - trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); -#else - *(trt_state->context) = - std::unique_ptr(trt_state->engine->get()->createExecutionContext( - nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); -#endif - } else { - *(trt_state->context) = - std::unique_ptr(trt_state->engine->get()->createExecutionContext()); - } - if (!(*(trt_state->context))) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); - } - trt_context = trt_state->context->get(); - } - - // Check before using trt_engine - if (trt_engine == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "No engine is found."); - } - - // Get input and output binding names - int total_bindings = trt_engine->getNbIOTensors(); - std::vector input_binding_names, output_binding_names; - for (int i = 0, end = total_bindings; i < end; ++i) { - auto const& name = trt_engine->getIOTensorName(i); - auto const& mode = trt_engine->getTensorIOMode(name); - if (mode == nvinfer1::TensorIOMode::kINPUT) { - input_binding_names.push_back(name); - } else { - output_binding_names.push_back(name); - } - } - - /* - * Set input shapes and bind input buffers - */ - std::vector> scratch_buffers; - for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { - char const* input_name = input_binding_names[i]; - - size_t input_index = 0; - const auto iter = input_indexes.find(input_name); - if (iter != input_indexes.end()) { - input_index = iter->second; - } - auto input_tensor = ctx.GetInput(input_index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shapes = tensor_info.GetShape(); - - auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, - shape_tensor_values_int64, scratch_buffers, alloc, stream); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); - } - } - - /* - * Set output shapes and bind output buffers - */ - std::unordered_map buffers; - buffers.reserve(num_outputs); - using OutputOrtValue = Ort::UnownedValue; - std::unordered_map output_tensors; - output_tensors.reserve(num_outputs); - std::unordered_map output_dim_sizes; - output_dim_sizes.reserve(num_outputs); - - for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { - char const* output_name = output_binding_names[i]; - - size_t output_index = 0; - const auto& index_iter = output_indexes.find(output_name); - if (index_iter != output_indexes.end()) { - output_index = index_iter->second; - } - - size_t output_type = 0; - const auto type_iter = output_types.find(output_name); - if (type_iter != output_types.end()) { - output_type = type_iter->second; - } - - Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, - output_dim_sizes, dds_output_allocator_map, scratch_buffers, alloc, buffers); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); - } - } - - // Set execution context memory - if (trt_state->context_memory_sharing_enable) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - size_t mem_size = trt_engine->getDeviceMemorySize(); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (mem_size > *max_context_mem_size_ptr) { - *max_context_mem_size_ptr = mem_size; - *context_memory = - IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr, true /*use_reserve*/); - } - trt_context->setDeviceMemory((*context_memory).get()); - } - - // Start CUDA graph capture. - // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because - // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { - //LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; - cuda_graph_.SetStream(stream); - CaptureBegin(0); - } - - // Run TRT inference - if (!trt_context->enqueueV3(stream)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); - } - - /* - * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this - * function concurrently, TRT EP needs to carefully take care of concurrency here, if not, following concurrent - * issue might happen: - * - * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per - * stream. In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling - * InferenceSession::Run() concurrently, the trt execution context instance is shared by all the threads and each - * thread aquires different stream from ORT. So TRT EP will end up having one trt execution context using multiple - * streams which is not suggested. But, since the whole compute_func() is protected by the lock and if - * cudaStreamSynchronize() is enforced here, one trt execution context per stream is guaranteed. - * - * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all - * operations to prevent the concurrent issue mentioned above. However, if cuda graph is enabled, TRT EP won't call - * cudaStreamSynchronize() since it's not allowed during graph capture. - */ - if (sync_stream_after_enqueue_) { - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); - } - - // Assign TRT output back to ORT output - // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) - // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output - for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { - char const* output_name = output_binding_names[i]; - - size_t output_type = 0; - const auto& iter = output_types.find(output_name); - if (iter != output_types.end()) { - output_type = iter->second; - } - - if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { - size_t output_index = 0; - const auto& index_iter = output_indexes.find(output_name); - if (index_iter != output_indexes.end()) { - output_index = index_iter->second; - } - auto status = - BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); - } - } else { - auto& output_tensor = output_tensors[i]; -#if NV_TENSORRT_MAJOR < 10 - if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), - output_tensor_ptr, output_dim_sizes[i]); - } - } -#endif - if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, - output_dim_sizes[i]); - } - } - } - } - - // End CUDA graph capture. - // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream - // mentioned in graph capture above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and - // ExecuteGraph() per inference_session.cc. It's safe to start/end CUDA graph capture in compute_func() here since - // cuda graph object is maintained by a per thread basis. - if (cuda_graph_enable_ && !IsGraphCaptured(0)) { - if (IsGraphCaptureAllowed()) { - CaptureEnd(0); - // CUDA work issued to a capturing stream doesn�t actually run on the GPU, - // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(ReplayGraph(0)); - } else { - IncrementRegularRunCountBeforeGraphCapture(); - } - } - - return Status::OK(); - }; - - node_compute_funcs.push_back(compute_info); + node_compute_info = node_compute_info.release(); return Status::OK(); } @@ -2630,10 +1962,10 @@ static OrtStatus* ORT_API_CALL CompileImpl(OrtEp* this_ptr, const OrtGraph** gra if (false) { status = ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[graph_idx], fused_node, input_map, - output_map, node_compute_infos[graph_idx]); + output_map, &node_compute_infos[graph_idx]); } else { status = ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[graph_idx], fused_node, input_map, - output_map, node_compute_infos[graph_idx]); + output_map, &node_compute_infos[graph_idx]); } //if (status != Status::OK()) { // return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index 2b925b1f..f3f4d131 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -152,6 +152,9 @@ class OutputAllocator : public nvinfer1::IOutputAllocator { using ShapeRangesMap = std::unordered_map>>>; +template +using IAllocatorUniquePtr = std::unique_ptr>; + struct TensorrtComputeState { std::string fused_node_name; nvinfer1::IBuilder* builder; @@ -168,7 +171,6 @@ struct TensorrtComputeState { bool int8_calibration_cache_available = false; bool dla_enable = false; int dla_core = 0; - size_t* max_workspace_size_ptr = nullptr; std::string trt_node_name_with_precision; bool engine_cache_enable = false; std::string engine_cache_path; @@ -176,6 +178,7 @@ struct TensorrtComputeState { std::vector profiles; bool context_memory_sharing_enable = false; size_t* max_context_mem_size_ptr = nullptr; + IAllocatorUniquePtr* context_memory = nullptr; std::unordered_map dynamic_range_map; bool engine_decryption_enable = false; int (*engine_decryption)(const char*, char*, size_t*) = nullptr; @@ -215,11 +218,6 @@ static const std::string k_cc_hw_compatible = "80+"; static const std::string k_ep_ctx_hardware_architecture = "hardware_architecture"; static const std::string k_ep_ctx_onnx_model_filename = "onnx_model_filename"; -struct ApiPtrs { - const OrtApi& ort_api; - const OrtEpApi& ep_api; -}; - /// /// /// Plugin TensorRT EP OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. @@ -346,7 +344,7 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs { bool context_memory_sharing_enable_ = false; bool layer_norm_fp32_fallback_ = false; size_t max_ctx_mem_size_ = 0; - // IAllocatorUniquePtr context_memory_ = nullptr; + IAllocatorUniquePtr context_memory_ = nullptr; mutable char model_path_[4096] = {}; // Reserved for max path length bool engine_decryption_enable_ = false; int (*engine_decryption_)(const char*, char*, size_t*) = nullptr; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc new file mode 100644 index 00000000..82f9941e --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "tensorrt_execution_provider_data_transfer.h" + +#include +#include + +void CUDA_RETURN_IF_ERROR(cudaError_t res); + +/*static*/ +bool ORT_API_CALL TRTEpDataTransfer::CanCopyImpl(void* this_ptr, + const OrtMemoryDevice* src_memory_device, + const OrtMemoryDevice* dst_memory_device) noexcept { + auto& impl = *static_cast(this_ptr); + bool src_is_our_device = impl.ep_api.MemoryDevice_AreEqual(src_memory_device, impl.device_mem_info); + bool dst_is_our_device = impl.ep_api.MemoryDevice_AreEqual(dst_memory_device, impl.device_mem_info); + + return src_is_our_device || dst_is_our_device; +} + +// function to copy one or more tensors. +// implementation can optionally use async copy if a stream is available for the input. +/*static*/ +OrtStatus* ORT_API_CALL TRTEpDataTransfer::CopyTensorsImpl(void* this_ptr, + const OrtValue** src_tensors_ptr, + OrtValue** dst_tensors_ptr, + OrtSyncStream** streams_ptr, + size_t num_tensors) noexcept { + auto& impl = *static_cast(this_ptr); + + auto src_tensors = gsl::make_span(src_tensors_ptr, num_tensors); + auto dst_tensors = gsl::make_span(dst_tensors_ptr, num_tensors); + auto streams = gsl::make_span(streams_ptr, num_tensors); + + for (size_t i = 0; i < num_tensors; ++i) { + // NOTE: Stream support will be a separate PR. ignore teh streams_ptr values for now + + const OrtMemoryDevice* src_device = nullptr; + const OrtMemoryDevice* dst_device = nullptr; + RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(src_tensors[i], &src_device)); + RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(dst_tensors[i], &dst_device)); + + OrtMemoryInfoDeviceType src_device_type = impl.ep_api.MemoryDevice_GetDeviceType(src_device); + OrtMemoryInfoDeviceType dst_device_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_device); + OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_device); + OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_device); + bool copy_involves_pinned_memory = src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE || + dst_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE; + + const void* src_data = nullptr; + void* dst_data = nullptr; + RETURN_IF_ERROR(impl.ort_api.GetTensorData(src_tensors[i], &src_data)); + RETURN_IF_ERROR(impl.ort_api.GetTensorMutableData(dst_tensors[i], &dst_data)); + + size_t bytes = 0; + RETURN_IF_ERROR(impl.ort_api.GetTensorSizeInBytes(reinterpret_cast(src_data), &bytes)); + + // for the sync version of memcpy, launch to cuda default stream + if (dst_device_type == OrtMemoryInfoDeviceType_GPU) { + if (src_device_type == OrtMemoryInfoDeviceType_GPU) { + // GPU -> GPU + // Copy only if the two addresses are different and bytes > 0. + if (dst_data != src_data && bytes > 0) { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice)); + // For device memory to device memory copy, no host-side synchronization is performed by cudaMemcpy. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } + } else { + // CPU -> GPU, this is blocking + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); + if (src_mem_type != OrtDeviceMemoryType_HOST_ACCESSIBLE) { + // For cudaMemcpy from pageable host memory to device memory, DMA to final destination may not have completed. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } + } + } else if (src_device_type == OrtMemoryInfoDeviceType_GPU) { + // GPU -> CPU, this is blocking + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost)); + } else { + // CPU -> CPU involves copy to/from pinned memory and a synchronize may be required first + //ORT_ENFORCE(dst_data != src_data); + memcpy(dst_data, src_data, bytes); + } + } + + return nullptr; +} + +/*static*/ +void ORT_API_CALL TRTEpDataTransfer::ReleaseImpl(void* this_ptr) noexcept { + // In our setup the factory owns a shared ExampleDataTransfer instance so it will do the cleanup, and we ignore + // the call to Release from the plugin_ep::DataTransfer dtor (see /onnxruntime/core/framework/plugin_data_transfer.h) + // + // If you create a new instance on each call to OrtEpFactory::CreateDataTransfer you call `delete` here + delete static_cast(this_ptr); +} diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h new file mode 100644 index 00000000..a72ff453 --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "tensorrt_execution_provider_utils.h" + +struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs { + TRTEpDataTransfer(ApiPtrs api_ptrs, const OrtMemoryDevice* device_mem_info_, + const OrtMemoryDevice* shared_mem_info_ = nullptr) + : ApiPtrs(api_ptrs), device_mem_info{device_mem_info_}, shared_mem_info{shared_mem_info_} { + CanCopy = CanCopyImpl; + CopyTensors = CopyTensorsImpl; + Release = ReleaseImpl; + } + + static bool ORT_API_CALL CanCopyImpl(void* this_ptr, const OrtMemoryDevice* src_memory_device, + const OrtMemoryDevice* dst_memory_device) noexcept; + + // function to copy one or more tensors. + // implementation can optionally use async copy if a stream is available for the input. + static OrtStatus* ORT_API_CALL CopyTensorsImpl(void* this_ptr, const OrtValue** src_tensors_ptr, + OrtValue** dst_tensors_ptr, OrtSyncStream** streams_ptr, + size_t num_tensors) noexcept; + static void ORT_API_CALL ReleaseImpl(void* this_ptr) noexcept; + + private: + const OrtMemoryDevice* device_mem_info; + const OrtMemoryDevice* shared_mem_info; +}; \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc index 8a34cf0c..c7062af5 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc @@ -7,7 +7,6 @@ #include "provider_options_utils.h" #include "cuda/cuda_common.h" -namespace onnxruntime { namespace tensorrt { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; @@ -336,4 +335,3 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions // trt_provider_options_v2.trt_ep_context_file_path = copy_string_if_needed(internal_options.ep_context_file_path); // trt_provider_options_v2.trt_engine_hw_compatible = internal_options.engine_hw_compatible; //} -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h index 5ca1d6df..16304db1 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h @@ -9,7 +9,6 @@ #define TRT_DEFAULT_OPTIMIZER_LEVEL 3 -namespace onnxruntime { // Information needed to construct trt execution providers. struct TensorrtExecutionProviderInfo { int device_id{0}; @@ -55,11 +54,10 @@ struct TensorrtExecutionProviderInfo { std::string engine_cache_prefix{""}; bool engine_hw_compatible{false}; - static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); + static TensorrtExecutionProviderInfo FromProviderOptions(const onnxruntime::ProviderOptions& options); // static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info); // static ProviderOptions ToProviderOptions(const OrtTensorRTProviderOptionsV2& info); // static void UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy); // // std::vector custom_op_domain_list; }; -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h index 2ff91908..54fa0ed4 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h @@ -1,3 +1,15 @@ +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +#include "flatbuffers/idl.h" +#include "ort_trt_int8_cal_table.fbs.h" +// #include "core/providers/cuda/cuda_pch.h" +// #include "core/common/path_string.h" +// #include "core/framework/murmurhash3.h" + +#include"nv_includes.h" + #include #include #include @@ -5,16 +17,29 @@ #include #include #include -#include "flatbuffers/idl.h" -#include "ort_trt_int8_cal_table.fbs.h" -#include -//#include "core/providers/cuda/cuda_pch.h" -//#include "core/common/path_string.h" -//#include "core/framework/murmurhash3.h" -namespace fs = std::filesystem; +#define RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return _status; \ + } \ + } while (0) + +#define RETURN_IF(cond, ort_api, msg) \ + do { \ + if ((cond)) { \ + return (ort_api).CreateStatus(ORT_EP_FAIL, (msg)); \ + } \ + } while (0) + +struct ApiPtrs { + const OrtApi& ort_api; + const OrtEpApi& ep_api; + const OrtModelEditorApi& model_editor_api; +}; -//namespace onnxruntime { +namespace fs = std::filesystem; // Check if cycle exists in the graph after partitioning /* @@ -143,6 +168,7 @@ std::vector SplitToStringVec(std::string const& s, char separator) return splitted; } +/* nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { nvinfer1::TacticSources disabledTactics = 0; nvinfer1::TacticSources enabledTactics = 0; @@ -197,6 +223,7 @@ nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { } return enabledTactics & ~disabledTactics; } +*/ inline std::vector loadTimingCacheFile(const std::string inFileName) { std::ifstream iFile(inFileName, std::ios::in | std::ios::binary); @@ -968,4 +995,3 @@ std::string GetCacheSuffix(const std::string& fused_node_name, const std::string } return ""; } -//} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index 4937e1c4..03d4e902 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -3,6 +3,7 @@ #undef ORT_API_MANUAL_INIT #include "tensorrt_provider_factory.h" #include "tensorrt_execution_provider.h" +#include "cuda_allocator.h" #include #include @@ -12,19 +13,61 @@ #include #include -//struct TensorrtExecutionProvider; +TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* ep_name, ApiPtrs apis) + : ApiPtrs(apis), ep_name_{ep_name} { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + + GetSupportedDevices = GetSupportedDevicesImpl; + + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + + CreateDataTransfer = CreateDataTransferImpl; + + // Default GPU allocator OrtMemoryInfo + OrtMemoryInfo* mem_info = nullptr; + auto* status = ort_api.CreateMemoryInfo_V2("ExampleEP GPU", OrtMemoryInfoDeviceType_GPU, + /*vendor*/ 0x10DE, /* device_id */ 0, OrtDeviceMemoryType_DEFAULT, + /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info); + assert(status == nullptr); // should never fail. + default_gpu_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); + + // CUDA PINNED allocator OrtMemoryInfo + // HOST_ACCESSIBLE memory should use the non-CPU device type + mem_info = nullptr; + status = ort_api.CreateMemoryInfo_V2("ExampleEP GPU pinned", OrtMemoryInfoDeviceType_GPU, + /*vendor*/ 0x10DE, /* device_id */ 0, OrtDeviceMemoryType_HOST_ACCESSIBLE, + /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info); + assert(status == nullptr); // should never fail. + host_accessible_gpu_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); + + // Create gpu data transfer + data_transfer_impl_ = std::make_unique( + apis, + ep_api.MemoryInfo_GetMemoryDevice(default_gpu_memory_info_.get()), // device memory + ep_api.MemoryInfo_GetMemoryDevice(host_accessible_gpu_memory_info_.get()) // shared memory + ); + + data_transfer_impl_.reset(); // but we're CPU only so we return nullptr for the IDataTransfer. +} -static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) { +const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetNameImpl(const OrtEpFactory* this_ptr) { const auto* factory = static_cast(this_ptr); return factory->ep_name_.c_str(); } -static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) { +const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetVendorImpl(const OrtEpFactory* this_ptr) { const auto* factory = static_cast(this_ptr); return factory->vendor_.c_str(); } -static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, +OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImpl( + OrtEpFactory* this_ptr, const OrtHardwareDevice* const* devices, size_t num_devices, OrtEpDevice** ep_devices, @@ -77,7 +120,8 @@ static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, return nullptr; } -static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, +OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateEpImpl( + OrtEpFactory* this_ptr, _In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/, _In_ size_t num_devices, @@ -110,11 +154,58 @@ static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, return nullptr; } -static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) { +void ORT_API_CALL TensorrtExecutionProviderFactory::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) { TensorrtExecutionProvider* trt_ep = static_cast(ep); delete trt_ep; } +OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl( + OrtEpFactory* this_ptr, const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + auto& factory = *static_cast(this_ptr); + *allocator = nullptr; + + // NOTE: The factory implementation can return a shared OrtAllocator* instead of creating a new instance on each call. + // To do this just make ReleaseAllocatorImpl a no-op. + + // NOTE: If OrtMemoryInfo has allocator type (call MemoryInfoGetType) of OrtArenaAllocator, an ORT BFCArena + // will be added to wrap the returned OrtAllocator. The EP is free to implement its own arena, and if it + // wants to do this the OrtMemoryInfo MUST be created with an allocator type of OrtDeviceAllocator. + + // NOTE: The OrtMemoryInfo pointer should only ever be coming straight from an OrtEpDevice, and pointer based + // matching should work. + if (memory_info == factory.default_gpu_memory_info_.get()) { + // create a CUDA allocator + auto cuda_allocator = std::make_unique(memory_info); + *allocator = cuda_allocator.release(); + } else if (memory_info == factory.host_accessible_gpu_memory_info_.get()) { + // create a CUDA PINNED allocator + auto cuda_pinned_allocator = std::make_unique(memory_info); + *allocator = cuda_pinned_allocator.release(); + } else { + return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "INTERNAL ERROR! Unknown memory info provided to CreateAllocator. " + "Value did not come directly from an OrtEpDevice returned by this factory."); + } + + return nullptr; +} + +void ORT_API_CALL TensorrtExecutionProviderFactory::ReleaseAllocatorImpl(OrtEpFactory* /*this*/, + OrtAllocator* allocator) noexcept { + delete static_cast(allocator); +} + +OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl( + OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept { + auto& factory = *static_cast(this_ptr); + *data_transfer = factory.data_transfer_impl_.get(); + + return nullptr; +} + // To make symbols visible on macOS/iOS #ifdef __APPLE__ #define EXPORT_SYMBOL __attribute__((visibility("default"))) diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h index 94283548..a3dd58b9 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -1,56 +1,50 @@ -#define ORT_API_MANUAL_INIT -#include "onnxruntime_cxx_api.h" -#undef ORT_API_MANUAL_INIT - -#define RETURN_IF_ERROR(fn) \ - do { \ - OrtStatus* status = (fn); \ - if (status != nullptr) { \ - return status; \ - } \ - } while (0) - -#define RETURN_IF(cond, ort_api, msg) \ - do { \ - if ((cond)) { \ - return (ort_api).CreateStatus(ORT_EP_FAIL, (msg)); \ - } \ - } while (0) - -static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr); -static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr); -static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices); -static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, - _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, - _Out_ OrtEp** ep); -static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep); - -struct ApiPtrs { - const OrtApi& ort_api; - const OrtEpApi& ep_api; -}; +#include "tensorrt_execution_provider_utils.h" +#include "tensorrt_execution_provider_data_transfer.h" /// /// Plugin TensorRT EP factory that can create an OrtEp and return information about the supported hardware devices. /// struct TensorrtExecutionProviderFactory : OrtEpFactory, ApiPtrs { - TensorrtExecutionProviderFactory(const char* ep_name, ApiPtrs apis) : ApiPtrs(apis), ep_name_{ep_name} { - ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. - GetName = GetNameImpl; - GetVendor = GetVendorImpl; - GetSupportedDevices = GetSupportedDevicesImpl; - CreateEp = CreateEpImpl; - ReleaseEp = ReleaseEpImpl; - } + public: + TensorrtExecutionProviderFactory(const char* ep_name, ApiPtrs apis); + + private: + static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; + + static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, size_t num_devices, + OrtEpDevice** ep_devices, size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept; + + static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, size_t num_devices, + const OrtSessionOptions* session_options, const OrtLogger* logger, + OrtEp** ep) noexcept; + + static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept; + + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr, const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept; + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this*/, OrtAllocator* allocator) noexcept; + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept; + const std::string ep_name_; // EP name const std::string vendor_{"Nvidia"}; // EP vendor name + + // CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed. + using MemoryInfoUniquePtr = std::unique_ptr>; + //MemoryInfoUniquePtr cpu_memory_info_; + + // GPU memory and pinned/shared memory are required for data transfer, these are the + // OrtMemoryInfo instance required for that. + MemoryInfoUniquePtr default_gpu_memory_info_; + MemoryInfoUniquePtr host_accessible_gpu_memory_info_; + + std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory }; \ No newline at end of file From 3d6fa57dfa9e2048059f48050050d0d89cd097e3 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 2 Jul 2025 12:10:20 -0700 Subject: [PATCH 13/60] fix a lot of compile errors --- .../cuda/cu_inc/unary_elementwise_impl.cuh | 2 - .../cuda/unary_elementwise_ops_impl.cu | 2 - .../cuda/unary_elementwise_ops_impl.h | 3 - .../tensorrt/tensorrt_execution_provider.cc | 797 +++++++----------- .../tensorrt/tensorrt_execution_provider.h | 103 ++- .../tensorrt_execution_provider_utils.h | 43 +- .../tensorrt/tensorrt_provider_factory.cc | 8 +- .../tensorrt/tensorrt_provider_factory.h | 3 +- 8 files changed, 426 insertions(+), 535 deletions(-) diff --git a/plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh b/plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh index 87cf7c83..7b16741b 100644 --- a/plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh +++ b/plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh @@ -4,7 +4,6 @@ #pragma once #include -namespace onnxruntime { namespace cuda { // We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer @@ -75,4 +74,3 @@ void UnaryElementWiseImpl( } } // namespace cuda -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu b/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu index ad515a23..0ceb9454 100644 --- a/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu +++ b/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu @@ -9,7 +9,6 @@ #endif #include -namespace onnxruntime { namespace cuda { @@ -90,4 +89,3 @@ IMPL_CAST_IMPL_FROM(bool) //IMPL_CAST_IMPL_FROM(BFloat16) } // namespace cuda -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h b/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h index 392cf46f..184426a9 100644 --- a/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h +++ b/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h @@ -7,7 +7,6 @@ #include #include -namespace onnxruntime { namespace cuda { // Cast @@ -50,5 +49,3 @@ void Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, si } } // namespace cuda - -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index fa532dc2..2e4b3915 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -10,8 +10,8 @@ #undef ORT_API_MANUAL_INIT #include "ep_abi_utils.h" +//#include "tensorrt_execution_provider_utils.h" #include "tensorrt_execution_provider.h" -#include "tensorrt_execution_provider_utils.h" #include "cuda_allocator.h" //#include "onnx_ctx_model_helper.h" #include "onnx/onnx_pb.h" @@ -29,85 +29,14 @@ #define LIBFUNC(lib, fn) dlsym((lib), (fn)) #endif +const OrtApi* g_ort_api = nullptr; +const OrtEpApi* g_ep_api = nullptr; +const OrtModelEditorApi* g_model_editor_api = nullptr; + void CUDA_RETURN_IF_ERROR(cudaError_t res) { if (res != cudaSuccess) abort(); } -const OrtApi& ort_api = Ort::GetApi(); - -/* -struct MemcpyFromHost : OrtCustomOp { - MemcpyFromHost() { - OrtCustomOp::version = ORT_API_VERSION; - OrtCustomOp::GetName = [](const struct OrtCustomOp* op) { return "MemcpyFromHost"; }; - OrtCustomOp::GetExecutionProviderType = [](const struct OrtCustomOp* op) { return tensorrtEp.c_str(); }; - OrtCustomOp::CreateKernelV2 = [](const struct OrtCustomOp* op, const OrtApi* api, const OrtKernelInfo* info, void** kernel) -> OrtStatusPtr { - return nullptr; - }; - OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { - const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); - void* stream = nullptr; - api->KernelContext_GetGPUComputeStream(context, &stream); - - const OrtValue* input = nullptr; - api->KernelContext_GetInput(context, 0, &input); - OrtTensorTypeAndShapeInfo* shape_info; - api->GetTensorTypeAndShape(input, &shape_info); - size_t dim_count = 0; - api->GetDimensionsCount(shape_info, &dim_count); - std::vector dim(dim_count, 0); - api->GetDimensions(shape_info, dim.data(), dim_count); - - OrtValue* output = nullptr; - api->KernelContext_GetOutput(context, 0, dim.data(), dim.size(), &output); - - void *input_raw = nullptr, *output_raw = nullptr; - api->GetTensorMutableData(const_cast(input), &input_raw); - api->GetTensorMutableData(output, &output_raw); - - size_t count = dim[0]; - for (size_t i = 1; i < dim_count; i++) count *= dim[i]; - cudaMemcpyAsync(output_raw, input_raw, count * sizeof(float), cudaMemcpyHostToDevice, static_cast(stream)); // TODO(leca): other data type - - return nullptr; - }; - OrtCustomOp::GetInputTypeCount = [](const struct OrtCustomOp* op) -> size_t { return 1; }; - OrtCustomOp::GetOutputTypeCount = [](const struct OrtCustomOp* op) -> size_t { return 1; }; - OrtCustomOp::GetInputMemoryType = [](const struct OrtCustomOp* op, size_t index) { return OrtMemType::OrtMemTypeCPUInput; }; - OrtCustomOp::GetInputType = [](const struct OrtCustomOp* op, size_t index) { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; - OrtCustomOp::GetOutputType = [](const struct OrtCustomOp* op, size_t index) { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; // TODO(leca): other data type - OrtCustomOp::GetStartVersion = [](const struct OrtCustomOp* op) { return 1; }; - } -}; -*/ - -bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept { - size_t alloc_size = size; - if (alignment == 0) { - *out = alloc_size * nmemb; - } else { - size_t alignment_mask = alignment - 1; - *out = (alloc_size * nmemb + alignment_mask) & ~static_cast(alignment_mask); - } - return true; -} - -template -IAllocatorUniquePtr MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator, size_t count_or_bytes) { - size_t alloc_size = count_or_bytes; - // if T is not void, 'count_or_bytes' == number of items so allow for that - if constexpr (!std::is_void::value) { - // sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't - // reachable if T is void. use std::conditional to 'use' void* in the sizeof call - constexpr auto size = sizeof(typename std::conditional::value, void*, T>::type); - CalcMemSizeForArrayWithAlignment(count_or_bytes, size, 0, &alloc_size); - } - - T* p = static_cast(ort_allocator->Alloc(ort_allocator, alloc_size)); - - return IAllocatorUniquePtr{p, [ort_allocator](T* p) { ort_allocator->Free(ort_allocator, p); }}; -} - #if NV_TENSORRT_MAJOR >= 10 void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept { @@ -380,7 +309,7 @@ OrtStatusPtr ApplyProfileShapesFromInputTensorValue(std::vectorCreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported.").c_str()); } } @@ -543,7 +472,7 @@ OrtStatusPtr BindContextInput(Ort::KernelContext& ctx, size_t input_index, std::unordered_map>& shape_tensor_values, std::unordered_map>& shape_tensor_values_int64, - std::vector>& scratch_buffers, + std::vector>& scratch_buffers, OrtAllocator* alloc, cudaStream_t stream) { auto input_tensor = ctx.GetInput(input_index); @@ -584,7 +513,7 @@ OrtStatusPtr BindContextInput(Ort::KernelContext& ctx, std::string error_msg = "TensorRT EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + error_input_name + "'"; - return ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + return g_ort_api->CreateStatus(ORT_EP_FAIL, error_msg.c_str()); } break; } @@ -604,13 +533,13 @@ OrtStatusPtr BindContextInput(Ort::KernelContext& ctx, std::string error_msg = "TensorRT EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + error_input_name + "'"; - return ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + return g_ort_api->CreateStatus(ORT_EP_FAIL, error_msg.c_str()); } break; } default: { std::string error_input_name = input_name; - return ort_api.CreateStatus(ORT_EP_FAIL, std::string("The data type of shape tensor should be INT32 or INT64. Please check the data type of " + error_input_name).c_str()); + return g_ort_api->CreateStatus(ORT_EP_FAIL, std::string("The data type of shape tensor should be INT32 or INT64. Please check the data type of " + error_input_name).c_str()); } } } else { @@ -622,7 +551,7 @@ OrtStatusPtr BindContextInput(Ort::KernelContext& ctx, } if (!trt_context->setInputShape(input_name, dims)) { std::string error_input_name = input_name; - return ort_api.CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'").c_str()); + return g_ort_api->CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'").c_str()); } // Bind "execution tensor" input buffer @@ -645,16 +574,9 @@ OrtStatusPtr BindContextInput(Ort::KernelContext& ctx, CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) #endif // Cast double input to float because TensorRT doesn't support double - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - auto input_tensor_ptr = input_tensor.GetTensorData(); if (input_tensor_ptr != nullptr && elem_cnt > 0) { - scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(float))); data = scratch_buffers.back().get(); cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), elem_cnt); - } - else { - scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, 1)); data = scratch_buffers.back().get(); - } break; -} + CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { - return ort_api.CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported.").c_str()); + return g_ort_api->CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported.").c_str()); } } trt_context->setTensorAddress(input_name, data); @@ -672,7 +594,7 @@ OrtStatusPtr BindContextOutput(Ort::KernelContext& ctx, std::unordered_map& output_tensors, std::unordered_map& output_dim_sizes, DDSOutputAllocatorMap& dds_output_allocator_map, - std::vector>& scratch_buffers, + std::vector>& scratch_buffers, OrtAllocator* alloc, std::unordered_map& buffers) { // Get output shape @@ -724,7 +646,7 @@ OrtStatusPtr BindContextOutput(Ort::KernelContext& ctx, // Allocate float CUDA memory for double output type because TensorRT doesn't support double CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { - return ort_api.CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported.").c_str()); + return g_ort_api->CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported.").c_str()); } } trt_context->setTensorAddress(output_name, buffers[output_name]); @@ -783,7 +705,7 @@ OrtStatusPtr BindKernelOutput(Ort::KernelContext& ctx, // The allocation buffer holds the float output data since TRT doesn't support double. So, we need to cast the data (float -> double) for ORT kernel output. // CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) default: { - return ort_api.CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported.").c_str()); + return g_ort_api->CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported.").c_str()); } } return nullptr; @@ -794,67 +716,6 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect const OrtGraph* graph, bool* early_termination) const { // Return if iterations are exceeding predefined number SubGraphCollection_t nodes_list_output; - if (iterations > max_iterations) { - *early_termination = true; - return nodes_list_output; - } - - iterations++; - for (const auto& group : nodes_vector_input) { - // Construct subgraph - if (!group.first.empty()) { - if (group.second) { - nodes_list_output.push_back(group); - } else { - // const OrtGraphViewer* sub_graph_viewer = nullptr; - // graph_api_->OrtGraph_GetSubGraph(graph, group.first.size(), group.first.data(), &sub_graph_viewer); - - void* buf_data = nullptr; - size_t buf_size = 0; - graph_api_->OrtGraph_SerializeToArray(sub_graph_viewer, &buf_data, &buf_size); - - // Get supported node list recursively - SubGraphCollection_t parser_nodes_list; - TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); - auto trt_builder = GetBuilder(trt_logger); - auto network_flags = 0; -#if NV_TENSORRT_MAJOR > 8 - network_flags |= fp16_enable_ || int8_enable_ - ? 0 - : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); -#endif - network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); - auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); - - auto trt_parser = - tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - trt_parser->supportsModel(buf_data, buf_size, parser_nodes_list, model_path_); - graph_api_->OrtFreeMem(buf_data); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - - SubGraphCollection_t next_nodes_list; - const size_t* subgraph_node_index = nullptr; - size_t subgraph_node_count = 0; - graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(sub_graph_viewer, 1, &subgraph_node_index, - &subgraph_node_count); - next_nodes_list = - GetSupportedList(parser_nodes_list, iterations, max_iterations, sub_graph_viewer, early_termination); - for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { - for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { - next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]]; - } - nodes_list_output.push_back(next_nodes_list[i]); - } - graph_api_->OrtGraph_ReleaseGraphViewer(sub_graph_viewer, true); - } - } - } return nodes_list_output; } @@ -863,7 +724,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this const OrtNode* fused_node, std::unordered_map& input_map, std::unordered_map& output_map, - OrtNodeComputeInfo* node_compute_info) { + /* out */OrtNodeComputeInfo** node_compute_info) { TensorrtExecutionProvider* ep = static_cast(this_ptr); /* @@ -892,7 +753,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this auto trt_builder = GetBuilder(trt_logger); auto network_flags = 0; #if NV_TENSORRT_MAJOR > 8 - network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_) + network_flags |= (fp16_enable_ || int8_enable_) ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); #else @@ -912,7 +773,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this #pragma warning(push) #pragma warning(disable : 4996) #endif - if ((fp16_enable_ || bf16_enable_) && layer_norm_fp32_fallback_) { + if (fp16_enable_ && layer_norm_fp32_fallback_) { for (auto idx = 1; idx < trt_network->getNbLayers() - 1; ++idx) { auto layer = trt_network->getLayer(idx); auto next_layer = trt_network->getLayer(idx + 1); @@ -1076,7 +937,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } // Check platform availability for low precision - if (fp16_enable_ || bf16_enable_) { + if (fp16_enable_) { #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4996) @@ -1086,7 +947,6 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this #pragma warning(pop) #endif fp16_enable_ = false; - bf16_enable_ = false; //LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE or ORT_TENSORRT_BF16_ENABLE is set, but " // "platform doesn't support fast native fp16/bf16"; } @@ -1131,11 +991,6 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this trt_node_name_with_precision += "_fp16"; //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; } - if (bf16_enable_) { - trt_config->setFlag(nvinfer1::BuilderFlag::kBF16); - trt_node_name_with_precision += "_bf16"; - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] BF16 mode is enabled"; - } if (int8_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); trt_node_name_with_precision += "_int8"; @@ -1232,11 +1087,6 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using " << tactic_sources_; } - // Set preview feature flags - for (auto feature : preview_features_) { - trt_config->setPreviewFeature(feature, true); - } - // Build TRT engine (if needed) and load TRT engine if: // (1) Graph has no dynamic shape input // (2) All the dynamic shape inputs have associated explicit profiles specified by user @@ -1282,10 +1132,12 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this weight_stripped_engine_refit_ = true; } + /* // Generate file name for dumping ep context model if (dump_ep_context_model_ && ctx_model_path_.empty()) { ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_); } + */ if (!has_dynamic_shape) { std::string timing_cache_path = ""; @@ -1371,8 +1223,8 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); if (timing_cache == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not create timing cache: " + timing_cache_path); + std::string err_msg = "TensorRT EP could not create timing cache: " + timing_cache_path; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); if (detailed_build_log_) { @@ -1388,15 +1240,14 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this std::unique_ptr serialized_engine{ trt_builder->buildSerializedNetwork(*trt_network, *trt_config)}; if (serialized_engine == nullptr) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, EP_FAIL, - "TensorRT EP failed to create engine from network for fused node: " + fused_node.Name()); + std::string err_msg = "TensorRT EP failed to create engine from network for fused node: " + fused_node_name; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } trt_engine = std::unique_ptr( runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); if (trt_engine == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP failed to deserialize engine for fused node: " + fused_node.Name()); + std::string err_msg = "TensorRT EP failed to deserialize engine for fused node: " + fused_node_name; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } if (detailed_build_log_) { auto engine_build_stop = std::chrono::steady_clock::now(); @@ -1418,7 +1269,8 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this if (engine_encryption_ != nullptr) { if (!engine_encryption_(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP call to engine encryption library failed"); + std::string err_msg = "TensorRT EP call to engine encryption library failed"; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; } else { @@ -1436,8 +1288,8 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this auto timing_cache = trt_config->getTimingCache(); std::unique_ptr timingCacheHostData{timing_cache->serialize()}; if (timingCacheHostData == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not serialize timing cache: " + timing_cache_path); + std::string err_msg = "TensorRT EP could not serialize timing cache: " + timing_cache_path; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); if (detailed_build_log_) { @@ -1457,11 +1309,13 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this if (engine_cache_enable_ && engine_hw_compatible_) { compute_capability_hw_compat = "80+"; } + /* std::unique_ptr model_proto{ CreateCtxModel(graph_body_viewer, ep_cache_context_attr_, reinterpret_cast(serialized_engine->data()), serialized_engine->size(), ep_context_embed_mode_, compute_capability_hw_compat, model_path_, GetLogger())}; DumpCtxModel(model_proto.get(), ctx_model_path_); + */ } } } @@ -1473,8 +1327,8 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this auto status = RefitEngine(model_path_, onnx_model_folder_path_, engine_cache_path, false /* path check for security */, onnx, onnx_size, trt_engine.get(), true /* serialize refitted engine to disk */, detailed_build_log_); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + if (status != nullptr) { + return ort_api.CreateStatus(ORT_EP_FAIL, "RefitEngine failed."); } } @@ -1496,8 +1350,8 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this trt_context = std::unique_ptr(trt_engine->createExecutionContext()); } if (!trt_context) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not build execution context for fused node: " + fused_node.Name()); + std::string err_msg = "TensorRT EP could not build execution context for fused node: " + fused_node_name; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } } @@ -1524,16 +1378,17 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } // Save TRT engine, other TRT objects and input/output info to map - parsers_.emplace(fused_node.Name(), std::move(trt_parser)); - engines_.emplace(fused_node.Name(), std::move(trt_engine)); - contexts_.emplace(fused_node.Name(), std::move(trt_context)); - networks_.emplace(fused_node.Name(), std::move(trt_network)); - input_info_[fused_node.Name()].push_back(input_indexes); - output_info_[fused_node.Name()].push_back(output_indexes); - output_info_[fused_node.Name()].push_back(output_types); - input_shape_ranges_[fused_node.Name()] = input_implicit_shape_ranges; - profiles_.emplace(fused_node.Name(), std::move(trt_profiles)); + parsers_.emplace(fused_node_name, std::move(trt_parser)); + engines_.emplace(fused_node_name, std::move(trt_engine)); + contexts_.emplace(fused_node_name, std::move(trt_context)); + networks_.emplace(fused_node_name, std::move(trt_network)); + input_info_[fused_node_name].push_back(input_indexes); + output_info_[fused_node_name].push_back(output_indexes); + output_info_[fused_node_name].push_back(output_types); + input_shape_ranges_[fused_node_name] = input_implicit_shape_ranges; + profiles_.emplace(fused_node_name, std::move(trt_profiles)); + /* // For dynamic shape input model, firstly TRT EP creates a model proto which includes inputs, outputs and empty // engine. TRT EP will serialize the model at inference time due to engine can be updated and the updated engine // should be included in the model. However, if the embed_mode is 0 (only includes engine path), TRT EP will serialize @@ -1556,6 +1411,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this DumpCtxModel(model_proto_.get(), ctx_model_path_); } } + */ std::unique_ptr compute_state = std::make_unique(); @@ -1575,6 +1431,8 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this output_info_[fused_node_name], input_shape_ranges_[fused_node_name], &tensorrt_mu_, + compute_capability_, + max_workspace_size_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, @@ -1603,18 +1461,26 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this !tactic_sources_.empty(), tactics, cuda_graph_enable_, + weight_stripped_engine_enable_, + weight_stripped_engine_refit_, + model_path_, + onnx_model_folder_path_, + onnx_model_bytestream_, + onnx_model_bytestream_size_, cache_prefix_, cache_suffix, - engine_hw_compatible_}; + engine_hw_compatible_, + sync_stream_after_enqueue_}; // Update the OrtNodeComputeInfo associated with the graph. - auto node_compute_info = std::make_unique(*ep); - node_compute_info = node_compute_info.release(); - return Status::OK(); + auto ep_node_compute_info = std::make_unique(*ep); + *node_compute_info = ep_node_compute_info.release(); + + return nullptr; } -static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, +OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, OrtEpGraphSupportInfo* graph_support_info) { TensorrtExecutionProvider* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; @@ -1891,8 +1757,13 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph reinterpret_cast(&supported_node))); supported_nodes.push_back(supported_node); } + + // Create (optional) fusion options for the supported nodes to fuse. + OrtNodeFusionOptions node_fusion_options = {}; + node_fusion_options.ort_version_supported = ORT_API_VERSION; + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, supported_nodes.data(), - supported_nodes.size())); + supported_nodes.size(), &node_fusion_options)); number_of_trt_nodes += static_cast(group.first.size()); } } @@ -1909,11 +1780,15 @@ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph return nullptr; } -static OrtStatus* ORT_API_CALL CompileImpl(OrtEp* this_ptr, const OrtGraph** graphs, const OrtNode** fused_nodes, - size_t count, OrtNodeComputeInfo** node_compute_infos) { +OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, + _In_ const OrtNode** fused_nodes, _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** ep_context_nodes) { TensorrtExecutionProvider* ep = static_cast(this_ptr); + gsl::span result(node_compute_infos, count); + for (size_t graph_idx = 0; graph_idx < count; graph_idx++) { auto fused_node = fused_nodes[graph_idx]; @@ -1962,10 +1837,10 @@ static OrtStatus* ORT_API_CALL CompileImpl(OrtEp* this_ptr, const OrtGraph** gra if (false) { status = ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[graph_idx], fused_node, input_map, - output_map, &node_compute_infos[graph_idx]); + output_map, &result[graph_idx]); } else { status = ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[graph_idx], fused_node, input_map, - output_map, &node_compute_infos[graph_idx]); + output_map, &result[graph_idx]); } //if (status != Status::OK()) { // return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); @@ -1975,18 +1850,121 @@ static OrtStatus* ORT_API_CALL CompileImpl(OrtEp* this_ptr, const OrtGraph** gra return nullptr; } -static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) { +const char* ORT_API_CALL TensorrtExecutionProvider::GetNameImpl(const OrtEp* this_ptr) noexcept { const auto* ep = static_cast(this_ptr); return ep->name_.c_str(); } +/** + * Refit the weight-stripped engine + */ +OrtStatus* TensorrtExecutionProvider::RefitEngine( + std::string onnx_model_filename, std::string& onnx_model_folder_path, std::string& weight_stripped_engine_cath_path, + bool path_check, const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, + nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, bool detailed_build_log) { +#if NV_TENSORRT_MAJOR >= 10 + bool refit_from_file = onnx_model_bytestream == nullptr && onnx_model_bytestream_size == 0; + std::filesystem::path onnx_model_path{onnx_model_folder_path}; + if (refit_from_file) { + if (!onnx_model_filename.empty()) { + onnx_model_path.append(onnx_model_filename); + } + if (onnx_model_path.empty()) { + std::string err_msg = "The ONNX model was not provided as path. Please use provide an ONNX bytestream to enable refitting the weightless engine."; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } else { + /* + // check if file path to ONNX is legal + if (path_check && IsAbsolutePath(onnx_model_path.string())) { + std::string err_msg = + "For security purpose, the ONNX model path should be set with a relative path, but it is an absolute path: " + onnx_model_path.string(); + "weightless engine."; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + if (path_check && IsRelativePathToParentPath(onnx_model_path.string())) { + std::string err_msg = + "The ONNX model path has '..'. For security purpose, it's not allowed to point outside the directory."; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + */ + + if (!(std::filesystem::exists(onnx_model_path) && std::filesystem::is_regular_file(onnx_model_path))) { + std::string err_msg = "The ONNX model " + onnx_model_path.string() + " does not exist."; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } + } + + // weight-stripped engine refit logic + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log); + auto refitter = std::unique_ptr(nvinfer1::createInferRefitter(*trt_engine, trt_logger)); + auto parser_refitter = + std::unique_ptr(nvonnxparser::createParserRefitter(*refitter, trt_logger)); + if (refit_from_file) { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from file on disk: " << onnx_model_path.string(); + if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { + std::string err_msg = "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with " + "weights contained in: " + + onnx_model_path.string(); + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } else { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from byte array"; + if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) { + std::string err_msg = + "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with " + "weights contained in the provided bytestraem"; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } + if (refitter->refitCudaEngine()) { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully refitted the weight-stripped engine."; + } else { + std::string err_msg = + "TensorRT EP's IRefitter could not refit deserialized weight-stripped engine with weights contained in: " + + onnx_model_path.string(); + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + // serialize the refitted engine to disk + if (serialize_refitted_engine) { + std::string refitted_engine_cache = GetWeightRefittedEnginePath(weight_stripped_engine_cath_path); + nvinfer1::IHostMemory* serialized_engine = trt_engine->serialize(); + std::ofstream engine_file(refitted_engine_cache, std::ios::binary | std::ios::out); + engine_file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialize the refitted engine to " << refitted_engine_cache; + } + return nullptr; +#else + std::string err_msg = "TensorRT EP's IParserRefitter can only be used on TRT 10.0 onwards."; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); +#endif +} + +/// /// -/// The Plugin TensorRT EP (Implementation of TensorrtExecutionProvider) +/// Plugin TensorRT EP that implements OrtEp /// -TensorrtExecutionProvider::TensorrtExecutionProvider(ApiPtrs apis, const std::string& name, - const OrtHardwareDevice& device, - const OrtSessionOptions& session_options, const OrtLogger& logger) - : ApiPtrs(apis), name_{name}, hardware_device_{device}, session_options_{session_options}, logger_{logger} { +/// +TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFactory& factory, + const std::string& name, + const OrtHardwareDevice& device, + const OrtSessionOptions& session_options, + const OrtLogger& logger) + : ApiPtrs{static_cast(factory)}, + factory_(factory), + name_{name}, + hardware_device_{device}, + session_options_{session_options}, + logger_{logger} { + + // Implementation of OrtEp interfaces + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. + GetName = GetNameImpl; + GetCapability = GetCapabilityImpl; + Compile = CompileImpl; + ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl; + // Initialize the execution provider. auto status = ort_api.Logger_LogMessage(&logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, @@ -1995,12 +1973,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(ApiPtrs apis, const std::st // ignore status for now (void)status; - // Implementation of OrtEp interfaces - ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. - GetName = GetNameImpl; - GetCapability = GetCapabilityImpl; - Compile = CompileImpl; - // ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl; + // populate apis as global for utility functions + g_ort_api = &ort_api; + g_ep_api = &ep_api; + g_model_editor_api = &model_editor_api; // The implementation of the SessionOptionsAppendExecutionProvider C API function automatically adds EP options to // the session option configurations with the key prefix "ep..". @@ -2031,7 +2007,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(ApiPtrs apis, const std::st } }; - // Get environment variables + // get provider options if (info_.has_trt_options) { max_partition_iterations_ = info_.max_partition_iterations; min_subgraph_size_ = info_.min_subgraph_size; @@ -2089,198 +2065,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(ApiPtrs apis, const std::st cuda_graph_enable_ = info_.cuda_graph_enable; engine_hw_compatible_ = info_.engine_hw_compatible; } else { - try { - // const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations); - // if (!max_partition_iterations_env.empty()) { - // max_partition_iterations_ = std::stoi(max_partition_iterations_env); - // } - - // const std::string min_subgraph_size_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMinSubgraphSize); - // if (!min_subgraph_size_env.empty()) { - // min_subgraph_size_ = std::stoi(min_subgraph_size_env); - // } - - // const std::string max_workspace_size_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxWorkspaceSize); - // if (!max_workspace_size_env.empty()) { - // max_workspace_size_ = std::stoull(max_workspace_size_env); - // } - - // const std::string fp16_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kFP16Enable); - // if (!fp16_enable_env.empty()) { - // fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); - // } - - // const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kINT8Enable); - // if (!int8_enable_env.empty()) { - // int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); - // } - - // if (int8_enable_) { - // const std::string int8_calibration_cache_name_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kINT8CalibrationTableName); - // if (!int8_calibration_cache_name_env.empty()) { - // int8_calibration_cache_name_ = int8_calibration_cache_name_env; - // } - - // const std::string int8_use_native_tensorrt_calibration_table_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kINT8UseNativeTensorrtCalibrationTable); - // if (!int8_use_native_tensorrt_calibration_table_env.empty()) { - // int8_use_native_tensorrt_calibration_table_ = (std::stoi(int8_use_native_tensorrt_calibration_table_env) == 0 ? false : true); - // } - // } - - // if (fp16_enable_ || int8_enable_) { // DLA can only be enabled with FP16 or INT8 - // const std::string dla_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDLAEnable); - // if (!dla_enable_env.empty()) { - // dla_enable_ = (std::stoi(dla_enable_env) == 0 ? false : true); - // } - - // if (dla_enable_) { - // const std::string dla_core_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDLACore); - // if (!dla_core_env.empty()) { - // dla_core_ = std::stoi(dla_core_env); - // } - // } - // } - - // const std::string dump_subgraphs_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDumpSubgraphs); - // if (!dump_subgraphs_env.empty()) { - // dump_subgraphs_ = (std::stoi(dump_subgraphs_env) == 0 ? false : true); - // } - - // const std::string engine_cache_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCacheEnable); - // if (!engine_cache_enable_env.empty()) { - // engine_cache_enable_ = (std::stoi(engine_cache_enable_env) == 0 ? false : true); - // } - - // const std::string weight_stripped_engine_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kWeightStrippedEngineEnable); - // if (!weight_stripped_engine_enable_env.empty()) { - // weight_stripped_engine_enable_ = std::stoi(weight_stripped_engine_enable_env) != 0; - // } - - // const std::string onnx_model_folder_path_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kOnnxModelFolderPath); - // if (!onnx_model_folder_path_env.empty()) { - // onnx_model_folder_path_ = onnx_model_folder_path_env; - // } - - // const std::string timing_cache_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kTimingCacheEnable); - // if (!timing_cache_enable_env.empty()) { - // timing_cache_enable_ = (std::stoi(timing_cache_enable_env) == 0 ? false : true); - // } - - // const std::string detailed_build_log_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDetailedBuildLog); - // if (!detailed_build_log_env.empty()) { - // detailed_build_log_ = (std::stoi(detailed_build_log_env) == 0 ? false : true); - // } - - // const std::string timing_force_match_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kForceTimingCache); - // if (!timing_force_match_env.empty()) { - // force_timing_cache_match_ = (std::stoi(timing_force_match_env) == 0 ? false : true); - // } - - // const std::string dump_ep_context_model_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDumpEpContextModel); - // if (!dump_ep_context_model_env.empty()) { - // dump_ep_context_model_ = (std::stoi(dump_ep_context_model_env) == 0 ? false : true); - // } - - // const std::string ep_context_file_path_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextComputeCapabilityEnable); - // if (!ep_context_file_path_env.empty()) { - // ep_context_file_path_ = ep_context_file_path_env; - // } - - // const std::string ep_context_embed_mode_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextEmbedMode); - // if (!ep_context_embed_mode_env.empty()) { - // ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env); - // } - // // incase the EP context is dumped the engine cache has to be enabled - // if (dump_ep_context_model_ && ep_context_embed_mode_ == 0) { - // engine_cache_enable_ = true; - // } - - // enable_engine_cache_for_ep_context_model(); - - // if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { - // const std::string engine_cache_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCachePath); - // cache_path_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCachePath); - // cache_prefix_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCachePrefix); - // if (!engine_cache_path.empty() && cache_path_.empty()) { - // cache_path_ = engine_cache_path; - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_ENGINE_CACHE_PATH is deprecated! Please use ORT_TENSORRT_CACHE_PATH to specify engine cache path"; - // } - // } - // if (timing_cache_enable_) { - // std::string timing_cache_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kTimingCachePath); - // // use a more global cache if given - // if (!timing_cache_path.empty()) { - // global_cache_path_ = timing_cache_path; - // } else { - // global_cache_path_ = cache_path_; - // } - // } - - // const std::string engine_decryption_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionEnable); - // if (!engine_decryption_enable_env.empty()) { - // engine_decryption_enable_ = (std::stoi(engine_decryption_enable_env) == 0 ? false : true); - // } - - // if (engine_decryption_enable_) { - // engine_decryption_lib_path_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionLibPath); - // } - - // const std::string force_sequential_engine_build_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kForceSequentialEngineBuild); - // if (!force_sequential_engine_build_env.empty()) { - // force_sequential_engine_build_ = (std::stoi(force_sequential_engine_build_env) == 0 ? false : true); - // } - - // const std::string context_memory_sharing_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kContextMemorySharingEnable); - // if (!context_memory_sharing_enable_env.empty()) { - // context_memory_sharing_enable_ = (std::stoi(context_memory_sharing_enable_env) == 0 ? false : true); - // } - - // const std::string layer_norm_fp32_fallback_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kLayerNormFP32Fallback); - // if (!layer_norm_fp32_fallback_env.empty()) { - // layer_norm_fp32_fallback_ = (std::stoi(layer_norm_fp32_fallback_env) == 0 ? false : true); - // } - - // const std::string build_heuristics_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kBuildHeuristics); - // if (!build_heuristics_env.empty()) { - // build_heuristics_enable_ = (std::stoi(build_heuristics_env) == 0 ? false : true); - // } - - // const std::string sparsity_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kSparsityEnable); - // if (!sparsity_enable_env.empty()) { - // sparsity_enable_ = (std::stoi(sparsity_enable_env) == 0 ? false : true); - // } - - // const std::string builder_optimization_level_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kBuilderOptimizationLevel); - // if (!builder_optimization_level_env.empty()) { - // builder_optimization_level_ = std::stoi(builder_optimization_level_env); - // } - - // const std::string auxiliary_streams_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kAuxiliaryStreams); - // if (!auxiliary_streams_env.empty()) { - // auxiliary_streams_ = std::stoi(auxiliary_streams_env); - // } - - // const std::string tactic_sources_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kTacticSources); - // if (!tactic_sources_env.empty()) { - // tactic_sources_ = tactic_sources_env; - // } - - // profile_min_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesMinShapes); - // profile_max_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesMaxShapes); - // profile_opt_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesOptShapes); - - // const std::string cuda_graph_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCudaGraphEnable); - // if (!cuda_graph_enable_env.empty()) { - // cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true); - // } - - } catch (const std::invalid_argument& ex) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what(); - } catch (const std::out_of_range& ex) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Out Of Range Error (from environment variables): " << ex.what(); - } catch (...) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Unknown Exception (from environment variables)"; - } + // deprecate env provider option } // Validate setting @@ -2308,6 +2093,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(ApiPtrs apis, const std::st } } + /* // If dump_ep_context_model_ is enable, TRT EP forces cache_path_ to be the relative path of ep_context_file_path_. // For example, // - original cache path = "engine_cache_dir" -> new cache path = "./context_model_dir/engine_cache_dir" @@ -2329,6 +2115,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(ApiPtrs apis, const std::st // Make cache_path_ to be the relative path of ep_context_file_path_ cache_path_ = GetPathOrParentPathOfCtxModel(ep_context_file_path_).append(cache_path_).string(); } + */ // Hardware compatibility: pre-check on environment if (engine_cache_enable_ && engine_hw_compatible_) { @@ -2454,6 +2241,14 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(ApiPtrs apis, const std::st } } +void ORT_API_CALL TensorrtExecutionProvider::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, + size_t num_node_compute_infos) { + (void)this_ptr; + for (size_t i = 0; i < num_node_compute_infos; i++) { + delete node_compute_infos[i]; + } +} + // // Implementation of TRTEpNodeComputeInfo @@ -2469,7 +2264,7 @@ OrtStatus* TRTEpNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, O void** compute_state) { auto* node_compute_info = static_cast(this_ptr); TensorrtExecutionProvider& ep = node_compute_info->ep; - + std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context); auto state_it = ep.GetComputeStates().find(fused_node_name); if (state_it == ep.GetComputeStates().end()) { @@ -2509,12 +2304,31 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input + auto max_workspace_size = trt_state->max_workspace_size; auto trt_builder = trt_state->builder; auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; auto context_memory = trt_state->context_memory; auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; + auto cache_prefix = trt_state->cache_prefix; + auto compute_capability = trt_state->compute_capability; + auto engine_cache_enable = trt_state->engine_cache_enable; + auto engine_hw_compatible = trt_state->engine_hw_compatible; + auto timing_cache_enable = trt_state->timing_cache_enable; + auto force_timing_cache_match = trt_state->force_timing_cache; + auto global_cache_path = trt_state->timing_cache_path; + auto detailed_build_log = trt_state->detailed_build_log; + + auto weight_stripped_engine_enable = trt_state->weight_stripped_engine_enable; + auto weight_stripped_engine_refit = trt_state->weight_stripped_engine_refit; + auto model_path = trt_state->model_path; + auto onnx_model_folder_path = trt_state->onnx_model_folder_path; + auto onnx_model_bytestream = trt_state->onnx_model_bytestream; + auto onnx_model_bytestream_size = trt_state->onnx_model_bytestream_size; + + auto sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; + int num_inputs = static_cast(input_indexes.size()); int num_outputs = static_cast(output_indexes.size()); bool engine_update = false; @@ -2523,17 +2337,19 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* std::unordered_map dds_output_allocator_maps = ep.GetDDSOutputAllocators(); auto& dds_output_allocator_map = dds_output_allocator_maps[fused_node_name]; - - OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, - narrow(device_id_)); - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device); - if (alloc_ == nullptr) { - Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); + + // Get default OrtMemoryInfo from factory + // Get allocator from OrtKernelContext + OrtMemoryInfo* mem_info = ep.factory_.GetDefaultMemInfo(); + OrtAllocator* alloc = nullptr; + ep.GetAllocator(&alloc); + if (alloc == nullptr) { + Ort::ThrowOnError(ep.ort_api.KernelContext_GetAllocator(kernel_context, mem_info, &alloc)); + ep.SetAllocator(alloc); } - OrtAllocator* alloc = alloc_; void* cuda_stream; - Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + Ort::ThrowOnError(ep.ort_api.KernelContext_GetGPUComputeStream(kernel_context, &cuda_stream)); cudaStream_t stream = static_cast(cuda_stream); // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache @@ -2541,16 +2357,16 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* // if they share the same compute capacity Prepare cache name std::string cache_path = ""; // Customize cache prefix if assigned - if (!cache_prefix_.empty()) { + if (!cache_prefix.empty()) { cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->cache_prefix) + trt_state->cache_suffix; } else { cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); } // Enable hardware compatility mode if assigned - std::string cache_hw_compat = "_sm" + compute_capability_; + std::string cache_hw_compat = "_sm" + compute_capability; #if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 - if (engine_cache_enable_ && engine_hw_compatible_) { + if (engine_cache_enable && engine_hw_compatible) { cache_hw_compat = "_sm80+"; // LOGS_DEFAULT(VERBOSE) // << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; @@ -2565,16 +2381,16 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; const std::string profile_cache_path = cache_path_prefix + ".profile"; std::string timing_cache_path = ""; - if (timing_cache_enable_) { - timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_); + if (timing_cache_enable) { + timing_cache_path = GetTimingCachePath(global_cache_path, compute_capability); } // If weight-stripped engine is enabled and refitted engine cache is not present, // TRT EP will use the engine cache with ".stripped.engine" appended to the end. const std::filesystem::path engine_cache_fs_path = engine_cache_path; - if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { + if (weight_stripped_engine_enable && !std::filesystem::exists(engine_cache_fs_path)) { engine_cache_path = cache_path_prefix + ".stripped.engine"; - weight_stripped_engine_refit_ = true; + weight_stripped_engine_refit = true; } // Load serialized engine @@ -2600,7 +2416,8 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* *(trt_state->engine) = std::unique_ptr( trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); if (!(*(trt_state->engine))) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + std::string err_msg = "TensorRT EP Failed to Build Engine."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; trt_engine = trt_state->engine->get(); @@ -2613,11 +2430,13 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* // Decrypt engine size_t engine_size = 0; if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not get engine buffer size"); + std::string err_msg = "TensorRT EP could not get engine buffer size"; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } std::unique_ptr engine_buf{new char[engine_size]}; if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not call engine decryption function decrypt"); + std::string err_msg = "TensorRT EP could not call engine decryption function decrypt"; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } // Deserialize engine // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc @@ -2626,9 +2445,8 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* *(trt_state->engine) = std::unique_ptr( trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); if (!(*(trt_state->engine))) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); + std::string err_msg = "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; trt_engine = trt_state->engine->get(); @@ -2649,9 +2467,9 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, shape_tensor_values, shape_tensor_values_int64, stream, &engine_update); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP failed to parse input tensor and generate optimization profiles."); + if (status != nullptr) { + std::string err_msg = "TensorRT EP failed to parse input tensor and generate optimization profiles."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } } } @@ -2663,8 +2481,8 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* trt_state->context->reset(); trt_state->engine->reset(); auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); - if (max_workspace_size_ > 0) { - trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); + if (max_workspace_size > 0) { + trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size); } for (auto trt_profile : trt_profiles) { trt_config->addOptimizationProfile(trt_profile); @@ -2680,7 +2498,8 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* #pragma warning(pop) #endif if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); + std::string err_msg = "TensorRT EP failed to set INT8 dynamic range."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } } #if defined(_MSC_VER) @@ -2696,10 +2515,6 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; } - if (trt_state->bf16_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kBF16); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] BF16 mode is enabled"; - } #if defined(_MSC_VER) #pragma warning(pop) #endif @@ -2742,7 +2557,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; } #endif - if (weight_stripped_engine_enable_) { + if (weight_stripped_engine_enable) { #if NV_TENSORRT_MAJOR >= 10 trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; @@ -2767,10 +2582,11 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); if (timing_cache == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not create timing cache: " + timing_cache_path); + std::string err_msg = "TensorRT EP could not create timing cache: " + timing_cache_path; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } - trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); - if (detailed_build_log_) { + trt_config->setTimingCache(*timing_cache, force_timing_cache_match); + if (detailed_build_log) { // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; } } @@ -2783,39 +2599,37 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* } #endif - // Set preview feature flags - for (auto feature : trt_state->preview_features) { - trt_config->setPreviewFeature(feature, true); - } - // Build engine std::unique_ptr serialized_engine; { - auto lock = GetApiLock(); + auto lock = ep.GetApiLock(); std::chrono::steady_clock::time_point engine_build_start; - if (detailed_build_log_) { + if (detailed_build_log) { engine_build_start = std::chrono::steady_clock::now(); } serialized_engine = std::unique_ptr( trt_builder->buildSerializedNetwork(*trt_state->network->get(), *trt_config)); if (!serialized_engine) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create engine from network."); + std::string err_msg = "TensorRT EP failed to create engine from network."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } *(trt_state->engine) = std::unique_ptr( trt_state->runtime->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); if (!(*(trt_state->engine))) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to deserialize engine."); + std::string err_msg = "TensorRT EP failed to deserialize engine."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } - if (detailed_build_log_) { + if (detailed_build_log) { auto engine_build_stop = std::chrono::steady_clock::now(); // LOGS_DEFAULT(INFO) - << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " - << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" - << std::endl; + // << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " + // << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" + // << std::endl; } } if (!(*(trt_state->engine))) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + std::string err_msg = "TensorRT EP Failed to Build Engine."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } trt_engine = trt_state->engine->get(); if (trt_state->engine_cache_enable) { @@ -2831,13 +2645,13 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not call engine encryption function encrypt"); + std::string err_msg = "TensorRT EP could not call engine encryption function encrypt"; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; } else { // LOGS_DEFAULT(WARNING) - << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; + // << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; } } else { std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); @@ -2851,30 +2665,32 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* auto timing_cache = trt_config->getTimingCache(); std::unique_ptr timingCacheHostData{timing_cache->serialize()}; if (timingCacheHostData == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not serialize timing cache: " + timing_cache_path); + std::string err_msg = "TensorRT EP could not serialize timing cache: " + timing_cache_path; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); - if (detailed_build_log_) { + if (detailed_build_log) { // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; } } + /* // dump ep context model if (dump_ep_context_model_ && ep_context_embed_mode_) { UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast(serialized_engine->data()), serialized_engine->size()); DumpCtxModel(model_proto_.get(), ctx_model_path_); } + */ context_update = true; - if (weight_stripped_engine_refit_) { + if (weight_stripped_engine_refit) { auto status = - RefitEngine(model_path_, onnx_model_folder_path_, engine_cache_path, false /* path check for security */, - onnx_model_bytestream_, onnx_model_bytestream_size_, trt_engine, - true /* serialize refitted engine to disk */, detailed_build_log_); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + ep.RefitEngine(model_path, onnx_model_folder_path, engine_cache_path, false /* path check for security */, + onnx_model_bytestream, onnx_model_bytestream_size, trt_engine, + true /* serialize refitted engine to disk */, detailed_build_log); + if (status != nullptr) { + return ep.ort_api.CreateStatus(ORT_EP_FAIL, "RefitEngine failed."); } } } @@ -2894,14 +2710,16 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* std::unique_ptr(trt_state->engine->get()->createExecutionContext()); } if (!(*(trt_state->context))) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); + std::string err_msg = "TensorRT EP failed to create context."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } trt_context = trt_state->context->get(); } // Check before using trt_engine if (trt_engine == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "No engine is found."); + std::string err_msg = "No engine is found."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } // Get input and output binding names @@ -2920,7 +2738,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* /* * Set input shapes and bind input buffers */ - std::vector> scratch_buffers; + std::vector> scratch_buffers; for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { char const* input_name = input_binding_names[i]; @@ -2935,8 +2753,8 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + if (status != nullptr) { + return ep.ort_api.CreateStatus(ORT_EP_FAIL, "BindContextInput failed."); } } @@ -2966,10 +2784,10 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* output_type = type_iter->second; } - Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, - output_dim_sizes, dds_output_allocator_map, scratch_buffers, alloc, buffers); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + auto status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, + output_dim_sizes, dds_output_allocator_map, scratch_buffers, alloc, buffers); + if (status != nullptr) { + return ep.ort_api.CreateStatus(ORT_EP_FAIL, "BindContextOutput failed."); } } @@ -2985,12 +2803,12 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* #endif if (mem_size > *max_context_mem_size_ptr) { *max_context_mem_size_ptr = mem_size; - *context_memory = - IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr, true /*use_reserve*/); + *context_memory = MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr, true); } trt_context->setDeviceMemory((*context_memory).get()); } + /* // Start CUDA graph capture. // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. @@ -2999,10 +2817,12 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* cuda_graph_.SetStream(stream); CaptureBegin(0); } + */ // Run TRT inference if (!trt_context->enqueueV3(stream)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); + std::string err_msg = "TensorRT EP execution context enqueue failed."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } /* @@ -3021,7 +2841,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* * operations to prevent the concurrent issue mentioned above. However, if cuda graph is enabled, TRT EP won't call * cudaStreamSynchronize() since it's not allowed during graph capture. */ - if (sync_stream_after_enqueue_) { + if (sync_stream_after_enqueue) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); } @@ -3043,10 +2863,9 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* if (index_iter != output_indexes.end()) { output_index = index_iter->second; } - auto status = - BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); + auto status = BindKernelOutput(ctx, mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); + if (status != nullptr) { + return ep.ort_api.CreateStatus(ORT_EP_FAIL, "BindKernelOutput failed."); } } else { auto& output_tensor = output_tensors[i]; @@ -3069,6 +2888,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* } } + /* // End CUDA graph capture. // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream // mentioned in graph capture above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and @@ -3084,8 +2904,9 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* IncrementRegularRunCountBeforeGraphCapture(); } } + */ - return kernel.Compute(kernel_context); + return nullptr; } void TRTEpNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) { diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index f3f4d131..a595bcf8 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -4,6 +4,7 @@ #include "onnxruntime_cxx_api.h" #undef ORT_API_MANUAL_INIT +#include "tensorrt_provider_factory.h" #include "utils/provider_options.h" #include "tensorrt_execution_provider_info.h" #include "nv_includes.h" @@ -150,11 +151,6 @@ class OutputAllocator : public nvinfer1::IOutputAllocator { std::vector output_shapes; }; -using ShapeRangesMap = std::unordered_map>>>; - -template -using IAllocatorUniquePtr = std::unique_ptr>; - struct TensorrtComputeState { std::string fused_node_name; nvinfer1::IBuilder* builder; @@ -166,6 +162,8 @@ struct TensorrtComputeState { std::vector> output_info; std::unordered_map>>> input_shape_ranges; std::mutex* tensorrt_mu_ptr = nullptr; + std::string compute_capability; + size_t max_workspace_size = 1 << 30; // 1GB; bool fp16_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; @@ -178,7 +176,7 @@ struct TensorrtComputeState { std::vector profiles; bool context_memory_sharing_enable = false; size_t* max_context_mem_size_ptr = nullptr; - IAllocatorUniquePtr* context_memory = nullptr; + AllocatorUniquePtr* context_memory = nullptr; std::unordered_map dynamic_range_map; bool engine_decryption_enable = false; int (*engine_decryption)(const char*, char*, size_t*) = nullptr; @@ -193,10 +191,17 @@ struct TensorrtComputeState { int auxiliary_streams = -1; bool filter_tactic_sources = false; nvinfer1::TacticSources tactic_sources; - bool cuda_graph_enable = 0; + bool cuda_graph_enable = false; + bool weight_stripped_engine_enable = false; + bool weight_stripped_engine_refit = false; + char* model_path; + std::string onnx_model_folder_path; + const void* onnx_model_bytestream; + size_t onnx_model_bytestream_size; std::string cache_prefix; std::string cache_suffix; bool engine_hw_compatible = false; + bool sync_stream_after_enqueue = true; }; // Minimum information to construct kernel function state for direct engine load code path @@ -211,6 +216,7 @@ struct TensorrtComputeStateForEPContext { std::mutex* tensorrt_mu_ptr = nullptr; }; +using ShapeRangesMap = std::unordered_map>>>; using DDSOutputAllocatorMap = std::unordered_map>; std::string GetWeightRefittedEnginePath(std::string engine_cache_path); @@ -220,54 +226,51 @@ static const std::string k_ep_ctx_onnx_model_filename = "onnx_model_filename"; /// /// -/// Plugin TensorRT EP OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. -/// -/// -struct TRTEpNodeComputeInfo : OrtNodeComputeInfo { - explicit TRTEpNodeComputeInfo(TensorrtExecutionProvider& ep); - - static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context, - void** compute_state); - static OrtStatus* ORT_API_CALL ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, - OrtKernelContext* kernel_context); - static void ORT_API_CALL ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state); - - TensorrtExecutionProvider& ep; -}; - -/// -/// -/// Plugin TensorRT EP that implements OrtEp +/// Plugin TensorRT EP /// /// -struct TensorrtExecutionProvider : OrtEp, ApiPtrs { - TensorrtExecutionProvider(ApiPtrs apis, const std::string& name, const OrtHardwareDevice& device, - const OrtSessionOptions& session_options, const OrtLogger& logger); +struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { + TensorrtExecutionProvider(TensorrtExecutionProviderFactory& factory, const std::string& name, + const OrtHardwareDevice& device, const OrtSessionOptions& session_options, + const OrtLogger& logger); ~TensorrtExecutionProvider(); + TensorrtExecutionProviderFactory& factory_; std::string name_; const OrtHardwareDevice& hardware_device_; const OrtSessionOptions& session_options_; const OrtLogger& logger_; - SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations, const int max_iterations, - const OrtGraph* graph, bool* early_termination) const; + SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations, + const int max_iterations, const OrtGraph* graph, bool* early_termination) const; OrtStatus* CreateNodeComputeInfoFromPrecompiledEngine(OrtEp* this_ptr, const OrtGraph* graph, const OrtNode* fused_node, std::unordered_map& input_map, std::unordered_map& output_map, - OrtNodeComputeInfo* node_compute_info); + OrtNodeComputeInfo** node_compute_info); OrtStatus* CreateNodeComputeInfoFromGraph(OrtEp* this_ptr, const OrtGraph* graph, const OrtNode* fused_node, std::unordered_map& input_map, std::unordered_map& output_map, - OrtNodeComputeInfo* node_compute_info); + OrtNodeComputeInfo** node_compute_info); + + OrtStatus* RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path, + std::string& weight_stripped_engine_cath_path, bool path_check, + const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, + nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, + bool detailed_build_log); std::unordered_map>& GetComputeStates() { return compute_states_; } - std::unordered_map>& GetComputeStatesForEPContext() { return compute_states_; } + std::unordered_map>& GetComputeStatesForEPContext() { + return compute_states_; + } + + void GetAllocator(OrtAllocator** alloc) const { *alloc = alloc_; } + void SetAllocator(OrtAllocator* alloc) { alloc_ = alloc; } + std::unordered_map& GetDDSOutputAllocators() { return dds_output_allocator_maps_; } @@ -312,6 +315,19 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs { std::unordered_map cache_suffix_; private: + static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info); + static OrtStatus* ORT_API_CALL CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, + _In_ const OrtNode** fused_nodes, _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** ep_context_nodes); + static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, + size_t num_node_compute_infos); + + OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, + /*out*/ gsl::span ep_context_nodes); + mutable TensorrtExecutionProviderInfo info_; bool external_stream_ = false; cudaStream_t stream_ = nullptr; @@ -331,6 +347,8 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs { bool weight_stripped_engine_enable_ = false; bool weight_stripped_engine_refit_ = false; std::string onnx_model_folder_path_; + const void* onnx_model_bytestream_; + size_t onnx_model_bytestream_size_; bool build_heuristics_enable_ = false; bool sparsity_enable_ = false; int builder_optimization_level_ = 3; @@ -344,7 +362,7 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs { bool context_memory_sharing_enable_ = false; bool layer_norm_fp32_fallback_ = false; size_t max_ctx_mem_size_ = 0; - IAllocatorUniquePtr context_memory_ = nullptr; + AllocatorUniquePtr context_memory_ = nullptr; mutable char model_path_[4096] = {}; // Reserved for max path length bool engine_decryption_enable_ = false; int (*engine_decryption_)(const char*, char*, size_t*) = nullptr; @@ -419,3 +437,20 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs { nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const; }; + +/// +/// +/// Plugin TensorRT EP OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. +/// +/// +struct TRTEpNodeComputeInfo : OrtNodeComputeInfo { + explicit TRTEpNodeComputeInfo(TensorrtExecutionProvider& ep); + + static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context, + void** compute_state); + static OrtStatus* ORT_API_CALL ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context); + static void ORT_API_CALL ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state); + + TensorrtExecutionProvider& ep; +}; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h index 54fa0ed4..e03111e0 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h @@ -1,3 +1,5 @@ +#pragma once + #define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" #undef ORT_API_MANUAL_INIT @@ -9,6 +11,7 @@ // #include "core/framework/murmurhash3.h" #include"nv_includes.h" +#include "gsl/narrow" #include #include @@ -41,6 +44,42 @@ struct ApiPtrs { namespace fs = std::filesystem; +template +using AllocatorUniquePtr = std::unique_ptr>; + +bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept { + size_t alloc_size = size; + if (alignment == 0) { + *out = alloc_size * nmemb; + } else { + size_t alignment_mask = alignment - 1; + *out = (alloc_size * nmemb + alignment_mask) & ~static_cast(alignment_mask); + } + return true; +} + +template +AllocatorUniquePtr MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator, size_t count_or_bytes, + bool use_reserve = false) { + size_t alloc_size = count_or_bytes; + // if T is not void, 'count_or_bytes' == number of items so allow for that + if constexpr (!std::is_void::value) { + // sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't + // reachable if T is void. use std::conditional to 'use' void* in the sizeof call + constexpr auto size = sizeof(typename std::conditional::value, void*, T>::type); + CalcMemSizeForArrayWithAlignment(count_or_bytes, size, 0, &alloc_size); + } + + T* p = nullptr; + if (use_reserve) { + p = static_cast(ort_allocator->Reserve(ort_allocator, alloc_size)); + } else { + p = static_cast(ort_allocator->Alloc(ort_allocator, alloc_size)); + } + + return AllocatorUniquePtr{p, [ort_allocator](T* p) { ort_allocator->Free(ort_allocator, p); }}; +} + // Check if cycle exists in the graph after partitioning /* bool FindCycleHelper(size_t i, gsl::span> adjacency_map, gsl::span visited, @@ -168,7 +207,6 @@ std::vector SplitToStringVec(std::string const& s, char separator) return splitted; } -/* nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { nvinfer1::TacticSources disabledTactics = 0; nvinfer1::TacticSources enabledTactics = 0; @@ -184,7 +222,7 @@ nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { const auto toUpper = [](std::string& sourceName) { std::transform(sourceName.begin(), sourceName.end(), sourceName.begin(), - [](char c) { return onnxruntime::narrow(std::toupper(c)); }); + [](char c) { return gsl::narrow(std::toupper(c)); }); return sourceName; }; @@ -223,7 +261,6 @@ nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { } return enabledTactics & ~disabledTactics; } -*/ inline std::vector loadTimingCacheFile(const std::string inFileName) { std::ifstream iFile(inFileName, std::ios::in | std::ios::binary); diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index 03d4e902..dcf9e30d 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -206,6 +206,10 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl return nullptr; } +OrtMemoryInfo* TensorrtExecutionProviderFactory::GetDefaultMemInfo() const { + return default_gpu_memory_info_.get(); +} + // To make symbols visible on macOS/iOS #ifdef __APPLE__ #define EXPORT_SYMBOL __attribute__((visibility("default"))) @@ -221,10 +225,10 @@ EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); const OrtEpApi* ort_ep_api = ort_api->GetEpApi(); + const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); // Factory could use registration_name or define its own EP name. - std::unique_ptr factory = std::make_unique(registration_name, - ApiPtrs{*ort_api, *ort_ep_api}); + std::unique_ptr factory = std::make_unique(registration_name, ApiPtrs{*ort_api, *ort_ep_api, *model_editor_api}); if (max_factories < 1) { return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h index a3dd58b9..e4222e92 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -4,9 +4,10 @@ /// /// Plugin TensorRT EP factory that can create an OrtEp and return information about the supported hardware devices. /// -struct TensorrtExecutionProviderFactory : OrtEpFactory, ApiPtrs { +struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { public: TensorrtExecutionProviderFactory(const char* ep_name, ApiPtrs apis); + OrtMemoryInfo* GetDefaultMemInfo() const; private: static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; From c8e3d6f8fedaf8526323fed993f22422b9a8aaa4 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 2 Jul 2025 14:30:53 -0700 Subject: [PATCH 14/60] call EpDevice_AddAllocatorInfo in GetSupportedDevicesImpl --- .../tensorrt/tensorrt_provider_factory.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index dcf9e30d..f8116a32 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -92,8 +92,9 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp factory->ort_api.AddKeyValuePair(ep_options, "trt_builder_optimization_level", "3"); // OrtEpDevice copies ep_metadata and ep_options. + OrtEpDevice* ep_device = nullptr; auto* status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, - &ep_devices[num_ep_devices++]); + &ep_device); factory->ort_api.ReleaseKeyValuePairs(ep_metadata); factory->ort_api.ReleaseKeyValuePairs(ep_options); @@ -101,6 +102,12 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp if (status != nullptr) { return status; } + + // register the allocator info required by the EP. + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->default_gpu_memory_info_.get())); + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->host_accessible_gpu_memory_info_.get())); + + ep_devices[num_ep_devices++] = ep_device; } // C++ API equivalent. Throws on error. From 3c4302997659659c4ee336a5ec4db833fb2c5bfe Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 2 Jul 2025 20:52:37 -0700 Subject: [PATCH 15/60] temporary way to get provider option without proper API --- .../tensorrt/tensorrt_execution_provider.cc | 25 ++++++++++++++++--- .../tensorrt/utils/provider_options.h | 4 --- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 2e4b3915..0e1fe33d 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -1980,10 +1980,16 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa // The implementation of the SessionOptionsAppendExecutionProvider C API function automatically adds EP options to // the session option configurations with the key prefix "ep..". - const std::string key_prefix = OrtSessionOptions::GetProviderOptionPrefix(name_.c_str()); - const ConfigOptions& config_options = session_options.GetConfigOptions(); - const std::unordered_map& config_options_map = config_options.GetConfigOptionsMap(); + // We extract those EP options to create a new "provider options" key/value map. + std::string lowercase_ep_name = name_.c_str(); + std::transform(lowercase_ep_name.begin(), lowercase_ep_name.end(), lowercase_ep_name.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + // The implementation of the SessionOptionsAppendExecutionProvider C API function automatically adds EP options to + // the session option configurations with the key prefix "ep..". + std::string key_prefix = "ep." + lowercase_ep_name + "."; + + /* // Get provider options as key-value pair strings ProviderOptions provider_options; for (const auto& [key, value] : config_options_map) { @@ -1991,6 +1997,19 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa provider_options[key.substr(key_prefix.size())] = value; } } + */ + + // Get all the provider options as session config from sesson + ProviderOptions provider_options; + int has_session_config_entry = 0; + std::string provider_option = key_prefix + "trt_engine_cache_enable"; + auto status = ort_api.HasSessionConfigEntry(&session_options, provider_option.c_str(), & has_session_config_entry); + if (has_session_config_entry) { + char* value = nullptr; + size_t size = 0; + status = ort_api.GetSessionConfigEntry(&session_options, provider_option.c_str(), value, &size); + provider_options[provider_option.substr(key_prefix.size())] = value; + } // Provider options to TensorrtExecutionProviderInfo info_ = TensorrtExecutionProviderInfo::FromProviderOptions(provider_options); diff --git a/plugin_execution_providers/tensorrt/utils/provider_options.h b/plugin_execution_providers/tensorrt/utils/provider_options.h index aab13e80..33beba2f 100644 --- a/plugin_execution_providers/tensorrt/utils/provider_options.h +++ b/plugin_execution_providers/tensorrt/utils/provider_options.h @@ -7,12 +7,8 @@ #include #include -namespace onnxruntime { - // data types for execution provider options using ProviderOptions = std::unordered_map; using ProviderOptionsVector = std::vector; using ProviderOptionsMap = std::unordered_map; - -} // namespace onnxruntime From 549b29d373a3885496ee02b746d389792e4c046b Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 7 Jul 2025 15:04:42 -0700 Subject: [PATCH 16/60] Clean up cmake file to remove dependencies that built with ORT --- .../tensorrt/CMakeLists.txt | 62 ++++++++++++------- 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/plugin_execution_providers/tensorrt/CMakeLists.txt b/plugin_execution_providers/tensorrt/CMakeLists.txt index bf2d7b0c..9456f7cf 100644 --- a/plugin_execution_providers/tensorrt/CMakeLists.txt +++ b/plugin_execution_providers/tensorrt/CMakeLists.txt @@ -1,6 +1,6 @@ # usage: # cd build/ -# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DORT_HOME=/home/lochi/repos/ort -DTENSORRT_HOME=/home/lochi/tensorrt/TensorRT-10.3.0.26 (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") +# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DORT_HOME=/home/lochi/onnxruntime-win-x64-gpu-1.22.0 -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DTENSORRT_HOME=/home/lochi/tensorrt/TensorRT-10.3.0.26 (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") # cmake --build ./ --config Debug cmake_minimum_required(VERSION 3.26) project(TensorRTEp VERSION 1.0) @@ -13,7 +13,7 @@ add_definitions(-DONNX_NAMESPACE=onnx) add_definitions(-DONNX_ML) add_definitions(-DNV_TENSORRT_MAJOR=10) add_definitions(-DNOMINMAX) -file(GLOB tensorrt_src "./*.cc" "./utils/*.cc" "./cuda/unary_elementwise_ops_impl.cu") +file(GLOB tensorrt_src "./*.cc" "./utils/*.cc" "./cuda/unary_elementwise_ops_impl.cu" "./*.h") add_library(TensorRTEp SHARED ${tensorrt_src}) if (NOT ORT_HOME) @@ -32,6 +32,24 @@ endif() # Add dependencies include(FetchContent) +# Add protobuf +FetchContent_Declare( + protobuf + GIT_REPOSITORY https://github.com/protocolbuffers/protobuf.git + GIT_TAG v21.12 # Use a specific tag or commit +) + +FetchContent_MakeAvailable(protobuf) + +# Add ONNX +FetchContent_Declare( + onnx + GIT_REPOSITORY https://github.com/onnx/onnx.git + GIT_TAG v1.18.0 # Use a specific tag or commit +) + +FetchContent_MakeAvailable(onnx) + # Add GSL FetchContent_Declare( gsl @@ -50,17 +68,18 @@ FetchContent_Declare( FetchContent_MakeAvailable(flatbuffers) -if (WIN32) - set(PLATFORM "Windows") - set(ORT_LIB "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/${CMAKE_BUILD_TYPE}/onnxruntime.lib") - set(DEPS_PATH "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/_deps") +set(DEPS_PATH "${CMAKE_BINARY_DIR}/_deps") + +if (WIN32) # Windows + set(ORT_LIB "${ORT_HOME}/lib/onnxruntime.lib") + #set(ORT_LIB "${ORT_HOME}/lib/onnxruntime.dll") set(TRT_LIBS "${TENSORRT_HOME}/lib/nvinfer_10.lib" "${TENSORRT_HOME}/lib/nvinfer_plugin_10.lib" "${TENSORRT_HOME}/lib/nvonnxparser_10.lib") set(DEPS_LIBS "${DEPS_PATH}/flatbuffers-build/${CMAKE_BUILD_TYPE}/flatbuffers.lib" - "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx.lib" - "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx_proto.lib") - + "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx.lib" + "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx_proto.lib") + if(CMAKE_BUILD_TYPE STREQUAL "Debug") set(DEPS_LIBS ${DEPS_LIBS} "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobufd.lib" @@ -70,17 +89,15 @@ if (WIN32) "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobuf.lib" "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotoc.lib") endif() -else() - set(PLATFORM "Linux") - set(ORT_LIB "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/libonnxruntime.so") - set(DEPS_PATH "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/_deps") +else() # Linux + set(ORT_LIB "${ORT_HOME}/lib/libonnxruntime.so") set(TRT_LIBS "${TENSORRT_HOME}/lib/libnvinfer.so" "${TENSORRT_HOME}/lib/libnvinfer_plugin.so" "${TENSORRT_HOME}/lib/libnvonnxparser.so") set(DEPS_LIBS "${DEPS_PATH}/flatbuffers-build/libflatbuffers.a" "${DEPS_PATH}/onnx-build/libonnx.a" "${DEPS_PATH}/onnx-build/libonnx_proto.a") - + if(CMAKE_BUILD_TYPE STREQUAL "Debug") set(DEPS_LIBS ${DEPS_LIBS} "${DEPS_PATH}/protobuf-build/libprotobufd.a" @@ -93,17 +110,14 @@ else() endif() MESSAGE(STATUS "Looking for following dependencies ...") -MESSAGE(STATUS "Platform : ${PLATFORM}") -MESSAGE(STATUS "ORT home : ${ORT_HOME}") MESSAGE(STATUS "ORT lib : ${ORT_LIB}") -MESSAGE(STATUS "Deps path: ${DEPS_PATH}") -MESSAGE(STATUS "Deps libs: ${DEPS_LIBS}") MESSAGE(STATUS "TRT libs : ${TRT_LIBS}") +MESSAGE(STATUS "Deps libs: ${DEPS_LIBS}") -target_include_directories(TensorRTEp PUBLIC "${ORT_HOME}/include/onnxruntime/core/session/" +target_include_directories(TensorRTEp PUBLIC "${ORT_HOME}/include" "./utils" "/usr/local/cuda/include" - ${TENSORRT_HOME}/include + "${TENSORRT_HOME}/include" "${DEPS_PATH}/flatbuffers-src/include" "${DEPS_PATH}/gsl-src/include" "${DEPS_PATH}/onnx-src" @@ -114,6 +128,8 @@ target_include_directories(TensorRTEp PUBLIC "${ORT_HOME}/include/onnxruntime/co target_link_libraries(TensorRTEp PUBLIC ${ORT_LIB} ${TRT_LIBS} CUDA::cudart - ${DEPS_LIBS} - GSL - flatbuffers) + protobuf + onnx + gsl + flatbuffers +) From 3ad7736f8590b56d4d0eba0fd17df1804a9f6344 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 7 Jul 2025 18:43:49 -0700 Subject: [PATCH 17/60] Update CompileImpl --- .../tensorrt/tensorrt_execution_provider.cc | 80 +++++++++---------- 1 file changed, 37 insertions(+), 43 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 0e1fe33d..f56d2232 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -13,7 +13,7 @@ //#include "tensorrt_execution_provider_utils.h" #include "tensorrt_execution_provider.h" #include "cuda_allocator.h" -//#include "onnx_ctx_model_helper.h" +#include "onnx_ctx_model_helper.h" #include "onnx/onnx_pb.h" #include "cuda/unary_elementwise_ops_impl.h" @@ -1480,8 +1480,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } -OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info) { +OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, OrtEpGraphSupportInfo* graph_support_info) { TensorrtExecutionProvider* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; @@ -1780,71 +1779,66 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this return nullptr; } -OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, - _In_ const OrtNode** fused_nodes, _In_ size_t count, +OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ptr, + _In_ const OrtGraph** graphs, + _In_ const OrtNode** fused_nodes, + _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, _Out_writes_(count) OrtNode** ep_context_nodes) { TensorrtExecutionProvider* ep = static_cast(this_ptr); + const OrtApi& ort_api = ep->ort_api; gsl::span result(node_compute_infos, count); - for (size_t graph_idx = 0; graph_idx < count; graph_idx++) { - auto fused_node = fused_nodes[graph_idx]; - - // Gets node's inputs and outputs as pointer array - OrtArrayOfConstObjects* inputs_array = nullptr; - OrtArrayOfConstObjects* outputs_array = nullptr; - DeferOrtRelease release_inputs(&inputs_array, ep->ort_api.ReleaseArrayOfConstObjects); - DeferOrtRelease release_outputs(&outputs_array, ep->ort_api.ReleaseArrayOfConstObjects); - - RETURN_IF_ERROR(ep->ort_api.Node_GetInputs(fused_node, &inputs_array)); - RETURN_IF_ERROR(ep->ort_api.Node_GetOutputs(fused_node, &outputs_array)); - - // Gets node's inputs and outputs as OrtValueInfo in gsl::span - gsl::span node_inputs{}; - gsl::span node_outputs{}; - - GetSpanFromArrayOfConstObjects(inputs_array, node_inputs); - GetSpanFromArrayOfConstObjects(outputs_array, node_outputs); + for (size_t fused_node_idx = 0; fused_node_idx < count; fused_node_idx++) { + auto fused_node = fused_nodes[fused_node_idx]; // Gets number of node's inputs and outputs size_t num_node_inputs = 0; - size_t num_node_outputs = 0; - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(inputs_array, &num_node_inputs)); - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(outputs_array, &num_node_outputs)); + RETURN_IF_ERROR(ort_api.Node_GetNumInputs(fused_node, &num_node_inputs)); + + std::vector node_inputs(num_node_inputs); + RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node, node_inputs.data(), node_inputs.size())); // Builds map from input name to its index in input list std::unordered_map input_map; input_map.reserve(num_node_inputs); for (size_t i = 0; i < num_node_inputs; i++) { - // TODO: Add ValueInfo_GetName() c api - //std::string& name = node_inputs[i]->GetName(); - //input_map[name] = i; + const OrtValueInfo* value_info = node_inputs[i]; + const char* name = nullptr; + RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &name)); + + input_map.emplace(name, i); } + // Gets number of node's outputs + size_t num_node_outputs = 0; + RETURN_IF_ERROR(ort_api.Node_GetNumInputs(fused_node, &num_node_outputs)); + + std::vector node_outputs(num_node_outputs); + RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node, node_outputs.data(), node_outputs.size())); + // Builds map from output name to its index in output list std::unordered_map output_map; - input_map.reserve(num_node_outputs); + output_map.reserve(num_node_outputs); for (size_t i = 0; i < num_node_outputs; i++) { - // TODO: Add ValueInfo_GetName() c api - //std::string& name = node_outputs[i]->GetName(); - //output_map[name] = i; - } + const OrtValueInfo* value_info = node_outputs[i]; + const char* name = nullptr; + RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &name)); + output_map.emplace(name, i); + } + OrtStatus* status; - //if (GraphHasCtxNode(graph_body_viewer)) { - if (false) { - status = ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[graph_idx], fused_node, + if (GraphHasCtxNode(graphs[fused_node_idx], ort_api)) { + RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[fused_node_idx], fused_node, input_map, - output_map, &result[graph_idx]); + output_map, &result[fused_node_idx])); } else { - status = ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[graph_idx], fused_node, input_map, - output_map, &result[graph_idx]); + RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[fused_node_idx], fused_node, input_map, + output_map, &result[fused_node_idx])); } - //if (status != Status::OK()) { - // return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); - //} } return nullptr; From 3ced4cfe0de141d403f8086c6a4955cebf4028ae Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 8 Jul 2025 11:42:50 -0700 Subject: [PATCH 18/60] add ort_graph_to_proto.h and leverage OrtGraphToProto utilities --- .../tensorrt/tensorrt_execution_provider.cc | 330 +++++++- .../tensorrt/utils/ort_graph_to_proto.h | 718 ++++++++++++++++++ 2 files changed, 1029 insertions(+), 19 deletions(-) create mode 100644 plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index f56d2232..8bdd2a7f 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -9,6 +9,9 @@ #include "onnxruntime_cxx_api.h" #undef ORT_API_MANUAL_INIT +#define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL +#include "ort_graph_to_proto.h" + #include "ep_abi_utils.h" //#include "tensorrt_execution_provider_utils.h" #include "tensorrt_execution_provider.h" @@ -716,6 +719,267 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect const OrtGraph* graph, bool* early_termination) const { // Return if iterations are exceeding predefined number SubGraphCollection_t nodes_list_output; + if (iterations > max_iterations) { + *early_termination = true; + return nodes_list_output; + } + + // Get parent graph output names + std::unordered_set graph_output_names; + for (const auto* output_arg : graph.GetOutputs()) { + graph_output_names.insert(output_arg->Name()); + } + + iterations++; + const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); + for (const auto& group : nodes_vector_input) { + // Construct subgraph + if (!group.first.empty()) { + if (group.second) { + nodes_list_output.push_back(group); + } else { + auto model_build = graph.CreateModel(*GetLogger()); + auto& graph_build = model_build->MainGraph(); + bool has_control_flow_op = false; + + // Add node and node args + // If node output is also parent graph output, the output will be added to the + // subgraph's output list + std::vector subgraph_output_names; + for (const auto& index : group.first) { + // Initializers that refer to a memory location in OrtValue + // can not be handled by TRT (unlike those that are on disk). + // This prevents us from sharing the data and we have to make a copy here. + constexpr const bool load_initializers_inline_true = true; + const auto& node = graph.GetNode(node_index[index]); + std::vector inputs, outputs; + for (auto input : node->InputDefs()) { + auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); + inputs.push_back(&n_input); + graph_utils::MakeInitializerCopyIfNotExist(graph.GetGraph(), graph_build, input->Name(), + load_initializers_inline_true); + } + + for (auto input : node->ImplicitInputDefs()) { + graph_utils::MakeInitializerCopyIfNotExist(graph.GetGraph(), graph_build, input->Name(), + load_initializers_inline_true); + } + for (auto output : node->OutputDefs()) { + auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); + outputs.push_back(&n_output); + const auto name = output->Name(); + if (graph_output_names.find(name) != graph_output_names.end()) { + subgraph_output_names.push_back(name); + } + } + + if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) { + has_control_flow_op = true; + } + + // If the node has subgraph, it's possible that the ORT graph of that subgraph and the GraphProto in the node + // attributes are not in sync because of graph optimization. Therefore, we need to force GraphProto attributes + // to be updated in order to get the valid GraphProto. + if (node->GetAttributes().size() > 0) { + auto node_proto = ONNX_NAMESPACE::NodeProto::Create(); + // we need to update any GraphProto attributes for subgraphs so that any changes made by things + // such as the optimizers are captured. otherwise we can end up saving an invalid graph. + node->ToProto(*node_proto, /* update_subgraphs */ true); + const int num_attributes = node_proto->attribute_size(); + auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create(); + node_attributes->reserve(num_attributes); + + for (int i = 0; i < num_attributes; ++i) { + auto& attr = node_proto->attribute(i); + node_attributes->emplace(attr.name(), attr); + } + + // The GraphProto attributes are the updated ones. + graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, + node_attributes.get(), node->Domain()); + } else { + // The GraphProto attributes are the original ones. + graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, + &node->GetAttributes(), node->Domain()); + } + } + + // Only if the newly built graph has control flow op as well as it has parent node, + // it needs to handle outer scope values before calling graph.Resolve(). + if (has_control_flow_op && graph.ParentNode()) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Handle outer scope values for the subgraph " << graph_build.Name(); + BuildSubGraphContext(graph_build); + SetGraphOuterScopeValuesAndInputs(graph_build, graph.GetGraph()); + SetAllGraphInputs(graph_build); + } + + ORT_ENFORCE(graph_build.Resolve().IsOK()); + + // Add parent graph output to the subgraph + int i = 0; + std::vector subgraph_outputs; + subgraph_outputs.resize(subgraph_output_names.size()); + for (auto& name : subgraph_output_names) { + auto output_arg = graph.GetNodeArg(name); + auto& subgraph_output_arg = graph_build.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + subgraph_outputs[i] = &subgraph_output_arg; + ++i; + } + auto& graph_build_outputs = graph_build.GetOutputs(); + subgraph_outputs.insert(subgraph_outputs.begin(), graph_build_outputs.begin(), graph_build_outputs.end()); + graph_build.SetOutputs(graph_build_outputs); + ORT_ENFORCE(graph_build.Resolve().IsOK()); + + // Check if input tensors have shapes + if (iterations > 1) { + auto graph_inputs = graph_build.GetInputs(); + for (auto input_arg : graph_inputs) { + bool has_dim_value_or_param = true; + auto input_shape = input_arg->Shape(); + if (input_shape != nullptr) { + auto dim_size = input_shape->dim_size(); + for (int i = 0; i < dim_size; ++i) { + auto& dim = input_shape->dim(i); + if (!dim.has_dim_value() && !dim.has_dim_param()) { + has_dim_value_or_param = false; + break; + } + } + } + + if (input_shape == nullptr || !has_dim_value_or_param) { + ORT_THROW_IF_ERROR( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "TensorRT input: " + input_arg->Name() + " has no shape specified. " + + "Please run shape inference on the onnx model first. Details can be found in " + + "https://onnxruntime.ai/docs/execution-providers/" + "TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs")); + } + } + } + + + /* + //Save initializers to external file + std::string ext_ini_file_path = "model_serialized.bin"; + std::filesystem::remove(ext_ini_file_path); + std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); + auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path]( + const OrtValueInfo* value_info, const void* data, size_t bytes, bool& is_external, + std::string& location, int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, + // node consumers, etc. + (void)value_info; + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = ext_ini_ofs.tellp(); + location = ext_ini_file_path; + ext_ini_ofs.write(static_cast(data), bytes); + ext_ini_ofs.flush(); + is_external = true; // True if is external initializer. + + return Ort::Status{nullptr}; + }; + */ + + // Construct ModelProto from OrtGraph + ONNX_NAMESPACE::ModelProto model_proto; + + // add back handle_initializer_data to save initializer to external file + OrtEpUtils::OrtGraphToProto(*graph, model_proto /*, handle_initializer_data */); + + std::string string_buf; + model_proto.SerializeToString(&string_buf); + + if (dump_subgraphs_) { + // Dump TensorRT subgraph for debugging + std::fstream dump("TensorrtExecutionProvider_TRT_Subgraph.onnx", + std::ios::out | std::ios::trunc | std::ios::binary); + model_proto.SerializeToOstream(&dump); + } + + // Get supported node list recursively + SubGraphCollection_t parser_nodes_list; + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); + auto trt_builder = GetBuilder(trt_logger); + auto network_flags = 0; +#if NV_TENSORRT_MAJOR > 8 + network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_) + ? 0 + : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); +#else + network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); +#endif + + auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); + auto trt_parser = + tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); + +#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 1) || NV_TENSORRT_MAJOR > 10 + auto is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_); + + // Note: Calling getNbSubgraphs or getSubgraphNodes before calling supportsModelV2 results in undefined + // behavior. + auto num_subgraphs = trt_parser->getNbSubgraphs(); + parser_nodes_list.reserve(num_subgraphs); + + for (int64_t i = 0; i < num_subgraphs; ++i) { + int64_t subgraph_len = 0; + int64_t* nodes = trt_parser->getSubgraphNodes(i, subgraph_len); + parser_nodes_list.emplace_back(); + parser_nodes_list.back().first.reserve(subgraph_len); + for (int64_t j = 0; j < subgraph_len; ++j) { + parser_nodes_list.back().first.push_back(nodes[j]); + } + parser_nodes_list.back().second = is_model_supported ? true : false; + } +#else + trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_); +#endif + + SubGraphCollection_t next_nodes_list; + const std::vector& subgraph_node_index = + graph_viewer->GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); + next_nodes_list = + GetSupportedList(parser_nodes_list, iterations, max_iterations, *graph_viewer, early_termination); + for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { + for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { + /* + * Convert the supported node list returning from onnx-tensorrt parser to the node list recognized by ORT + * TRT. + * + * TRT EP reconstructs the graph based on the nodes in group.first and feeds this graph (converts to model + * proto and to string buffer) to onnx-tensorrt parser. The node index in the list returning from + * onnx-tensorrt parser might not be the same as the node index in group.first. Therefore, TRT EP needs a + * node index mapping table here. + * + * The order of iterating the nodes in group.first and calling graph_build.AddNode() determines the node + * order in the newly constructed graph (see Graph::AllocateNode() in graph.cc), however, once the graph is + * converted to model proto, the node proto order in model proto (ex: onnx-tensorrt calls + * model.graph().node() to iterate NodeProto in ModelProto) is decided by topo sort. + * + * The topo sort list (i.e. subgraph_node_index) acts as the node index mapping table: + * subgraph_node_index[node index from onnx-tensorrt parser] = index in group.first + * + * In the past, TRT EP uses ORT's default reversed DFS topo sort which might end up with the sorting result + * not sequence of 0, 1, ... n-1, ex: the subgraph_node_index = [0,2,1,3,4]. With the change of using ORT's + * priority-based topo sort (node with lower node index outputs first) the sorting result is the sequence of + * 0, 1, ... n-1 for most of the cases, therefore subgraph_node_index as a mapping table is not needed + * anymore. + * + * TODO: Remove the subgraph_node_index + */ + next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]]; + } + nodes_list_output.push_back(next_nodes_list[i]); + } + } + } + } return nodes_list_output; } @@ -728,26 +992,50 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this TensorrtExecutionProvider* ep = static_cast(this_ptr); /* - // Reconstruct graph proto from fused node's function body - auto model = graph_body_viewer.CreateModel(*GetLogger()); - auto model_proto = model->ToProto(); - - // ORT's default topological sort is using reversed DFS. - // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. - // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating - // the model proto that has different node ordering compared to original onnx model. - graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true, 1); - model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + //Save initializers to external file + std::string ext_ini_file_path = "model_serialized.bin"; + std::filesystem::remove(ext_ini_file_path); + std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); + auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path]( + const OrtValueInfo* value_info, const void* data, size_t bytes, bool& is_external, + std::string& location, int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, + // node consumers, etc. + (void)value_info; + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = ext_ini_ofs.tellp(); + location = ext_ini_file_path; + ext_ini_ofs.write(static_cast(data), bytes); + ext_ini_ofs.flush(); + is_external = true; // True if is external initializer. + + return Ort::Status{nullptr}; + }; + */ + + // Construct ModelProto from OrtGraph + ONNX_NAMESPACE::ModelProto model_proto; + + // add back handle_initializer_data to save initializer to external file + OrtEpUtils::OrtGraphToProto(*graph, model_proto /*, handle_initializer_data */); + std::string string_buf; - model_proto->SerializeToString(string_buf); + model_proto.SerializeToString(&string_buf); if (dump_subgraphs_) { // Dump TensorRT subgraphs - std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary); - model_proto->SerializeToOstream(dump); + const char* name = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetName(fused_node, &name)); + std::string subgraph_name = name; + subgraph_name += ".onnx"; + std::fstream dump(subgraph_name, std::ios::out | std::ios::trunc | std::ios::binary); + model_proto.SerializeToOstream(&dump); } - */ - std::string string_buf; TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); auto trt_builder = GetBuilder(trt_logger); @@ -1356,6 +1644,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } // Create input to index map + // TRT network input -> ORT fused_node input index for (int i = 0; i < num_inputs; ++i) { auto input = trt_network->getInput(i); const std::string& input_name = input->getName(); @@ -1366,6 +1655,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } // Create output to index and type maps + // TRT network output -> ORT fused_node output index const auto& graph_output = model_proto->graph().output(); for (int i = 0; i < num_outputs; ++i) { const std::string& output_name = trt_network->getOutput(i)->getName(); @@ -1789,7 +2079,8 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ TensorrtExecutionProvider* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; - gsl::span result(node_compute_infos, count); + gsl::span node_compute_infos_result(node_compute_infos, count); + gsl::span ep_context_nodes_result(ep_context_nodes, count); for (size_t fused_node_idx = 0; fused_node_idx < count; fused_node_idx++) { auto fused_node = fused_nodes[fused_node_idx]; @@ -1833,11 +2124,12 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ OrtStatus* status; if (GraphHasCtxNode(graphs[fused_node_idx], ort_api)) { RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[fused_node_idx], fused_node, - input_map, - output_map, &result[fused_node_idx])); + input_map, output_map, + &node_compute_infos_result[fused_node_idx])); } else { RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[fused_node_idx], fused_node, input_map, - output_map, &result[fused_node_idx])); + output_map, &node_compute_infos_result[fused_node_idx]), + &ep_context_nodes_result[fused_node_idx]); } } diff --git a/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h b/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h new file mode 100644 index 00000000..37665542 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h @@ -0,0 +1,718 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/* + SUMMARY: + Utilities to serialize an OrtGraph into an ONNX GraphProto or ModelProto. Can be used by execution provider + implementations that need to convert an OrtGraph instance into an ONNX protobuf model. + + Users may copy this file and modify as needed. + + USAGE: + This is a header-only implementation that includes both the function declarations and definitions. Copy this file + into a project that links with both ONNX Runtime and ONNX. + + Define the ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL preprocessor macro before the #include statement in exactly one C++ + file to define the implementation. Example: + + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + Other compilation units that depend on these utilities should include this file without defining the + preprocessor macro. + + Example program snippets are shown below. Refer to the function declarations for detailed usage information. + + EXAMPLE SNIPPET (initializers stored within TensorProto): + + ```C++ + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) { + onnx::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto); + + // graph_proto stores initializers internally + } + ``` + + EXAMPLE SNIPPET (large initializers stored in external file): + + ```C++ + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) { + std::string external_file_path = "weights.bin"; + std::ofstream out_file(external_file_path, std::ios::binary); + + auto handle_initializer_data = [&external_file_path, &out_file](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, consumers, etc. + (void)value_info; + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = out_file.tellp(); + location = external_file_path; + out_file.write(static_cast(data), bytes); + out_file.flush(); + is_external = true; // True if is external initializer + return Ort::Status{nullptr}; + } + + ONNX_NAMESPACE::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data); + + // graph_proto stores large initializers in an external file + } + ``` +*/ + +#ifndef INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ +#define INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ + +#include +#include "core/session/onnxruntime_cxx_api.h" +#include "onnx/onnx_pb.h" + +namespace OrtEpUtils { + +/// +/// Signature of user-provided function to handle initializer data. Called by OrtGraphToProto() for every initializer. +/// +/// If the function sets the `is_external` output parameter to false, OrtGraphToProto() stores initializer data +/// within the TensorProto as raw_data. +/// +/// Otherwise, if the function sets `is_external` to true, OrtGraphToProto() assumes that this function stores the +/// initializer data in a file. In this case, OrtGraphToProto() configures the corresponding TensorProto to point the +/// location and offset returned via the `location` and `offset` output parameters. +/// +/// It is recommended to keep small initializers with byte size <= 127 stored inline the TensorProto to ensure +/// ONNX shape inference works correctly with the serialized ONNX model. +/// +/// OrtValueInfo for the initializer. Can be used to query name, type, shape, +/// and consumer nodes. +/// Opaque pointer to the initializer data. +/// Size in bytes of the initializer data. +/// Output parameter set to true if the initializer data is stored externally. The +/// implementer is responsible for writing the initializer data to file. If set to false, +/// the initializer will be stored within the TensorProto. +/// Output parameter set to the location (e.g., file) into which the initializer is stored +/// by the implementer of this function. Ignored if `is_external` is set to false. +/// Output parameter set to the offset (e.g., file offset) into which the initializer is stored +/// by the implementer of this function. Ignored if `is_external` is set to false. +/// An Ort::Status indicating success or an error. Serialization exits if this returns an error. +using HandleInitializerDataFunc = std::function; + +/// +/// Serializes the provided OrtGraph to a onnx::GraphProto. +/// Allows the caller to provide a function that specifies whether an initializer should be stored +/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). +/// +/// OrtGraph instance to serialize. +/// Destination GraphProto into which to serialize the input OrtGraph. +/// Optional function called to allow the user to determine +/// where the initializer data is stored. +/// An Ort::Status indicating success or an error. +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::GraphProto& graph_proto, + HandleInitializerDataFunc handle_initializer_data_func = nullptr); + +/// +/// Serializes the provided top-level OrtGraph to a onnx::ModelProto. +/// Allows the caller to provide a function that specifies whether an initializer should be stored +/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). +/// +/// OrtGraph instance to serialize. +/// Destination ModelProto into which to serialize the input OrtGraph. +/// Optional function called to allow the user to determine +/// where the initializer data is stored. +/// An Ort::Status indicating success or an error. +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::ModelProto& model_proto, + HandleInitializerDataFunc handle_initializer_data_func = nullptr); +} // namespace OrtEpUtils + +// End of header +#endif // INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ + +// +// IMPLEMENTATION BELOW +// +#ifdef ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + +#include +#include +#include +#include +#include +#include + +#define ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return Ort::Status{_status}; \ + } \ + } while (0) + +#define ORT_EP_UTILS_CXX_RETURN_IF_ERROR(fn) \ + do { \ + Ort::Status _status = (fn); \ + if (!_status.IsOK()) { \ + return _status; \ + } \ + } while (0) + +#define ORT_EP_UTILS_C_RETURN_IF(cond, ort_api, msg) \ + do { \ + if ((cond)) { \ + return Ort::Status{(ort_api).CreateStatus(ORT_FAIL, (msg))}; \ + } \ + } while (0) + +namespace OrtEpUtils { + +static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, + bool get_symbolic_dims, + /*out*/ ONNXTensorElementDataType& elem_type, + /*out*/ std::vector& dims, + /*out*/ std::vector& symbolic_dims); +static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); + +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::GraphProto& graph_proto, + HandleInitializerDataFunc handle_initializer_data_func) { + const OrtApi& ort_api = Ort::GetApi(); + + // + // Set GraphProto metadata + // + const char* graph_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetName(&ort_graph, &graph_name)); + graph_proto.set_name(graph_name); + graph_proto.set_doc_string("Serialized from OrtGraph"); + + // + // Set GraphProto inputs and outputs + // + size_t num_graph_inputs = 0; + size_t num_graph_outputs = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumInputs(&ort_graph, &num_graph_inputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOutputs(&ort_graph, &num_graph_outputs)); + + std::vector graph_inputs(num_graph_inputs); + std::vector graph_outputs(num_graph_outputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetInputs(&ort_graph, graph_inputs.data(), graph_inputs.size())); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOutputs(&ort_graph, graph_outputs.data(), graph_outputs.size())); + + for (const OrtValueInfo* ort_value_info : graph_inputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); + } + + for (const OrtValueInfo* ort_value_info : graph_outputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); + } + + // + // Set GraphProto nodes, value_infos, and initializers. + // + + // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. + // A std::map maintains its elements in a stable ordering. + std::map value_infos; // For GraphProto.value_info + std::map initializer_value_infos; // For GraphProto.initializer + + // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. + // Optionally returns the OrtValueInfo name to the caller. + auto collect_value_info = [&ort_api, &value_infos, + &initializer_value_infos](const OrtValueInfo& ort_value_info, + /*out*/ const char** value_name_out = nullptr) -> Ort::Status { + const char* value_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); + + if (value_name_out != nullptr) { + *value_name_out = value_name; + } + + if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { + return Ort::Status{nullptr}; // Already processed this OrtValueInfo. + } + + bool is_required_graph_input = false; + bool is_optional_graph_input = false; + bool is_graph_output = false; + bool is_constant_initializer = false; + bool is_from_outer_scope = false; + + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsRequiredGraphInput(&ort_value_info, &is_required_graph_input)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsOptionalGraphInput(&ort_value_info, &is_optional_graph_input)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsGraphOutput(&ort_value_info, &is_graph_output)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(&ort_value_info, &is_constant_initializer)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsFromOuterScope(&ort_value_info, &is_from_outer_scope)); + + // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. + // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. + // For values defined in an outer scope, just add the value info but not the initializer. + if (is_from_outer_scope) { + value_infos.emplace(value_name, &ort_value_info); + } else if (is_optional_graph_input) { + initializer_value_infos.emplace(value_name, &ort_value_info); + } else if (is_constant_initializer) { + value_infos.emplace(value_name, &ort_value_info); + initializer_value_infos.emplace(value_name, &ort_value_info); + } else if (!is_required_graph_input && !is_graph_output) { + value_infos.emplace(value_name, &ort_value_info); // This is an internal OrtValueInfo. + } + + return Ort::Status{nullptr}; + }; + + size_t num_nodes = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); + + std::vector nodes(num_nodes); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); + + // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos + // that will be stored in GraphProto.value_info and GraphProto.initializer. + for (size_t i = 0; i < num_nodes; i++) { + const OrtNode* ort_node = nodes[i]; + onnx::NodeProto* node_proto = graph_proto.add_node(); + + const char* node_name = nullptr; + const char* node_domain = nullptr; + const char* node_op_type = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetName(ort_node, &node_name)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetDomain(ort_node, &node_domain)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOperatorType(ort_node, &node_op_type)); + + node_proto->set_name(node_name); + node_proto->set_domain(node_domain); + node_proto->set_op_type(node_op_type); + + size_t num_inputs = 0; + size_t num_implicit_inputs = 0; + size_t num_outputs = 0; + size_t num_attrs = 0; + size_t num_subgraphs = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumInputs(ort_node, &num_inputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumImplicitInputs(ort_node, &num_implicit_inputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(ort_node, &num_outputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumAttributes(ort_node, &num_attrs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumSubgraphs(ort_node, &num_subgraphs)); + + // Handle node attributes + if (num_attrs > 0) { + std::vector ort_attrs(num_attrs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetAttributes(ort_node, ort_attrs.data(), ort_attrs.size())); + + for (const OrtOpAttr* ort_attr : ort_attrs) { + OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; + + Ort::Status status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; + if (!status.IsOK()) { + // This is an attribute type that ORT does not support via ReadOpAttr(), like subgraphs, so skip it. + // Can use Node_GetSubgraphs to get subgraphs. + continue; + } + + onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); + } + } + + // Handle node subgraphs + if (num_subgraphs > 0) { + std::vector ort_subgraphs(num_subgraphs); + std::vector subgraph_attr_names(num_subgraphs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetSubgraphs(ort_node, ort_subgraphs.data(), ort_subgraphs.size(), + subgraph_attr_names.data())); + + for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { + const OrtGraph* ort_subgraph = ort_subgraphs[subgraph_idx]; + const char* subgraph_attr_name = subgraph_attr_names[subgraph_idx]; + + onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::GraphProto* subgraph_proto = attr_proto->mutable_g(); + + attr_proto->set_name(subgraph_attr_name); + attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_subgraph, *subgraph_proto)); + } + } + + // Handle node inputs + if (num_inputs > 0) { + std::vector ort_inputs(num_inputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetInputs(ort_node, ort_inputs.data(), ort_inputs.size())); + + for (const OrtValueInfo* ort_value_info : ort_inputs) { + if (ort_value_info == nullptr) { + // missing optional input. + node_proto->add_input(""); + continue; + } + + const char* value_name = nullptr; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); + + node_proto->add_input(value_name); + } + } + + // Handle implicit inputs to this node. + if (num_implicit_inputs > 0) { + std::vector ort_implicit_inputs(num_implicit_inputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetImplicitInputs(ort_node, ort_implicit_inputs.data(), + ort_implicit_inputs.size())); + + for (const OrtValueInfo* ort_value_info : ort_implicit_inputs) { + assert(ort_value_info != nullptr); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, /*value_name_out*/ nullptr)); + } + } + + // Handle node outputs + if (num_outputs > 0) { + std::vector ort_outputs(num_outputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOutputs(ort_node, ort_outputs.data(), ort_outputs.size())); + + for (const OrtValueInfo* ort_value_info : ort_outputs) { + if (ort_value_info == nullptr) { + // missing optional output. + node_proto->add_output(""); + continue; + } + + const char* value_name = nullptr; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); + + node_proto->add_output(value_name); + } + } + } + + // Add value_infos to GraphProto as ValueInfoProto objects. + for (const std::pair& entry : value_infos) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*entry.second, *value_info_proto)); + } + + // Add initializers to GraphProto as TensorProto objects. + for (const std::pair& entry : initializer_value_infos) { + const OrtValueInfo* initializer_value_info = entry.second; + std::string initializer_name = std::string{entry.first}; // Need a null-terminated string. + std::vector initializer_dims; + std::vector initializer_sym_dims; + ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(*initializer_value_info, /*get_sym_dims*/ false, + initializer_elem_type, initializer_dims, + initializer_sym_dims)); + + onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); + tensor_proto->set_name(initializer_name); + tensor_proto->set_data_type(initializer_elem_type); + + auto* tensor_proto_dims = tensor_proto->mutable_dims(); + for (int64_t dim : initializer_dims) { + tensor_proto_dims->Add(dim); + } + + const OrtValue* ort_value = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer_value_info, &ort_value)); + + const void* data = nullptr; + size_t data_bytes = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorData(ort_value, &data)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(ort_value, &data_bytes)); + + std::string ext_location; + int64_t ext_offset = 0; + bool is_external = false; + + if (handle_initializer_data_func != nullptr) { + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, + is_external, ext_location, ext_offset)); + } + + if (is_external) { + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); + auto* ext_data_entries = tensor_proto->mutable_external_data(); + onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); + + location_entry->set_key("location"); + location_entry->set_value(ext_location); + offset_entry->set_key("offset"); + offset_entry->set_value(std::to_string(ext_offset)); + } else { + // User wants to store data inline the TensorProto's raw_data + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); + tensor_proto->set_raw_data(data, data_bytes); + } + } + + return Ort::Status{nullptr}; +} + +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::ModelProto& model_proto, + HandleInitializerDataFunc handle_initializer_data_func) { + const OrtApi& ort_api = Ort::GetApi(); + + // Check that OrtGraph is a top-level graph (no parent node). + const OrtNode* parent_node = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetParentNode(&ort_graph, &parent_node)); + ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, ort_api, "Cannot serialize nested OrtGraph into a ModelProto"); + + // Set model description. + model_proto.set_doc_string("Serialized from OrtGraph"); + model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); + + // Set ir version. + int64_t ir_version = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOnnxIRVersion(&ort_graph, &ir_version)); + model_proto.set_ir_version(ir_version); + + // Set operator sets. + size_t num_operator_sets = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOperatorSets(&ort_graph, &num_operator_sets)); + ORT_EP_UTILS_C_RETURN_IF(num_operator_sets == 0, ort_api, "OrtGraph should have at least one operator set."); + + std::vector domains(num_operator_sets, nullptr); + std::vector opset_versions(num_operator_sets); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOperatorSets(&ort_graph, domains.data(), opset_versions.data(), + num_operator_sets)); + + auto* operator_sets = model_proto.mutable_opset_import(); + + for (size_t i = 0; i < num_operator_sets; ++i) { + onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); + operator_set->set_domain(domains[i]); + operator_set->set_version(opset_versions[i]); + } + + model_proto.clear_graph(); + onnx::GraphProto* graph_proto = model_proto.mutable_graph(); + + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(ort_graph, *graph_proto, handle_initializer_data_func)); + + return Ort::Status{nullptr}; +} + +static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, + bool get_symbolic_dims, + /*out*/ ONNXTensorElementDataType& elem_type, + /*out*/ std::vector& dims, + /*out*/ std::vector& symbolic_dims) { + const OrtApi& ort_api = Ort::GetApi(); + + const OrtTypeInfo* ort_type_info = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(&ort_value_info, &ort_type_info)); + + ONNXType ort_onnx_type = ONNX_TYPE_UNKNOWN; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(ort_type_info, &ort_onnx_type)); + ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, ort_api, "Expected OrtValueInfo to represent a Tensor"); + + const OrtTensorTypeAndShapeInfo* ort_type_shape = nullptr; + ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(ort_type_info, &ort_type_shape)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorElementType(ort_type_shape, &ort_elem_type)); + + size_t num_dims = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensionsCount(ort_type_shape, &num_dims)); + + std::vector ort_dims(num_dims, 0); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensions(ort_type_shape, ort_dims.data(), ort_dims.size())); + + elem_type = ort_elem_type; + dims = std::move(ort_dims); + + if (get_symbolic_dims) { + std::vector ort_dim_syms(num_dims, nullptr); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetSymbolicDimensions(ort_type_shape, ort_dim_syms.data(), + ort_dim_syms.size())); + + symbolic_dims.reserve(num_dims); + for (const char* sym_dim : ort_dim_syms) { + symbolic_dims.push_back(sym_dim); + } + } + + return Ort::Status{nullptr}; +} + +// Create an onnx::ValueInfoProto from an OrtValueInfo (name, type, shape). +static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, + onnx::ValueInfoProto& value_info_proto) { + const OrtApi& ort_api = Ort::GetApi(); + + std::vector ort_dims; + std::vector ort_dim_syms; + ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + // We currently only support ONNX tensors. Support for other types (e.g., ONNX_TYPE_SEQUENCE) can be added later. + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true, + ort_elem_type, ort_dims, ort_dim_syms)); + + const char* value_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); + value_info_proto.set_name(value_name); + + onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); + type_proto_tensor->set_elem_type(ort_elem_type); + + onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); + + for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { + onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); + + if (ort_dims[dim_idx] >= 0) { + dim_proto->set_dim_value(ort_dims[dim_idx]); + } else { + const std::string& dim_param = ort_dim_syms[dim_idx]; + + // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, + // which represents an unknown dimension. + if (!dim_param.empty()) { + dim_proto->set_dim_param(dim_param); + } + } + } + + return Ort::Status{nullptr}; +} + +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { + const OrtApi& ort_api = Ort::GetApi(); + + const char* attr_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetName(&ort_attr, &attr_name)); + attr_proto.set_name(attr_name); + + size_t total_attr_bytes = 0; + OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetType(&ort_attr, &attr_type)); + + switch (attr_type) { + case OrtOpAttrType::ORT_OP_ATTR_INT: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); + + int64_t i_val = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &i_val, sizeof(i_val), &total_attr_bytes)); + attr_proto.set_i(i_val); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_INTS: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::vector i_vals(total_attr_bytes / sizeof(int64_t)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, i_vals.data(), total_attr_bytes, + &total_attr_bytes)); + + auto* ints = attr_proto.mutable_ints(); + for (int64_t val : i_vals) { + ints->Add(val); + } + break; + } + case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); + + float f_val = 0.0f; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &f_val, sizeof(f_val), &total_attr_bytes)); + attr_proto.set_f(f_val); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::vector f_vals(total_attr_bytes / sizeof(float)); + + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, f_vals.data(), total_attr_bytes, + &total_attr_bytes)); + + auto* floats = attr_proto.mutable_floats(); + for (float val : f_vals) { + floats->Add(val); + } + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRING: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::string* str = attr_proto.mutable_s(); + + str->resize(total_attr_bytes, '\0'); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes, + &total_attr_bytes)); + + str->resize(total_attr_bytes - 1); // remove extra ending terminating '\0' character. + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::vector chars(total_attr_bytes, '\0'); + + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, chars.data(), total_attr_bytes, + &total_attr_bytes)); + + auto* strs = attr_proto.mutable_strings(); + + // Strings are all in a single buffer, each separated with a '\0'. + // Extract each string and add it to the STRINGS attribute array. + char* at = chars.data(); + char* end = at + chars.size(); + + while (at < end) { + char* str_begin = at; + + while (*at && at < end) { + at++; + } + + strs->Add()->assign(str_begin, at - str_begin); + if (at < end) { + assert(*at == '\0'); + at++; // Skip '\0' to get to the beginning of the next string. + } + } + + break; + } + default: { + std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } + + return Ort::Status{nullptr}; +} + +} // namespace OrtEpUtils +#endif // ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL From 081de367627dbe523d73cd09ff7c18055e282b55 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 9 Jul 2025 22:07:43 -0700 Subject: [PATCH 19/60] update EP context model helper --- .../tensorrt/onnx_ctx_model_helper.cc | 303 ++++--------- .../tensorrt/onnx_ctx_model_helper.h | 78 +--- .../tensorrt/tensorrt_execution_provider.cc | 413 +++--------------- .../tensorrt/tensorrt_execution_provider.h | 3 +- 4 files changed, 157 insertions(+), 640 deletions(-) diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc index 1b29f626..0a479ebc 100644 --- a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc @@ -1,258 +1,111 @@ -#include +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include +#include + +#include "tensorrt_execution_provider_utils.h" #include "onnx_ctx_model_helper.h" -#include "tensorrt_execution_provider.h" -#include "path_string.h" -namespace onnxruntime { - -bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer) { - const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); - const OrtGraphApi* graph_api = api->GetGraphApi(ORT_API_VERSION); - int maxNodeIndex = 0; - graph_api->OrtGraph_MaxNodeIndex(graph_viewer, &maxNodeIndex); - for (int i = 0; i < maxNodeIndex; ++i) { - const OrtNode* node = nullptr; - graph_api->OrtGraph_GetOrtNode(graph_viewer, i, &node); - if (node == nullptr) { - continue; - } - const char* opType = nullptr; - graph_api->OrtNode_GetOpType(node, &opType); - if (strcmp(opType, EPCONTEXT_OP.c_str()) == 0) { - return true; - } - } - return false; -} +extern TensorrtLogger& GetTensorrtLogger(bool verbose_log); /* - * Return the directory where the ep context model locates + * Check whether the graph has the EP context node. + * The node can contain the precompiled engine info for TRT EP to directly load the engine. + * + * Note: Please see more details about "EPContext" contrib op in contrib_defs.cc */ -std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path) { - if (ep_context_file_path.empty()) { - return std::filesystem::path(); - } - std::filesystem::path ctx_path(ep_context_file_path); - if (std::filesystem::is_directory(ep_context_file_path)) { - return ctx_path; - } else { - return ctx_path.parent_path(); - } -} +bool EPContextNodeHelper::GraphHasCtxNode(const OrtGraph* graph, const OrtApi& ort_api) { + size_t num_nodes = 0; + RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); -std::string GetCtxModelPath(const std::string& ep_context_file_path, - const std::string& original_model_path) { - std::string ctx_model_path; + std::vector nodes(num_nodes); - if (!ep_context_file_path.empty() && !std::filesystem::is_directory(ep_context_file_path)) { - ctx_model_path = ep_context_file_path; - } else { - std::filesystem::path model_path = original_model_path; - std::filesystem::path model_name_stem = model_path.stem(); // model_name.onnx -> model_name - std::string ctx_model_name = model_name_stem.string() + "_ctx.onnx"; + for (size_t i = 0; i < num_nodes; ++i) { + auto node = nodes[i]; - if (std::filesystem::is_directory(ep_context_file_path)) { - std::filesystem::path model_directory = ep_context_file_path; - ctx_model_path = model_directory.append(ctx_model_name).string(); - } else { - ctx_model_path = ctx_model_name; + const char* op_type = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetOperatorType(node, &op_type)); + if (node != nullptr && op_type == "EPContext") { + return true; } } - return ctx_model_path; -} - -bool IsAbsolutePath(const std::string& path_string) { -#ifdef _WIN32 - onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); - auto path = std::filesystem::path(ort_path_string.c_str()); - return path.is_absolute(); -#else - if (!path_string.empty() && path_string[0] == '/') { - return true; - } return false; -#endif -} - -// Like "../file_path" -bool IsRelativePathToParentPath(const std::string& path_string) { -#ifdef _WIN32 - onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); - auto path = std::filesystem::path(ort_path_string.c_str()); - auto relative_path = path.lexically_normal().make_preferred().wstring(); - if (relative_path.find(L"..", 0) != std::string::npos) { - return true; - } - return false; -#else - if (!path_string.empty() && path_string.find("..", 0) != std::string::npos) { - return true; - } - return false; -#endif } /* - * Get the weight-refitted engine cache path from a weight-stripped engine cache path - * - * Weight-stipped engine: - * An engine with weights stripped and its size is smaller than a regualr engine. - * The cache name of weight-stripped engine is TensorrtExecutionProvider_TRTKernel_XXXXX.stripped.engine - * - * Weight-refitted engine: - * An engine that its weights have been refitted and it's simply a regular engine. - * The cache name of weight-refitted engine is TensorrtExecutionProvider_TRTKernel_XXXXX.engine + * Create EPContext OrtNode from a fused_node */ -std::string GetWeightRefittedEnginePath(std::string stripped_engine_cache) { - std::filesystem::path stripped_engine_cache_path(stripped_engine_cache); - std::string refitted_engine_cache_path = stripped_engine_cache_path.stem().stem().string() + ".engine"; - return refitted_engine_cache_path; -} - -bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) { - // The weight-stripped engine cache has the naming of xxx.stripped.engine - return engine_cache_path.stem().extension().string() == ".stripped"; -} - -OrtStatusPtr TensorRTCacheModelHandler::GetEpContextFromGraph(const OrtGraphViewer* graph_viewer) { - if (!ValidateEPCtxNode(graph_viewer)) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "It's not a valid EP Context node"); - } - const OrtNode* node = nullptr; - graph_api_->OrtGraph_GetOrtNode(graph_viewer, 0, &node); - - int64_t embed_mode = -1; - graph_api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str(), &embed_mode); - if (embed_mode) { - // Get engine from byte stream. - const char* context_binary_cstr = nullptr; - size_t size; - graph_api_->OrtNode_GetAttributeStrWithSize(node, EP_CACHE_CONTEXT.c_str(), &context_binary_cstr, &size); - std::string context_binary(context_binary_cstr, size); - *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(const_cast(context_binary.c_str()), - static_cast(context_binary.length()))); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it"; - if (!(*trt_engine_)) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not deserialize engine from binary data"); +OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_cache_path, + char* engine_data, + size_t size, + const int64_t embed_mode, + const std::string& compute_capability, + const std::string& onnx_model_path, + OrtNode** ep_context_node) { + + // Helper to collect input or output names from an array of OrtValueInfo instances. + auto collect_input_output_names = [&](gsl::span value_infos, + std::vector& result) -> OrtStatus* { + size_t num_values = value_infos.size(); + std::vector value_names(num_values); + + for (size_t i = 0; i < num_values; ++i) { + const OrtValueInfo* value_info = value_infos[i]; + RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &value_names[i])); } - } else { - // Get engine from cache file. - const char* cache_path_cstr = nullptr; - graph_api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str(), &cache_path_cstr); - std::string cache_path(cache_path_cstr); - // For security purpose, in the case of running context model, TRT EP won't allow - // engine cache path to be the relative path like "../file_path" or the absolute path. - // It only allows the engine cache to be in the same directory or sub directory of the context model. - if (IsAbsolutePath(cache_path)) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("For security purpose, the ep_cache_context attribute should be set with a relative path, but it is an absolute path: " + cache_path).c_str()); - } - if (IsRelativePathToParentPath(cache_path)) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "The file path in ep_cache_context attribute has '..'. For security purpose, it's not allowed to point outside the directory."); - } + result = std::move(value_names); + return nullptr; + }; - // The engine cache and context model (current model) should be in the same directory - std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_)); - auto engine_cache_path = ctx_model_dir.append(cache_path); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string(); + const char* fused_node_name = nullptr; - // If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled - if (!weight_stripped_engine_refit_) { - weight_stripped_engine_refit_ = IsWeightStrippedEngineCache(engine_cache_path); - } + RETURN_IF_ERROR(ort_api.Node_GetName(fused_node_, &fused_node_name)); - // If the serialized refitted engine is present, use it directly without refitting the engine again - if (weight_stripped_engine_refit_) { - const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string()); - if (std::filesystem::exists(refitted_engine_cache_path)) { -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " + refitted_engine_cache_path.string() + " exists."; - engine_cache_path = refitted_engine_cache_path.string(); - weight_stripped_engine_refit_ = false; - } - } + size_t num_fused_node_inputs = 0; + size_t num_fused_node_outputs = 0; + RETURN_IF_ERROR(ort_api.Node_GetNumInputs(fused_node_, &num_fused_node_inputs)); + RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(fused_node_, &num_fused_node_outputs)); - if (!std::filesystem::exists(engine_cache_path)) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, - std::string("TensorRT EP can't find engine cache: " + engine_cache_path.string() + - ". Please make sure engine cache is in the same directory or sub-directory of context model.").c_str()); - } + std::vector fused_node_inputs(num_fused_node_inputs); + std::vector fused_node_outputs(num_fused_node_outputs); + RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node_, fused_node_inputs.data(), fused_node_inputs.size())); + RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node_, fused_node_outputs.data(), fused_node_outputs.size())); - std::ifstream engine_file(engine_cache_path.string(), std::ios::binary | std::ios::in); - engine_file.seekg(0, std::ios::end); - size_t engine_size = engine_file.tellg(); - engine_file.seekg(0, std::ios::beg); - std::unique_ptr engine_buf{new char[engine_size]}; - engine_file.read((char*)engine_buf.get(), engine_size); - *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); - if (!(*trt_engine_)) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, - std::string("TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string()).c_str()); - } -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string(); + std::vector input_names; + std::vector output_names; - if (weight_stripped_engine_refit_) { - const char* onnx_model_filename_cstr = nullptr; - graph_api_->OrtNode_GetAttributeStr(node, ONNX_MODEL_FILENAME.c_str(), &onnx_model_filename_cstr); - const std::string onnx_model_filename(onnx_model_filename_cstr); - std::string weight_stripped_engine_cache = engine_cache_path.string(); - auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename, - onnx_model_folder_path_, - weight_stripped_engine_cache, - true /* path check for security */, - (*trt_engine_).get(), - true /* serialize refitted engine to disk */, - detailed_build_log_); - if (status != nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); - } - } - } - return nullptr; -} + RETURN_IF_ERROR(collect_input_output_names(fused_node_inputs, /*out*/ input_names)); + RETURN_IF_ERROR(collect_input_output_names(fused_node_outputs, /*out*/ output_names)); -bool TensorRTCacheModelHandler::ValidateEPCtxNode(const OrtGraphViewer* graph_viewer) { - int node_count = 0; - graph_api_->OrtGraph_NumberOfNodes(graph_viewer, &node_count); - assert(node_count == 1); - const OrtNode* node = nullptr; - graph_api_->OrtGraph_GetOrtNode(graph_viewer, 0, &node); - const char* opType = nullptr; - graph_api_->OrtNode_GetOpType(node, &opType); - assert(strcmp(opType, EPCONTEXT_OP.c_str()) == 0); + // Create node attributes. The CreateNode() function copies the attributes, so we have to release them. + std::array attributes = {}; + DeferOrtRelease defer_release_attrs(attributes.data(), attributes.size(), ort_api.ReleaseOpAttr); - size_t key_count = 0; - graph_api_->OrtNode_GetAttributeKeyCount(node, COMPUTE_CAPABILITY.c_str(), &key_count); - // Show the warning if compute capability is not matched - if (key_count > 0) { - const char* model_compute_capability = nullptr; - graph_api_->OrtNode_GetAttributeStr(node, COMPUTE_CAPABILITY.c_str(), &model_compute_capability); - // Verify if engine was compiled with ampere+ hardware compatibility enabled - if (strcmp(model_compute_capability, "80+") == 0) { -// if (std::stoi(compute_capability_) < 80) { -// LOGS_DEFAULT(WARNING) << "[TensorRT EP] However, this GPU doesn't match. The compute capability of the GPU: " << compute_capability_; -// } - } else if (strcmp(model_compute_capability, compute_capability_.c_str()) != 0) { -// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine was compiled for a different compatibility level and might not work or perform suboptimal"; -// LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the engine: " << model_compute_capability; -// LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the GPU: " << compute_capability_; + RETURN_IF_ERROR(ort_api.CreateOpAttr("embed_mode", &embed_mode, 1, ORT_OP_ATTR_INT, &attributes[0])); + + std::string engine_data_str = ""; + if (embed_mode) { + if (size > 0) { + engine_data_str.assign(engine_data, size); } + RETURN_IF_ERROR( + ort_api.CreateOpAttr("ep_cache_context", engine_data_str.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[1])); + } else { + RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", engine_cache_path.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[1])); } - // "embed_mode" attr and "ep_cache_context" attr should be present - graph_api_->OrtNode_GetAttributeKeyCount(node, EMBED_MODE.c_str(), &key_count); - assert(key_count > 0); - graph_api_->OrtNode_GetAttributeKeyCount(node, EP_CACHE_CONTEXT.c_str(), &key_count); - assert(key_count > 0); + + ort_api.CreateOpAttr("hardware_architecture", compute_capability.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[2]); + ort_api.CreateOpAttr("onnx_model_filename", std::filesystem::path(onnx_model_path).filename().string().c_str(), 1, + ORT_OP_ATTR_STRING, &attributes[3]); - int64_t embed_mode = -1; - graph_api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str(), &embed_mode); - if (embed_mode == 1) { - // engine binary data -// LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; - } - return true; -} + RETURN_IF_ERROR(model_editor_api.CreateNode("EPContext", "com.microsoft", fused_node_name, input_names.data(), + input_names.size(), output_names.data(), output_names.size(), + attributes.data(), attributes.size(), ep_context_node)); + + return nullptr; } diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h index 77efc11f..1b1d2891 100644 --- a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h @@ -3,64 +3,34 @@ #pragma once +#include "tensorrt_execution_provider_utils.h" +#include "tensorrt_execution_provider.h" +#include "nv_includes.h" + #include #include #include -#include "onnxruntime_c_api.h" -#include "nv_includes.h" - -namespace onnxruntime { +#include -static const std::string EPCONTEXT_OP = "EPContext"; -static const std::string EMBED_MODE = "embed_mode"; -static const std::string EP_CACHE_CONTEXT = "ep_cache_context"; -static const std::string COMPUTE_CAPABILITY = "hardware_architecture"; -static const std::string ONNX_MODEL_FILENAME = "onnx_model_filename"; -static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft"; -static const std::string EPCONTEXT_WARNING = - "It's suggested to set the ORT graph optimization level to 0 and \ - make \"embed_mode\" to 0 (\"ep_cache_context\" is the cache path)\ - for the best model loading time"; - -bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer); -std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); -std::string GetCtxModelPath(const std::string& ep_context_file_path, - const std::string& original_model_path); -bool IsAbsolutePath(const std::string& path_string); -bool IsRelativePathToParentPath(const std::string& path_string); - -class TensorRTCacheModelHandler { +class EPContextNodeHelper : public ApiPtrs { public: - TensorRTCacheModelHandler(std::unique_ptr* trt_engine, - nvinfer1::IRuntime* trt_runtime, - std::string ep_context_model_path, - std::string compute_capability, - bool weight_stripped_engine_refit, - std::string onnx_model_folder_path, - bool detailed_build_log) - : trt_engine_(trt_engine), - trt_runtime_(trt_runtime), - ep_context_model_path_(ep_context_model_path), - compute_capability_(compute_capability), - weight_stripped_engine_refit_(weight_stripped_engine_refit), - onnx_model_folder_path_(onnx_model_folder_path), - detailed_build_log_(detailed_build_log) { - api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); - graph_api_ = api_->GetGraphApi(ORT_API_VERSION); - } - bool ValidateEPCtxNode(const OrtGraphViewer* graph_viewer); - - OrtStatusPtr GetEpContextFromGraph(const OrtGraphViewer* graph_viewer); + EPContextNodeHelper(TensorrtExecutionProvider& ep, + const OrtGraph* graph, + const OrtNode* fused_node) + : ApiPtrs{static_cast(ep)}, graph_(graph), fused_node_(fused_node) {} + + static bool GraphHasCtxNode(const OrtGraph* graph, const OrtApi& ort_api); + + OrtStatus* CreateEPContextNode(const std::string& engine_cache_path, + char* engine_data, + size_t size, + const int64_t embed_mode, + const std::string& compute_capability, + const std::string& onnx_model_path, + OrtNode** ep_context_node + ); private: - std::unique_ptr* trt_engine_; - nvinfer1::IRuntime* trt_runtime_; - std::string ep_context_model_path_; // If using context model, it implies context model and engine cache is in the same directory - std::string compute_capability_; - bool weight_stripped_engine_refit_; - std::string onnx_model_folder_path_; - bool detailed_build_log_; - const OrtApi* api_; - const OrtGraphApi* graph_api_; -}; // TRTCacheModelHandler -} + const OrtGraph* graph_ = nullptr; + const OrtNode* fused_node_ = nullptr; +}; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 8bdd2a7f..fb5a0bc9 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -32,10 +32,6 @@ #define LIBFUNC(lib, fn) dlsym((lib), (fn)) #endif -const OrtApi* g_ort_api = nullptr; -const OrtEpApi* g_ep_api = nullptr; -const OrtModelEditorApi* g_model_editor_api = nullptr; - void CUDA_RETURN_IF_ERROR(cudaError_t res) { if (res != cudaSuccess) abort(); } @@ -719,267 +715,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect const OrtGraph* graph, bool* early_termination) const { // Return if iterations are exceeding predefined number SubGraphCollection_t nodes_list_output; - if (iterations > max_iterations) { - *early_termination = true; - return nodes_list_output; - } - - // Get parent graph output names - std::unordered_set graph_output_names; - for (const auto* output_arg : graph.GetOutputs()) { - graph_output_names.insert(output_arg->Name()); - } - - iterations++; - const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); - for (const auto& group : nodes_vector_input) { - // Construct subgraph - if (!group.first.empty()) { - if (group.second) { - nodes_list_output.push_back(group); - } else { - auto model_build = graph.CreateModel(*GetLogger()); - auto& graph_build = model_build->MainGraph(); - bool has_control_flow_op = false; - - // Add node and node args - // If node output is also parent graph output, the output will be added to the - // subgraph's output list - std::vector subgraph_output_names; - for (const auto& index : group.first) { - // Initializers that refer to a memory location in OrtValue - // can not be handled by TRT (unlike those that are on disk). - // This prevents us from sharing the data and we have to make a copy here. - constexpr const bool load_initializers_inline_true = true; - const auto& node = graph.GetNode(node_index[index]); - std::vector inputs, outputs; - for (auto input : node->InputDefs()) { - auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); - inputs.push_back(&n_input); - graph_utils::MakeInitializerCopyIfNotExist(graph.GetGraph(), graph_build, input->Name(), - load_initializers_inline_true); - } - - for (auto input : node->ImplicitInputDefs()) { - graph_utils::MakeInitializerCopyIfNotExist(graph.GetGraph(), graph_build, input->Name(), - load_initializers_inline_true); - } - for (auto output : node->OutputDefs()) { - auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); - outputs.push_back(&n_output); - const auto name = output->Name(); - if (graph_output_names.find(name) != graph_output_names.end()) { - subgraph_output_names.push_back(name); - } - } - - if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) { - has_control_flow_op = true; - } - - // If the node has subgraph, it's possible that the ORT graph of that subgraph and the GraphProto in the node - // attributes are not in sync because of graph optimization. Therefore, we need to force GraphProto attributes - // to be updated in order to get the valid GraphProto. - if (node->GetAttributes().size() > 0) { - auto node_proto = ONNX_NAMESPACE::NodeProto::Create(); - // we need to update any GraphProto attributes for subgraphs so that any changes made by things - // such as the optimizers are captured. otherwise we can end up saving an invalid graph. - node->ToProto(*node_proto, /* update_subgraphs */ true); - const int num_attributes = node_proto->attribute_size(); - auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create(); - node_attributes->reserve(num_attributes); - - for (int i = 0; i < num_attributes; ++i) { - auto& attr = node_proto->attribute(i); - node_attributes->emplace(attr.name(), attr); - } - - // The GraphProto attributes are the updated ones. - graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, - node_attributes.get(), node->Domain()); - } else { - // The GraphProto attributes are the original ones. - graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, - &node->GetAttributes(), node->Domain()); - } - } - - // Only if the newly built graph has control flow op as well as it has parent node, - // it needs to handle outer scope values before calling graph.Resolve(). - if (has_control_flow_op && graph.ParentNode()) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Handle outer scope values for the subgraph " << graph_build.Name(); - BuildSubGraphContext(graph_build); - SetGraphOuterScopeValuesAndInputs(graph_build, graph.GetGraph()); - SetAllGraphInputs(graph_build); - } - - ORT_ENFORCE(graph_build.Resolve().IsOK()); - - // Add parent graph output to the subgraph - int i = 0; - std::vector subgraph_outputs; - subgraph_outputs.resize(subgraph_output_names.size()); - for (auto& name : subgraph_output_names) { - auto output_arg = graph.GetNodeArg(name); - auto& subgraph_output_arg = graph_build.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); - subgraph_outputs[i] = &subgraph_output_arg; - ++i; - } - auto& graph_build_outputs = graph_build.GetOutputs(); - subgraph_outputs.insert(subgraph_outputs.begin(), graph_build_outputs.begin(), graph_build_outputs.end()); - graph_build.SetOutputs(graph_build_outputs); - ORT_ENFORCE(graph_build.Resolve().IsOK()); - - // Check if input tensors have shapes - if (iterations > 1) { - auto graph_inputs = graph_build.GetInputs(); - for (auto input_arg : graph_inputs) { - bool has_dim_value_or_param = true; - auto input_shape = input_arg->Shape(); - if (input_shape != nullptr) { - auto dim_size = input_shape->dim_size(); - for (int i = 0; i < dim_size; ++i) { - auto& dim = input_shape->dim(i); - if (!dim.has_dim_value() && !dim.has_dim_param()) { - has_dim_value_or_param = false; - break; - } - } - } - - if (input_shape == nullptr || !has_dim_value_or_param) { - ORT_THROW_IF_ERROR( - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "TensorRT input: " + input_arg->Name() + " has no shape specified. " + - "Please run shape inference on the onnx model first. Details can be found in " + - "https://onnxruntime.ai/docs/execution-providers/" - "TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs")); - } - } - } - - - /* - //Save initializers to external file - std::string ext_ini_file_path = "model_serialized.bin"; - std::filesystem::remove(ext_ini_file_path); - std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); - auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path]( - const OrtValueInfo* value_info, const void* data, size_t bytes, bool& is_external, - std::string& location, int64_t& offset) -> Ort::Status { - // OrtValueInfo* could be used to query initializer's name, type, shape, - // node consumers, etc. - (void)value_info; - - if (bytes <= 127) { - is_external = false; // Keep small initializers stored inside the TensorProto. - return Ort::Status{nullptr}; - } - - offset = ext_ini_ofs.tellp(); - location = ext_ini_file_path; - ext_ini_ofs.write(static_cast(data), bytes); - ext_ini_ofs.flush(); - is_external = true; // True if is external initializer. - - return Ort::Status{nullptr}; - }; - */ - - // Construct ModelProto from OrtGraph - ONNX_NAMESPACE::ModelProto model_proto; - - // add back handle_initializer_data to save initializer to external file - OrtEpUtils::OrtGraphToProto(*graph, model_proto /*, handle_initializer_data */); - - std::string string_buf; - model_proto.SerializeToString(&string_buf); - - if (dump_subgraphs_) { - // Dump TensorRT subgraph for debugging - std::fstream dump("TensorrtExecutionProvider_TRT_Subgraph.onnx", - std::ios::out | std::ios::trunc | std::ios::binary); - model_proto.SerializeToOstream(&dump); - } - - // Get supported node list recursively - SubGraphCollection_t parser_nodes_list; - TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); - auto trt_builder = GetBuilder(trt_logger); - auto network_flags = 0; -#if NV_TENSORRT_MAJOR > 8 - network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_) - ? 0 - : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); -#else - network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); -#endif - - auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); - auto trt_parser = - tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); - -#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 1) || NV_TENSORRT_MAJOR > 10 - auto is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_); - - // Note: Calling getNbSubgraphs or getSubgraphNodes before calling supportsModelV2 results in undefined - // behavior. - auto num_subgraphs = trt_parser->getNbSubgraphs(); - parser_nodes_list.reserve(num_subgraphs); - - for (int64_t i = 0; i < num_subgraphs; ++i) { - int64_t subgraph_len = 0; - int64_t* nodes = trt_parser->getSubgraphNodes(i, subgraph_len); - parser_nodes_list.emplace_back(); - parser_nodes_list.back().first.reserve(subgraph_len); - for (int64_t j = 0; j < subgraph_len; ++j) { - parser_nodes_list.back().first.push_back(nodes[j]); - } - parser_nodes_list.back().second = is_model_supported ? true : false; - } -#else - trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_); -#endif - SubGraphCollection_t next_nodes_list; - const std::vector& subgraph_node_index = - graph_viewer->GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); - next_nodes_list = - GetSupportedList(parser_nodes_list, iterations, max_iterations, *graph_viewer, early_termination); - for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { - for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { - /* - * Convert the supported node list returning from onnx-tensorrt parser to the node list recognized by ORT - * TRT. - * - * TRT EP reconstructs the graph based on the nodes in group.first and feeds this graph (converts to model - * proto and to string buffer) to onnx-tensorrt parser. The node index in the list returning from - * onnx-tensorrt parser might not be the same as the node index in group.first. Therefore, TRT EP needs a - * node index mapping table here. - * - * The order of iterating the nodes in group.first and calling graph_build.AddNode() determines the node - * order in the newly constructed graph (see Graph::AllocateNode() in graph.cc), however, once the graph is - * converted to model proto, the node proto order in model proto (ex: onnx-tensorrt calls - * model.graph().node() to iterate NodeProto in ModelProto) is decided by topo sort. - * - * The topo sort list (i.e. subgraph_node_index) acts as the node index mapping table: - * subgraph_node_index[node index from onnx-tensorrt parser] = index in group.first - * - * In the past, TRT EP uses ORT's default reversed DFS topo sort which might end up with the sorting result - * not sequence of 0, 1, ... n-1, ex: the subgraph_node_index = [0,2,1,3,4]. With the change of using ORT's - * priority-based topo sort (node with lower node index outputs first) the sorting result is the sequence of - * 0, 1, ... n-1 for most of the cases, therefore subgraph_node_index as a mapping table is not needed - * anymore. - * - * TODO: Remove the subgraph_node_index - */ - next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]]; - } - nodes_list_output.push_back(next_nodes_list[i]); - } - } - } - } return nodes_list_output; } @@ -988,7 +724,8 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this const OrtNode* fused_node, std::unordered_map& input_map, std::unordered_map& output_map, - /* out */OrtNodeComputeInfo** node_compute_info) { + /* out */OrtNodeComputeInfo** node_compute_info, + /* out */OrtNode** ep_context_node) { TensorrtExecutionProvider* ep = static_cast(this_ptr); /* @@ -1420,12 +1157,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this weight_stripped_engine_refit_ = true; } - /* - // Generate file name for dumping ep context model - if (dump_ep_context_model_ && ctx_model_path_.empty()) { - ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_); - } - */ + std::unique_ptr serialized_engine = nullptr; if (!has_dynamic_shape) { std::string timing_cache_path = ""; @@ -1525,8 +1257,10 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this if (detailed_build_log_) { engine_build_start = std::chrono::steady_clock::now(); } - std::unique_ptr serialized_engine{ - trt_builder->buildSerializedNetwork(*trt_network, *trt_config)}; + + serialized_engine = + std::make_unique(trt_builder->buildSerializedNetwork(*trt_network, *trt_config)); + if (serialized_engine == nullptr) { std::string err_msg = "TensorRT EP failed to create engine from network for fused node: " + fused_node_name; return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); @@ -1584,27 +1318,6 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; } } - // dump EP context node model - if (dump_ep_context_model_) { - // "ep_cache_context" node attribute should be a relative path to context model directory - if (ep_cache_context_attr_.empty()) { - auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); - ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir) - .append(cache_file_name.string()) - .string(); - } - std::string compute_capability_hw_compat = compute_capability_; - if (engine_cache_enable_ && engine_hw_compatible_) { - compute_capability_hw_compat = "80+"; - } - /* - std::unique_ptr model_proto{ - CreateCtxModel(graph_body_viewer, ep_cache_context_attr_, - reinterpret_cast(serialized_engine->data()), serialized_engine->size(), - ep_context_embed_mode_, compute_capability_hw_compat, model_path_, GetLogger())}; - DumpCtxModel(model_proto.get(), ctx_model_path_); - */ - } } } @@ -1656,7 +1369,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this // Create output to index and type maps // TRT network output -> ORT fused_node output index - const auto& graph_output = model_proto->graph().output(); + const auto& graph_output = model_proto.graph().output(); for (int i = 0; i < num_outputs; ++i) { const std::string& output_name = trt_network->getOutput(i)->getName(); const auto& iter = output_map.find(output_name); @@ -1703,6 +1416,22 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } */ + std::unique_ptr ep_ctx_node_helper = std::make_unique(graph, fused_node); + if (dump_ep_context_model_) { + std::string compute_capability_hw_compat = compute_capability_; + if (engine_cache_enable_ && engine_hw_compatible_) { + compute_capability_hw_compat = "80+"; + } + + ep_ctx_node_helper->CreateEPContextNode(engine_cache_path, + reinterpret_cast(serialized_engine->data()), + serialized_engine->size(), + ep_context_embed_mode_, + compute_capability_hw_compat, + model_path_, + ep_context_node); + } + std::unique_ptr compute_state = std::make_unique(); // translate tactic sources string to nvinfer1::TacticSources @@ -1773,33 +1502,18 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, OrtEpGraphSupportInfo* graph_support_info) { TensorrtExecutionProvider* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; - - // Get ModelPath - /* - const std::filesystem::path* model_path = nullptr; - graph_api_->OrtGraph_GetModelPath(graph, reinterpret_cast(&model_path)); - const auto& path_string = model_path->string(); -#ifdef _WIN32 - strncpy_s(p->model_path_, path_string.c_str(), sizeof(p->model_path_) - 1); -#else - strncpy(p->model_path_, path_string.c_str(), sizeof(p->model_path_) - 1); -#endif - p->model_path_[sizeof(p->model_path_) - 1] = '\0'; - - int node_count = 0; - graph_api_->OrtGraph_NumberOfNodes(graph, &node_count); - if (node_count == 1 && GraphHasCtxNode(graph)) { - SubGraph_t supported_node_vector = {{0}, true}; - std::unique_ptr sub_graph = p->GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph), 0); - *cnt = 1; - *indexed_sub_graph = new OrtIndexedSubGraph*[1]; - (*indexed_sub_graph)[0] = sub_graph.release(); - return; - } - */ - // Generate unique kernel name for TRT graph - // HashValue model_hash = TRTGenerateId(ort_api, graph, std::to_string(trt_version_), std::to_string(cuda_version_)); + size_t num_nodes = 0; + RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); + + // Get all the nodes from the graph + std::vector nodes(num_nodes); + RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); + + SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; + bool new_subgraph = true; + + std::unordered_set control_flow_op_set = {"If", "Loop", "Scan"}; // Get pre-excluded op list from provider options auto get_exclude_ops_set = [&](std::string node_list_to_exclude) -> std::set { @@ -1814,25 +1528,9 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this return set; }; - //auto exclude_ops_set = get_exclude_ops_set(op_types_to_exclude_); + // auto exclude_ops_set = get_exclude_ops_set(op_types_to_exclude_); auto exclude_ops_set = get_exclude_ops_set(""); - // Get all Ort nodes - OrtArrayOfConstObjects* nodes_container = nullptr; - DeferOrtRelease release_nodes(&nodes_container, - ep->ort_api.ReleaseArrayOfConstObjects); - RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graph, &nodes_container)); - - gsl::span nodes{}; - GetSpanFromArrayOfConstObjects(nodes_container, nodes); - // using ORT's priority-based topo sort (node with lower node index outputs first) the sorting result is the sequence of 0, 1, ... n-1 - // RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, /*order*/ 1, nodes.data(), nodes.size())); - - SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; - bool new_subgraph = true; - - std::unordered_set control_flow_op_set = {"If", "Loop", "Scan"}; - /* Iterate all the nodes and exclude the node if: * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. * 2. Its op type is in the exclusion list. @@ -1857,26 +1555,19 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this auto supported_control_flow_op = [&](const OrtNode* node) { OrtStatus* status = nullptr; size_t num_subgraphs = 0; - OrtArrayOfConstObjects* node_subgraphs_container = nullptr; - DeferOrtRelease release_node_subgraphs(&node_subgraphs_container, - ep->ort_api.ReleaseArrayOfConstObjects); - - RETURN_FALSE_AND_PRINT_IF_ERROR(ep->ort_api.Node_GetSubgraphs(node, &node_subgraphs_container), ep->ort_api); - RETURN_FALSE_AND_PRINT_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(node_subgraphs_container, &num_subgraphs), ep->ort_api); + RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetNumSubgraphs(node, &num_subgraphs), ort_api); + + std::vector node_subgraphs(num_subgraphs); + RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetSubgraphs(node, node_subgraphs.data(), node_subgraphs.size(), nullptr), ort_api); + + // Iterate the node's subgraphs for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { - const OrtGraph* subgraph = nullptr; - RETURN_FALSE_AND_PRINT_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetElementAt(node_subgraphs_container, subgraph_idx, - reinterpret_cast(&subgraph)), - ep->ort_api); + const OrtGraph* subgraph = node_subgraphs[subgraph_idx]; // Get number of subgraph's nodes size_t num_subgraph_nodes = 0; - OrtArrayOfConstObjects* subgraph_nodes_container = nullptr; - DeferOrtRelease release_subgraph_nodes(&subgraph_nodes_container, - ep->ort_api.ReleaseArrayOfConstObjects); - RETURN_FALSE_AND_PRINT_IF_ERROR(ep->ort_api.Graph_GetNodes(subgraph, &subgraph_nodes_container), ep->ort_api); - RETURN_FALSE_AND_PRINT_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(subgraph_nodes_container, &num_subgraph_nodes), ep->ort_api); + RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Graph_GetNumNodes(subgraph, &num_subgraph_nodes), ort_api); // TRT EP should consider the empty subgraph is fully supported by TRT. if (num_subgraph_nodes == 0) { @@ -1925,6 +1616,9 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this supported_nodes_vector.clear(); } + // Temporarily make all nodes supported + supported_nodes_vector = parser_nodes_vector; + // Remove subgraphs if its size is less than the predefined minimal size for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end(); ++it) { const size_t subgraph_size = it->first.size(); @@ -2039,11 +1733,10 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this int number_of_trt_nodes = 0; for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { - std::vector supported_nodes; + std::vector supported_nodes(group.first.size()); for (const auto& index : group.first) { - const OrtNode* supported_node = nullptr; - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetElementAt(nodes_container, index, - reinterpret_cast(&supported_node))); + const OrtNode* supported_node = nodes[index]; + supported_nodes.push_back(supported_node); } @@ -2122,14 +1815,14 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ } OrtStatus* status; - if (GraphHasCtxNode(graphs[fused_node_idx], ort_api)) { + if (EPContextNodeHelper::GraphHasCtxNode(graphs[fused_node_idx], ort_api)) { RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[fused_node_idx], fused_node, input_map, output_map, &node_compute_infos_result[fused_node_idx])); } else { RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[fused_node_idx], fused_node, input_map, - output_map, &node_compute_infos_result[fused_node_idx]), - &ep_context_nodes_result[fused_node_idx]); + output_map, &node_compute_infos_result[fused_node_idx], + &ep_context_nodes_result[fused_node_idx])); } } diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index a595bcf8..9de811dd 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -253,7 +253,8 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { OrtStatus* CreateNodeComputeInfoFromGraph(OrtEp* this_ptr, const OrtGraph* graph, const OrtNode* fused_node, std::unordered_map& input_map, std::unordered_map& output_map, - OrtNodeComputeInfo** node_compute_info); + OrtNodeComputeInfo** node_compute_info, + OrtNode** ep_context_node); OrtStatus* RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path, std::string& weight_stripped_engine_cath_path, bool path_check, From 75240a410974af01e631ae5b29252dd2d6722f93 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 10 Jul 2025 15:02:26 -0700 Subject: [PATCH 20/60] Convert onnxruntime::Status to OrtStatus --- .../tensorrt_execution_provider_info.cc | 16 ++-- .../tensorrt_execution_provider_info.h | 7 +- .../tensorrt_execution_provider_utils.h | 82 +++++++++++++++++-- .../tensorrt/utils/cuda/cuda_call.h | 21 ++--- .../tensorrt/utils/cuda/cuda_common.h | 6 +- .../tensorrt/utils/make_string.h | 4 - .../tensorrt/utils/parse_string.h | 12 +-- .../tensorrt/utils/provider_options_utils.h | 44 +++++----- 8 files changed, 118 insertions(+), 74 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc index c7062af5..c9154bbf 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc @@ -59,29 +59,29 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions TensorrtExecutionProviderInfo info{}; void* user_compute_stream = nullptr; - ORT_THROW_IF_ERROR( + THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( tensorrt::provider_option_names::kDeviceId, - [&info](const std::string& value_str) -> Status { - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); + [&info](const std::string& value_str) -> OrtStatus* { + RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); int num_devices{}; CUDA_RETURN_IF_ERROR(cudaGetDeviceCount(&num_devices)); - ORT_RETURN_IF_NOT( + RETURN_IF_NOT( 0 <= info.device_id && info.device_id < num_devices, "Invalid device ID: ", info.device_id, ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); - return Status::OK(); + return nullptr; }) .AddAssignmentToReference(tensorrt::provider_option_names::kMaxPartitionIterations, info.max_partition_iterations) .AddAssignmentToReference(tensorrt::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream) .AddValueParser( tensorrt::provider_option_names::kUserComputeStream, - [&user_compute_stream](const std::string& value_str) -> Status { + [&user_compute_stream](const std::string& value_str) -> OrtStatus* { size_t address; - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); user_compute_stream = reinterpret_cast(address); - return Status::OK(); + return nullptr; }) .AddAssignmentToReference(tensorrt::provider_option_names::kMinSubgraphSize, info.min_subgraph_size) .AddAssignmentToReference(tensorrt::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h index 16304db1..3dec464d 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h @@ -3,9 +3,10 @@ #pragma once -#include +#include "tensorrt_execution_provider_utils.h" #include "provider_options.h" -#include "common.h" + +#include #define TRT_DEFAULT_OPTIMIZER_LEVEL 3 @@ -54,7 +55,7 @@ struct TensorrtExecutionProviderInfo { std::string engine_cache_prefix{""}; bool engine_hw_compatible{false}; - static TensorrtExecutionProviderInfo FromProviderOptions(const onnxruntime::ProviderOptions& options); + static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); // static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info); // static ProviderOptions ToProviderOptions(const OrtTensorRTProviderOptionsV2& info); // static void UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy); diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h index e03111e0..0bdac3cb 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h @@ -6,6 +6,7 @@ #include "flatbuffers/idl.h" #include "ort_trt_int8_cal_table.fbs.h" +#include "make_string.h" // #include "core/providers/cuda/cuda_pch.h" // #include "core/common/path_string.h" // #include "core/framework/murmurhash3.h" @@ -21,6 +22,26 @@ #include #include +struct ApiPtrs { + const OrtApi& ort_api; + const OrtEpApi& ep_api; + const OrtModelEditorApi& model_editor_api; +}; + +const OrtApi* g_ort_api = nullptr; +const OrtEpApi* g_ep_api = nullptr; +const OrtModelEditorApi* g_model_editor_api = nullptr; + +#define ENFORCE(condition, ...) \ + do { \ + if (!(condition)) { \ + throw std::runtime_error(MakeString(__VA_ARGS__)); \ + } \ + } while (false) + +#define THROW(...) \ + throw std::runtime_error(MakeString(__VA_ARGS__)); + #define RETURN_IF_ERROR(fn) \ do { \ OrtStatus* _status = (fn); \ @@ -29,17 +50,60 @@ } \ } while (0) -#define RETURN_IF(cond, ort_api, msg) \ - do { \ - if ((cond)) { \ - return (ort_api).CreateStatus(ORT_EP_FAIL, (msg)); \ - } \ +/* +template +std::string ComposeString(Args&&... args) { + std::ostringstream oss; + (oss << ... << args); + return oss.str(); +}; +*/ + +#define RETURN_IF(cond, ...) \ + do { \ + if ((cond)) { \ + return Ort::GetApi().CreateStatus(ORT_EP_FAIL, MakeString(__VA_ARGS__).c_str()); \ + } \ + } while (0) + +#define RETURN_IF_NOT(condition, ...) RETURN_IF(!(condition), __VA_ARGS__) + +#define MAKE_STATUS(error_code, msg) \ + Ort::GetApi().CreateStatus(error_code, (msg)); + +#define THROW_IF_ERROR(expr) \ + do { \ + auto _status = (expr); \ + if (_status != nullptr) { \ + std::ostringstream oss; \ + oss << Ort::GetApi().GetErrorMessage(_status); \ + Ort::GetApi().ReleaseStatus(_status); \ + throw std::runtime_error(oss.str()); \ + } \ } while (0) -struct ApiPtrs { - const OrtApi& ort_api; - const OrtEpApi& ep_api; - const OrtModelEditorApi& model_editor_api; +// Helper to release Ort one or more objects obtained from the public C API at the end of their scope. +template +struct DeferOrtRelease { + DeferOrtRelease(T** object_ptr, std::function release_func) + : objects_(object_ptr), count_(1), release_func_(release_func) {} + + DeferOrtRelease(T** objects, size_t count, std::function release_func) + : objects_(objects), count_(count), release_func_(release_func) {} + + ~DeferOrtRelease() { + if (objects_ != nullptr && count_ > 0) { + for (size_t i = 0; i < count_; ++i) { + if (objects_[i] != nullptr) { + release_func_(objects_[i]); + objects_[i] = nullptr; + } + } + } + } + T** objects_ = nullptr; + size_t count_ = 0; + std::function release_func_ = nullptr; }; namespace fs = std::filesystem; diff --git a/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h b/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h index 81d5975c..ada25ab7 100644 --- a/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h +++ b/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h @@ -2,9 +2,6 @@ // Licensed under the MIT License. #pragma once -#include "../common.h" - -namespace onnxruntime { // ----------------------------------------------------------------------- // Error handling @@ -12,11 +9,11 @@ namespace onnxruntime { // template const char* CudaErrString(ERRTYPE) { - ORT_NOT_IMPLEMENTED(); + THROW(); } template -std::conditional_t CudaCall( +std::conditional_t CudaCall( ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line) { if (retCode != successCode) { try { @@ -41,22 +38,20 @@ std::conditional_t CudaCall( file, line, exprString, msg); if constexpr (THRW) { // throw an exception with the error info - ORT_THROW(str); + THROW(str); } else { - //LOGS_DEFAULT(ERROR) << str; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str); + return MAKE_STATUS(ORT_EP_FAIL, str); } } catch (const std::exception& e) { // catch, log, and rethrow since CUDA code sometimes hangs in destruction, so we'd never get to see the error if constexpr (THRW) { - ORT_THROW(e.what()); + THROW(e.what()); } else { - //LOGS_DEFAULT(ERROR) << e.what(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what()); + return MAKE_STATUS(ORT_EP_FAIL, e.what()); } } } if constexpr (!THRW) { - return Status::OK(); + return nullptr; } } @@ -65,5 +60,3 @@ std::conditional_t CudaCall( //ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line); #define CUDA_CALL(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) - -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/cuda/cuda_common.h b/plugin_execution_providers/tensorrt/utils/cuda/cuda_common.h index b00ef3f9..38f9d147 100644 --- a/plugin_execution_providers/tensorrt/utils/cuda/cuda_common.h +++ b/plugin_execution_providers/tensorrt/utils/cuda/cuda_common.h @@ -5,10 +5,8 @@ #include "cuda_call.h" -namespace onnxruntime { namespace cuda { -#define CUDA_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDA_CALL(expr)) +#define CUDA_RETURN_IF_ERROR(expr) RETURN_IF_ERROR(CUDA_CALL(expr)) -} // namespace cuda -} // namespace onnxruntime +} // namespace cuda \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/utils/make_string.h b/plugin_execution_providers/tensorrt/utils/make_string.h index 826898de..a21be30b 100644 --- a/plugin_execution_providers/tensorrt/utils/make_string.h +++ b/plugin_execution_providers/tensorrt/utils/make_string.h @@ -21,8 +21,6 @@ #include #include -namespace onnxruntime { - namespace detail { inline void MakeStringImpl(std::ostringstream& /*ss*/) noexcept { @@ -122,5 +120,3 @@ inline std::string MakeStringWithClassicLocale(const std::string& str) { inline std::string MakeStringWithClassicLocale(const char* cstr) { return cstr; } - -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/parse_string.h b/plugin_execution_providers/tensorrt/utils/parse_string.h index ce404607..b10d0dfc 100644 --- a/plugin_execution_providers/tensorrt/utils/parse_string.h +++ b/plugin_execution_providers/tensorrt/utils/parse_string.h @@ -8,10 +8,6 @@ #include #include -#include "common.h" - -namespace onnxruntime { - /** * Tries to parse a value from an entire string. */ @@ -67,9 +63,9 @@ inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) { * Parses a value from an entire string. */ template -Status ParseStringWithClassicLocale(std::string_view s, T& value) { - ORT_RETURN_IF_NOT(TryParseStringWithClassicLocale(s, value), "Failed to parse value: \"", value, "\""); - return Status::OK(); +OrtStatus* ParseStringWithClassicLocale(std::string_view s, T& value) { + RETURN_IF_NOT(TryParseStringWithClassicLocale(s, value), "Failed to parse value: \"", value, "\""); + return nullptr; } /** @@ -81,5 +77,3 @@ T ParseStringWithClassicLocale(std::string_view s) { ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(s, value)); return value; } - -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/provider_options_utils.h b/plugin_execution_providers/tensorrt/utils/provider_options_utils.h index c7380b36..f190f20c 100644 --- a/plugin_execution_providers/tensorrt/utils/provider_options_utils.h +++ b/plugin_execution_providers/tensorrt/utils/provider_options_utils.h @@ -9,12 +9,11 @@ #include #include -#include "common.h" +#include "onnxruntime_c_api.h" +#include "../tensorrt_execution_provider_utils.h" #include "parse_string.h" #include "provider_options.h" -namespace onnxruntime { - template using EnumNameMapping = std::vector>; @@ -22,23 +21,23 @@ using EnumNameMapping = std::vector>; * Given a mapping and an enumeration value, gets the corresponding name. */ template -Status EnumToName(const EnumNameMapping& mapping, TEnum value, std::string& name) { +OrtStatus* EnumToName(const EnumNameMapping& mapping, TEnum value, std::string& name) { const auto it = std::find_if( mapping.begin(), mapping.end(), [&value](const std::pair& entry) { return entry.first == value; }); - ORT_RETURN_IF( + RETURN_IF( it == mapping.end(), "Failed to map enum value to name: ", static_cast::type>(value)); name = it->second; - return Status::OK(); + return nullptr; } template std::string EnumToName(const EnumNameMapping& mapping, TEnum value) { std::string name; - ORT_THROW_IF_ERROR(EnumToName(mapping, value, name)); + THROW_IF_ERROR(EnumToName(mapping, value, name)); return name; } @@ -46,24 +45,24 @@ std::string EnumToName(const EnumNameMapping& mapping, TEnum value) { * Given a mapping and a name, gets the corresponding enumeration value. */ template -Status NameToEnum( +OrtStatus* NameToEnum( const EnumNameMapping& mapping, const std::string& name, TEnum& value) { const auto it = std::find_if( mapping.begin(), mapping.end(), [&name](const std::pair& entry) { return entry.second == name; }); - ORT_RETURN_IF( + RETURN_IF( it == mapping.end(), "Failed to map enum name to value: ", name); value = it->first; - return Status::OK(); + return nullptr; } template TEnum NameToEnum(const EnumNameMapping& mapping, const std::string& name) { TEnum value; - ORT_THROW_IF_ERROR(NameToEnum(mapping, name, value)); + THROW_IF_ERROR(NameToEnum(mapping, name, value)); return value; } @@ -83,7 +82,7 @@ class ProviderOptionsParser { template ProviderOptionsParser& AddValueParser( const std::string& name, ValueParserType value_parser) { - ORT_ENFORCE( + ENFORCE( value_parsers_.emplace(name, ValueParser{value_parser}).second, "Provider option \"", name, "\" already has a value parser."); return *this; @@ -106,7 +105,7 @@ class ProviderOptionsParser { const std::string& name, ValueType& dest) { return AddValueParser( name, - [&dest](const std::string& value_str) -> Status { + [&dest](const std::string& value_str) -> OrtStatus* { return ParseStringWithClassicLocale(value_str, dest); }); } @@ -130,7 +129,7 @@ class ProviderOptionsParser { const std::string& name, const EnumNameMapping& mapping, EnumType& dest) { return AddValueParser( name, - [&mapping, &dest](const std::string& value_str) -> Status { + [&mapping, &dest](const std::string& value_str) -> OrtStatus* { return NameToEnum(mapping, value_str, dest); }); } @@ -138,27 +137,26 @@ class ProviderOptionsParser { /** * Parses the given provider options. */ - Status Parse(const ProviderOptions& options) const { + OrtStatus* Parse(const ProviderOptions& options) const { for (const auto& option : options) { const auto& name = option.first; const auto& value_str = option.second; const auto value_parser_it = value_parsers_.find(name); - ORT_RETURN_IF( + RETURN_IF( value_parser_it == value_parsers_.end(), "Unknown provider option: \"", name, "\"."); const auto parse_status = value_parser_it->second(value_str); - ORT_RETURN_IF_NOT( - parse_status.IsOK(), - "Failed to parse provider option \"", name, "\": ", parse_status.ErrorMessage()); + RETURN_IF_NOT( + (parse_status == nullptr), + "Failed to parse provider option \"", name, "\": "); + //"Failed to parse provider option \"", name, "\": ", parse_status.ErrorMessage()); } - return Status::OK(); + return nullptr; } private: - using ValueParser = std::function; + using ValueParser = std::function; std::unordered_map value_parsers_; }; - -} // namespace onnxruntime From f73420f686071688c59c8e0125821449122dfa08 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 10 Jul 2025 15:43:50 -0700 Subject: [PATCH 21/60] remove unused files --- .../tensorrt/ep_abi_utils.cc | 12 --- .../tensorrt/ep_abi_utils.h | 77 ------------------- 2 files changed, 89 deletions(-) delete mode 100644 plugin_execution_providers/tensorrt/ep_abi_utils.cc delete mode 100644 plugin_execution_providers/tensorrt/ep_abi_utils.h diff --git a/plugin_execution_providers/tensorrt/ep_abi_utils.cc b/plugin_execution_providers/tensorrt/ep_abi_utils.cc deleted file mode 100644 index bc0b6eeb..00000000 --- a/plugin_execution_providers/tensorrt/ep_abi_utils.cc +++ /dev/null @@ -1,12 +0,0 @@ -#define ORT_API_MANUAL_INIT -#include "onnxruntime_cxx_api.h" -#undef ORT_API_MANUAL_INIT - -#include -#include -#include -#include -#include -#include -#include - diff --git a/plugin_execution_providers/tensorrt/ep_abi_utils.h b/plugin_execution_providers/tensorrt/ep_abi_utils.h deleted file mode 100644 index 308a49da..00000000 --- a/plugin_execution_providers/tensorrt/ep_abi_utils.h +++ /dev/null @@ -1,77 +0,0 @@ -#pragma once - -#include -#include - -#include "onnxruntime_c_api.h" - -#define RETURN_IF_ERROR(fn) \ - do { \ - OrtStatus* status = (fn); \ - if (status != nullptr) { \ - return status; \ - } \ - } while (0) - -#define RETURN_IF(cond, ort_api, msg) \ - do { \ - if ((cond)) { \ - return (ort_api).CreateStatus(ORT_EP_FAIL, (msg)); \ - } \ - } while (0) - -#define RETURN_FALSE_AND_PRINT_IF_ERROR(fn, ort_api) \ - do { \ - OrtStatus* status = (fn); \ - if (status != nullptr) { \ - std::cerr << (ort_api).GetErrorMessage(status) << std::endl ; \ - return false; \ - } \ - } while (0) - -struct OrtArrayOfConstObjects { - OrtArrayOfConstObjects() = default; - explicit OrtArrayOfConstObjects(OrtTypeTag object_type) : object_type(object_type) {} - OrtArrayOfConstObjects(OrtTypeTag object_type, size_t size, const void* initial_val = nullptr) - : object_type(object_type), storage(size, initial_val) {} - - OrtTypeTag object_type = OrtTypeTag::ORT_TYPE_TAG_Void; - std::vector storage; -}; - -// Convert an OrtArrayOfConstObjects into a span of Ort___ pointers. -template -static void GetSpanFromArrayOfConstObjects(const OrtArrayOfConstObjects* ort_array, - /*out*/ gsl::span& span) { - const OrtApi& ort_api = Ort::GetApi(); - - size_t size = 0; - ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetSize(ort_array, &size)); - - const void* const* raw_data = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetData(ort_array, &raw_data)); - - auto data = reinterpret_cast(raw_data); - span = gsl::span(data, size); -} - -// Helper to release a C API Ort object at the end of its scope. -// Useful when not using the public C++ API. -// Example: -// { -// OrtTensorTypeAndShapeInfo* info = nullptr; -// DeferOrtRelease defer_release(&info, c_api.ReleaseTensorTypeAndShapeInfo); -// ... -// } /* Release is called at end of scope*/ -template -struct DeferOrtRelease { - DeferOrtRelease(T** obj_ptr, std::function release_func) : obj_ptr_(obj_ptr), release_func_(release_func) {} - ~DeferOrtRelease() { - if (obj_ptr_ != nullptr && *obj_ptr_ != nullptr) { - release_func_(*obj_ptr_); - *obj_ptr_ = nullptr; - } - } - T** obj_ptr_ = nullptr; - std::function release_func_ = nullptr; -}; \ No newline at end of file From 938a3fe60d9bffbb2e5574ce108d7875ffb9f8da Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 10 Jul 2025 15:44:50 -0700 Subject: [PATCH 22/60] use GetSessionOptionsConfigEntries to get provider options --- .../tensorrt/tensorrt_execution_provider.cc | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index fb5a0bc9..6cc3d3f5 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -12,7 +12,6 @@ #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL #include "ort_graph_to_proto.h" -#include "ep_abi_utils.h" //#include "tensorrt_execution_provider_utils.h" #include "tensorrt_execution_provider.h" #include "cuda_allocator.h" @@ -1968,27 +1967,30 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa // the session option configurations with the key prefix "ep..". std::string key_prefix = "ep." + lowercase_ep_name + "."; - /* - // Get provider options as key-value pair strings + // Get all the provider options as session config from sesson ProviderOptions provider_options; - for (const auto& [key, value] : config_options_map) { - if (key.rfind(key_prefix, 0) == 0) { - provider_options[key.substr(key_prefix.size())] = value; + + // Get the provider options from all the config entries in session option + OrtKeyValuePairs* key_value_pairs = nullptr; + ort_api.GetSessionOptionsConfigEntries(&session_options, &key_value_pairs); + + const char* const* keys = nullptr; + const char* const* values = nullptr; + size_t num_entries = 0; + ort_api.GetKeyValuePairs(key_value_pairs, &keys, &values, &num_entries); + + for (size_t i = 0; i < num_entries; ++i) { + const char* key = keys[i]; + + // only gets ep provider options + if (strncmp(key, key_prefix.c_str(), key_prefix.size()) == 0) { + std::string key_str = key; + const char* value = values[i]; + provider_options[key_str.substr(key_prefix.size())] = value; } } - */ - // Get all the provider options as session config from sesson - ProviderOptions provider_options; - int has_session_config_entry = 0; - std::string provider_option = key_prefix + "trt_engine_cache_enable"; - auto status = ort_api.HasSessionConfigEntry(&session_options, provider_option.c_str(), & has_session_config_entry); - if (has_session_config_entry) { - char* value = nullptr; - size_t size = 0; - status = ort_api.GetSessionConfigEntry(&session_options, provider_option.c_str(), value, &size); - provider_options[provider_option.substr(key_prefix.size())] = value; - } + ort_api.ReleaseKeyValuePairs(key_value_pairs); // Provider options to TensorrtExecutionProviderInfo info_ = TensorrtExecutionProviderInfo::FromProviderOptions(provider_options); From 731ed7207ae4199125e0885646aff119c000ba86 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 10 Jul 2025 18:29:05 -0700 Subject: [PATCH 23/60] fix a bunch of compile errors --- .../tensorrt/cuda_allocator.h | 6 +- .../tensorrt/tensorrt_execution_provider.cc | 57 ++++++------------- .../tensorrt/tensorrt_execution_provider.h | 16 ++---- .../tensorrt_execution_provider_utils.h | 9 +++ .../tensorrt/tensorrt_provider_factory.cc | 15 +++-- .../tensorrt/tensorrt_provider_factory.h | 2 + .../tensorrt/utils/ort_graph_to_proto.h | 2 +- 7 files changed, 45 insertions(+), 62 deletions(-) diff --git a/plugin_execution_providers/tensorrt/cuda_allocator.h b/plugin_execution_providers/tensorrt/cuda_allocator.h index 37e7f462..4b6a8565 100644 --- a/plugin_execution_providers/tensorrt/cuda_allocator.h +++ b/plugin_execution_providers/tensorrt/cuda_allocator.h @@ -13,12 +13,14 @@ constexpr const char* CUDA_PINNED_ALLOCATOR = "CudaPinned"; using DeviceId = int16_t; struct CUDAAllocator : OrtAllocator { - CUDAAllocator(DeviceId device_id, const char* name = CUDA_ALLOCATOR) { + CUDAAllocator(const OrtMemoryInfo* mem_info, const char* name = CUDA_ALLOCATOR) { OrtAllocator::version = ORT_API_VERSION; OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; + mem_info_ = mem_info; + device_id_ = device_id; const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); @@ -44,7 +46,7 @@ struct CUDAAllocator : OrtAllocator { void SetDevice(bool throw_when_fail) const; DeviceId device_id_; - OrtMemoryInfo* mem_info_ = nullptr; + const OrtMemoryInfo* mem_info_ = nullptr; }; struct CUDAPinnedAllocator : OrtAllocator { diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 6cc3d3f5..d2593a23 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -1156,7 +1156,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this weight_stripped_engine_refit_ = true; } - std::unique_ptr serialized_engine = nullptr; + std::unique_ptr serialized_engine; if (!has_dynamic_shape) { std::string timing_cache_path = ""; @@ -1258,7 +1258,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } serialized_engine = - std::make_unique(trt_builder->buildSerializedNetwork(*trt_network, *trt_config)); + std::unique_ptr(trt_builder->buildSerializedNetwork(*trt_network, *trt_config)); if (serialized_engine == nullptr) { std::string err_msg = "TensorRT EP failed to create engine from network for fused node: " + fused_node_name; @@ -1390,32 +1390,9 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this input_shape_ranges_[fused_node_name] = input_implicit_shape_ranges; profiles_.emplace(fused_node_name, std::move(trt_profiles)); - /* - // For dynamic shape input model, firstly TRT EP creates a model proto which includes inputs, outputs and empty - // engine. TRT EP will serialize the model at inference time due to engine can be updated and the updated engine - // should be included in the model. However, if the embed_mode is 0 (only includes engine path), TRT EP will serialize - // it here. - if (dump_ep_context_model_ && has_dynamic_shape) { - // "ep_cache_context" node attribute should be a relative path to context model directory - if (ep_cache_context_attr_.empty()) { - auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); - ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir) - .append(cache_file_name.string()) - .string(); - } - std::string compute_capability_hw_compat = compute_capability_; - if (engine_cache_enable_ && engine_hw_compatible_) { - compute_capability_hw_compat = "80+"; - } - model_proto_.reset(CreateCtxModel(graph_body_viewer, ep_cache_context_attr_, nullptr, 0, ep_context_embed_mode_, - compute_capability_hw_compat, model_path_, GetLogger())); - if (ep_context_embed_mode_ == 0) { - DumpCtxModel(model_proto_.get(), ctx_model_path_); - } - } - */ - std::unique_ptr ep_ctx_node_helper = std::make_unique(graph, fused_node); + // Create EP Context nodes + std::unique_ptr ep_ctx_node_helper = std::make_unique(*ep, graph, fused_node); if (dump_ep_context_model_) { std::string compute_capability_hw_compat = compute_capability_; if (engine_cache_enable_ && engine_hw_compatible_) { @@ -1490,6 +1467,8 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this engine_hw_compatible_, sync_stream_after_enqueue_}; + ep->compute_states_[fused_node_name] = std::move(compute_state); + // Update the OrtNodeComputeInfo associated with the graph. auto ep_node_compute_info = std::make_unique(*ep); *node_compute_info = ep_node_compute_info.release(); @@ -1554,10 +1533,10 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this auto supported_control_flow_op = [&](const OrtNode* node) { OrtStatus* status = nullptr; size_t num_subgraphs = 0; - RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetNumSubgraphs(node, &num_subgraphs), ort_api); + RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetNumSubgraphs(node, &num_subgraphs)); std::vector node_subgraphs(num_subgraphs); - RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetSubgraphs(node, node_subgraphs.data(), node_subgraphs.size(), nullptr), ort_api); + RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetSubgraphs(node, node_subgraphs.data(), node_subgraphs.size(), nullptr)); // Iterate the node's subgraphs @@ -1566,7 +1545,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this // Get number of subgraph's nodes size_t num_subgraph_nodes = 0; - RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Graph_GetNumNodes(subgraph, &num_subgraph_nodes), ort_api); + RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Graph_GetNumNodes(subgraph, &num_subgraph_nodes)); // TRT EP should consider the empty subgraph is fully supported by TRT. if (num_subgraph_nodes == 0) { @@ -1926,13 +1905,11 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine( /// TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFactory& factory, const std::string& name, - const OrtHardwareDevice& device, const OrtSessionOptions& session_options, const OrtLogger& logger) : ApiPtrs{static_cast(factory)}, factory_(factory), name_{name}, - hardware_device_{device}, session_options_{session_options}, logger_{logger} { @@ -2176,7 +2153,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa * Please refer to ParserProfileShapes() for more details) * */ - bool status = true; + // bool status = true; // if (status) { // status = ParseProfileShapes(profile_min_shapes, profile_min_shapes_); // if (!status) { @@ -2266,14 +2243,14 @@ OrtStatus* TRTEpNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, O TensorrtExecutionProvider& ep = node_compute_info->ep; std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context); - auto state_it = ep.GetComputeStates().find(fused_node_name); - if (state_it == ep.GetComputeStates().end()) { + auto state_it = ep.compute_states_.find(fused_node_name); + if (state_it == ep.compute_states_.end()) { std::string message = "Unable to TensorRT EP's compute state for fused node with name " + fused_node_name; return ep.ort_api.CreateStatus(ORT_EP_FAIL, message.c_str()); } - TensorrtComputeState& compute_state = *state_it->second; - *compute_state = &compute_state; + TensorrtComputeState& trt_ep_compute_state = *state_it->second; + *compute_state = &trt_ep_compute_state; return nullptr; } @@ -2335,7 +2312,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* bool context_update = false; std::unordered_set input_names; - std::unordered_map dds_output_allocator_maps = ep.GetDDSOutputAllocators(); + std::unordered_map& dds_output_allocator_maps = ep.GetDDSOutputAllocators(); auto& dds_output_allocator_map = dds_output_allocator_maps[fused_node_name]; // Get default OrtMemoryInfo from factory @@ -2911,7 +2888,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* void TRTEpNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) { (void)this_ptr; - TensorrtComputeState& compute_state = *reinterpret_cast(compute_state); - (void)compute_state; + TensorrtComputeState& trt_ep_compute_state = *reinterpret_cast(compute_state); + (void)trt_ep_compute_state; // Do nothing for here. } diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index 9de811dd..6d30f59a 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -13,6 +13,7 @@ #include #include #include +#include #ifdef _WIN32 #define EXPORT_API __declspec(dllexport) @@ -231,16 +232,18 @@ static const std::string k_ep_ctx_onnx_model_filename = "onnx_model_filename"; /// struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { TensorrtExecutionProvider(TensorrtExecutionProviderFactory& factory, const std::string& name, - const OrtHardwareDevice& device, const OrtSessionOptions& session_options, + const OrtSessionOptions& session_options, const OrtLogger& logger); ~TensorrtExecutionProvider(); TensorrtExecutionProviderFactory& factory_; std::string name_; - const OrtHardwareDevice& hardware_device_; const OrtSessionOptions& session_options_; const OrtLogger& logger_; + std::unordered_map> compute_states_; + std::unordered_map> compute_states_for_ep_context_; + SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations, const int max_iterations, const OrtGraph* graph, bool* early_termination) const; @@ -262,12 +265,6 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, bool detailed_build_log); - std::unordered_map>& GetComputeStates() { return compute_states_; } - - std::unordered_map>& GetComputeStatesForEPContext() { - return compute_states_; - } - void GetAllocator(OrtAllocator** alloc) const { *alloc = alloc_; } void SetAllocator(OrtAllocator* alloc) { alloc_ = alloc; } @@ -415,9 +412,6 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { std::unordered_map> profiles_; std::unordered_map dds_output_allocator_maps_; - std::unordered_map> compute_states_; - std::unordered_map> compute_states_for_ep_context; - // for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture // cudnnHandle_t external_cudnn_handle_ = nullptr; // cublasHandle_t external_cublas_handle_ = nullptr; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h index 0bdac3cb..49c11286 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h @@ -82,6 +82,15 @@ std::string ComposeString(Args&&... args) { } \ } while (0) +#define RETURN_FALSE_AND_PRINT_IF_ERROR(fn) \ + do { \ + OrtStatus* status = (fn); \ + if (status != nullptr) { \ + std::cerr << Ort::GetApi().GetErrorMessage(status) << std::endl; \ + return false; \ + } \ + } while (0) + // Helper to release Ort one or more objects obtained from the public C API at the end of their scope. template struct DeferOrtRelease { diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index f8116a32..fde2f39d 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -31,7 +31,7 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* e // Default GPU allocator OrtMemoryInfo OrtMemoryInfo* mem_info = nullptr; - auto* status = ort_api.CreateMemoryInfo_V2("ExampleEP GPU", OrtMemoryInfoDeviceType_GPU, + auto* status = ort_api.CreateMemoryInfo_V2("Cuda", OrtMemoryInfoDeviceType_GPU, /*vendor*/ 0x10DE, /* device_id */ 0, OrtDeviceMemoryType_DEFAULT, /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info); assert(status == nullptr); // should never fail. @@ -40,7 +40,7 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* e // CUDA PINNED allocator OrtMemoryInfo // HOST_ACCESSIBLE memory should use the non-CPU device type mem_info = nullptr; - status = ort_api.CreateMemoryInfo_V2("ExampleEP GPU pinned", OrtMemoryInfoDeviceType_GPU, + status = ort_api.CreateMemoryInfo_V2("CudaPinned", OrtMemoryInfoDeviceType_GPU, /*vendor*/ 0x10DE, /* device_id */ 0, OrtDeviceMemoryType_HOST_ACCESSIBLE, /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info); assert(status == nullptr); // should never fail. @@ -56,12 +56,12 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* e data_transfer_impl_.reset(); // but we're CPU only so we return nullptr for the IDataTransfer. } -const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetNameImpl(const OrtEpFactory* this_ptr) { +const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetNameImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); return factory->ep_name_.c_str(); } -const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetVendorImpl(const OrtEpFactory* this_ptr) { +const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); return factory->vendor_.c_str(); } @@ -72,7 +72,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp size_t num_devices, OrtEpDevice** ep_devices, size_t max_ep_devices, - size_t* p_num_ep_devices) { + size_t* p_num_ep_devices) noexcept { size_t& num_ep_devices = *p_num_ep_devices; auto* factory = static_cast(this_ptr); @@ -133,8 +133,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateEpImpl( _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/, _In_ size_t num_devices, _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, - _Out_ OrtEp** ep) { + _In_ const OrtLogger* logger, _Out_ OrtEp** ep) noexcept { auto* factory = static_cast(this_ptr); *ep = nullptr; @@ -161,7 +160,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateEpImpl( return nullptr; } -void ORT_API_CALL TensorrtExecutionProviderFactory::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) { +void ORT_API_CALL TensorrtExecutionProviderFactory::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept { TensorrtExecutionProvider* trt_ep = static_cast(ep); delete trt_ep; } diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h index e4222e92..5d2d476c 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -1,3 +1,5 @@ +#pragma once + #include "tensorrt_execution_provider_utils.h" #include "tensorrt_execution_provider_data_transfer.h" diff --git a/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h b/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h index 37665542..16b346c7 100644 --- a/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h +++ b/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h @@ -81,7 +81,7 @@ #define INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ #include -#include "core/session/onnxruntime_cxx_api.h" +#include "onnxruntime_cxx_api.h" #include "onnx/onnx_pb.h" namespace OrtEpUtils { From 30e0f91191402afb2a8eef9724780b69ddc1985a Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Sun, 13 Jul 2025 23:22:10 -0700 Subject: [PATCH 24/60] update memory info and data transfer in TRT EP's factor to accommodate mutiple GPU devices --- .../tensorrt/cuda_allocator.h | 17 +- .../tensorrt/tensorrt_execution_provider.cc | 6 +- .../tensorrt/tensorrt_execution_provider.h | 2 + ...nsorrt_execution_provider_data_transfer.cc | 19 ++- ...ensorrt_execution_provider_data_transfer.h | 10 +- .../tensorrt/tensorrt_provider_factory.cc | 151 ++++++++++++------ .../tensorrt/tensorrt_provider_factory.h | 34 +++- 7 files changed, 158 insertions(+), 81 deletions(-) diff --git a/plugin_execution_providers/tensorrt/cuda_allocator.h b/plugin_execution_providers/tensorrt/cuda_allocator.h index 4b6a8565..eb5ac144 100644 --- a/plugin_execution_providers/tensorrt/cuda_allocator.h +++ b/plugin_execution_providers/tensorrt/cuda_allocator.h @@ -13,22 +13,13 @@ constexpr const char* CUDA_PINNED_ALLOCATOR = "CudaPinned"; using DeviceId = int16_t; struct CUDAAllocator : OrtAllocator { - CUDAAllocator(const OrtMemoryInfo* mem_info, const char* name = CUDA_ALLOCATOR) { + CUDAAllocator(const OrtMemoryInfo* mem_info, DeviceId device_id) : mem_info_(mem_info), device_id_(device_id) { OrtAllocator::version = ORT_API_VERSION; - OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; + OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { + return static_cast(this_)->Alloc(size); + }; OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; - - mem_info_ = mem_info; - - device_id_ = device_id; - - const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); - api->CreateMemoryInfo(name, - OrtAllocatorType::OrtDeviceAllocator, - static_cast(device_id), - OrtMemType::OrtMemTypeDefault, - &mem_info_); } // TODO: Handle destructor //~CUDAAllocator(); diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index d2593a23..4b126004 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -654,7 +654,7 @@ OrtStatusPtr BindContextOutput(Ort::KernelContext& ctx, } OrtStatusPtr BindKernelOutput(Ort::KernelContext& ctx, - OrtMemoryInfo* /*mem_info*/, + const OrtMemoryInfo* /*mem_info*/, DDSOutputAllocatorMap& allocator_map, char const* output_name, size_t output_index, @@ -1416,6 +1416,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this tactics = GetTacticSourceFromString(tactic_sources_); } *compute_state = { + static_cast(device_id_), fused_node_name, builder_.get(), &parsers_[fused_node_name], @@ -2281,6 +2282,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input + uint16_t device_id = trt_state->device_id; auto max_workspace_size = trt_state->max_workspace_size; auto trt_builder = trt_state->builder; auto trt_engine = trt_state->engine->get(); @@ -2317,7 +2319,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* // Get default OrtMemoryInfo from factory // Get allocator from OrtKernelContext - OrtMemoryInfo* mem_info = ep.factory_.GetDefaultMemInfo(); + const OrtMemoryInfo* mem_info = ep.factory_.GetDefaultGpuMemInfoForDeviceId(device_id); OrtAllocator* alloc = nullptr; ep.GetAllocator(&alloc); if (alloc == nullptr) { diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index 6d30f59a..f2ae4d45 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -153,6 +153,7 @@ class OutputAllocator : public nvinfer1::IOutputAllocator { }; struct TensorrtComputeState { + uint32_t device_id; std::string fused_node_name; nvinfer1::IBuilder* builder; tensorrt_ptr::unique_pointer* parser = nullptr; @@ -207,6 +208,7 @@ struct TensorrtComputeState { // Minimum information to construct kernel function state for direct engine load code path struct TensorrtComputeStateForEPContext { + uint32_t device_id; std::string fused_node_name; std::unique_ptr* engine = nullptr; std::unique_ptr* context = nullptr; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc index 82f9941e..67b868d2 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc @@ -9,14 +9,21 @@ void CUDA_RETURN_IF_ERROR(cudaError_t res); /*static*/ -bool ORT_API_CALL TRTEpDataTransfer::CanCopyImpl(void* this_ptr, - const OrtMemoryDevice* src_memory_device, - const OrtMemoryDevice* dst_memory_device) noexcept { +bool ORT_API_CALL TRTEpDataTransfer::CanCopyImpl(void* this_ptr, const OrtMemoryDevice* src_memory_device, + const OrtMemoryDevice* dst_memory_device) noexcept { auto& impl = *static_cast(this_ptr); - bool src_is_our_device = impl.ep_api.MemoryDevice_AreEqual(src_memory_device, impl.device_mem_info); - bool dst_is_our_device = impl.ep_api.MemoryDevice_AreEqual(dst_memory_device, impl.device_mem_info); - return src_is_our_device || dst_is_our_device; + auto it = std::find_if(impl.cuda_gpu_mem_devices_.begin(), impl.cuda_gpu_mem_devices_.end(), + [&impl, &src_memory_device, &dst_memory_device](const OrtMemoryDevice* memory_device) { + bool src_is_our_device = impl.ep_api.MemoryDevice_AreEqual(src_memory_device, memory_device); + bool dst_is_our_device = impl.ep_api.MemoryDevice_AreEqual(dst_memory_device, memory_device); + return src_is_our_device || dst_is_our_device; + }); + + if (it != impl.cuda_gpu_mem_devices_.end()) { + return true; + } + return false; } // function to copy one or more tensors. diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h index a72ff453..2e5ac808 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h @@ -6,9 +6,9 @@ #include "tensorrt_execution_provider_utils.h" struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs { - TRTEpDataTransfer(ApiPtrs api_ptrs, const OrtMemoryDevice* device_mem_info_, - const OrtMemoryDevice* shared_mem_info_ = nullptr) - : ApiPtrs(api_ptrs), device_mem_info{device_mem_info_}, shared_mem_info{shared_mem_info_} { + TRTEpDataTransfer(ApiPtrs api_ptrs, std::vector device_mem_infos, + std::vector shared_mem_infos) + : ApiPtrs(api_ptrs), cuda_gpu_mem_devices_{device_mem_infos}, cuda_pinned_mem_devices_{shared_mem_infos} { CanCopy = CanCopyImpl; CopyTensors = CopyTensorsImpl; Release = ReleaseImpl; @@ -25,6 +25,6 @@ struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs { static void ORT_API_CALL ReleaseImpl(void* this_ptr) noexcept; private: - const OrtMemoryDevice* device_mem_info; - const OrtMemoryDevice* shared_mem_info; + std::vector cuda_gpu_mem_devices_; + std::vector cuda_pinned_mem_devices_; }; \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index fde2f39d..56517d80 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -28,32 +28,6 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* e ReleaseAllocator = ReleaseAllocatorImpl; CreateDataTransfer = CreateDataTransferImpl; - - // Default GPU allocator OrtMemoryInfo - OrtMemoryInfo* mem_info = nullptr; - auto* status = ort_api.CreateMemoryInfo_V2("Cuda", OrtMemoryInfoDeviceType_GPU, - /*vendor*/ 0x10DE, /* device_id */ 0, OrtDeviceMemoryType_DEFAULT, - /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info); - assert(status == nullptr); // should never fail. - default_gpu_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); - - // CUDA PINNED allocator OrtMemoryInfo - // HOST_ACCESSIBLE memory should use the non-CPU device type - mem_info = nullptr; - status = ort_api.CreateMemoryInfo_V2("CudaPinned", OrtMemoryInfoDeviceType_GPU, - /*vendor*/ 0x10DE, /* device_id */ 0, OrtDeviceMemoryType_HOST_ACCESSIBLE, - /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info); - assert(status == nullptr); // should never fail. - host_accessible_gpu_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); - - // Create gpu data transfer - data_transfer_impl_ = std::make_unique( - apis, - ep_api.MemoryInfo_GetMemoryDevice(default_gpu_memory_info_.get()), // device memory - ep_api.MemoryInfo_GetMemoryDevice(host_accessible_gpu_memory_info_.get()) // shared memory - ); - - data_transfer_impl_.reset(); // but we're CPU only so we return nullptr for the IDataTransfer. } const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetNameImpl(const OrtEpFactory* this_ptr) noexcept { @@ -76,6 +50,9 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp size_t& num_ep_devices = *p_num_ep_devices; auto* factory = static_cast(this_ptr); + std::vector cuda_gpu_mem_devices; + std::vector cuda_pinned_mem_devices; + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { // C API const OrtHardwareDevice& device = *devices[i]; @@ -88,7 +65,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp // The ep options can be provided here as default values. // Users can also call SessionOptionsAppendExecutionProvider_V2 C API with provided ep options to override. - factory->ort_api.AddKeyValuePair(ep_metadata, "version", "0.1"); // random example using made up values + factory->ort_api.AddKeyValuePair(ep_metadata, "gpu_type", "data center"); // random example using made up values factory->ort_api.AddKeyValuePair(ep_options, "trt_builder_optimization_level", "3"); // OrtEpDevice copies ep_metadata and ep_options. @@ -103,25 +80,60 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp return status; } - // register the allocator info required by the EP. - RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->default_gpu_memory_info_.get())); - RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->host_accessible_gpu_memory_info_.get())); + uint32_t vendor_id = factory->ort_api.HardwareDevice_VendorId(&device); + uint32_t device_id = factory->ort_api.HardwareDevice_DeviceId(&device); + + // CUDA allocator OrtMemoryInfo + OrtMemoryInfo* mem_info = nullptr; + status = factory->ort_api.CreateMemoryInfo_V2("Cuda", OrtMemoryInfoDeviceType_GPU, vendor_id, device_id, OrtDeviceMemoryType_DEFAULT, + /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info); + + assert(status == nullptr); // should never fail. + MemoryInfoUniquePtr cuda_gpu_memory_info = MemoryInfoUniquePtr(mem_info, factory->ort_api.ReleaseMemoryInfo); + + // CUDA PINNED allocator OrtMemoryInfo + // HOST_ACCESSIBLE memory should use the non-CPU device type. + mem_info = nullptr; + status = factory->ort_api.CreateMemoryInfo_V2("CudaPinned", OrtMemoryInfoDeviceType_GPU, vendor_id, device_id, OrtDeviceMemoryType_HOST_ACCESSIBLE, + /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info); + + assert(status == nullptr); // should never fail. + MemoryInfoUniquePtr cuda_pinned_memory_info = MemoryInfoUniquePtr(mem_info, factory->ort_api.ReleaseMemoryInfo); + + // Register the allocator info required by TRT EP. + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_gpu_memory_info.get())); + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_pinned_memory_info.get())); + + // Get memory device from memory info for gpu data transfer + cuda_gpu_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_gpu_memory_info.get())); + cuda_pinned_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_pinned_memory_info.get())); + + factory->SetDefaultGpuMemInfo(std::move(cuda_gpu_memory_info), device_id); + factory->SetHostAccessibleMemInfo(std::move(cuda_pinned_memory_info), device_id); ep_devices[num_ep_devices++] = ep_device; } - // C++ API equivalent. Throws on error. - //{ - // Ort::ConstHardwareDevice device(devices[i]); - // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - // Ort::KeyValuePairs ep_metadata; - // Ort::KeyValuePairs ep_options; - // ep_metadata.Add("version", "0.1"); - // ep_options.Add("trt_builder_optimization_level", "3"); - // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()}; - // ep_devices[num_ep_devices++] = ep_device.release(); - // } - //} + // Create gpu data transfer + auto data_transfer_impl = std::make_unique( + static_cast(*factory), + cuda_gpu_mem_devices, // device memory + cuda_pinned_mem_devices // shared memory + ); + + + // C++ API equivalent. Throws on error. + //{ + // Ort::ConstHardwareDevice device(devices[i]); + // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // Ort::KeyValuePairs ep_metadata; + // Ort::KeyValuePairs ep_options; + // ep_metadata.Add("version", "0.1"); + // ep_options.Add("trt_builder_optimization_level", "3"); + // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()}; + // ep_devices[num_ep_devices++] = ep_device.release(); + // } + //} } return nullptr; @@ -181,11 +193,14 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl( // NOTE: The OrtMemoryInfo pointer should only ever be coming straight from an OrtEpDevice, and pointer based // matching should work. - if (memory_info == factory.default_gpu_memory_info_.get()) { + + uint32_t device_id = 0; + + if (factory.GetDeviceIdForDefaultGpuMemInfo(memory_info, &device_id)) { // create a CUDA allocator - auto cuda_allocator = std::make_unique(memory_info); + auto cuda_allocator = std::make_unique(memory_info, static_cast(device_id)); *allocator = cuda_allocator.release(); - } else if (memory_info == factory.host_accessible_gpu_memory_info_.get()) { + } else if (factory.GetDeviceIdForHostAccessibleMemInfo(memory_info, &device_id)) { // create a CUDA PINNED allocator auto cuda_pinned_allocator = std::make_unique(memory_info); *allocator = cuda_pinned_allocator.release(); @@ -212,8 +227,50 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl return nullptr; } -OrtMemoryInfo* TensorrtExecutionProviderFactory::GetDefaultMemInfo() const { - return default_gpu_memory_info_.get(); +bool TensorrtExecutionProviderFactory::GetDeviceIdForDefaultGpuMemInfo(const OrtMemoryInfo* mem_info, uint32_t* device_id) const { + auto iter = cuda_gpu_memory_info_to_device_id_map_.find(mem_info); + if (iter != cuda_gpu_memory_info_to_device_id_map_.end()) { + *device_id = iter->second; + return true; + } + return false; +} + +const OrtMemoryInfo* TensorrtExecutionProviderFactory::GetDefaultGpuMemInfoForDeviceId(uint32_t device_id) const { + auto iter = device_id_to_cuda_gpu_memory_info_map_.find(device_id); + if (iter != device_id_to_cuda_gpu_memory_info_map_.end()) { + return iter->second; + } + return nullptr; +} + +void TensorrtExecutionProviderFactory::SetDefaultGpuMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id) { + cuda_gpu_memory_info_to_device_id_map_[mem_info.get()] = device_id; + device_id_to_cuda_gpu_memory_info_map_[device_id] = mem_info.get(); + cuda_gpu_memory_infos_.push_back(std::move(mem_info)); +} + +bool TensorrtExecutionProviderFactory::GetDeviceIdForHostAccessibleMemInfo(const OrtMemoryInfo* mem_info, uint32_t* device_id) const { + auto iter = cuda_pinned_memory_info_to_device_id_map_.find(mem_info); + if (iter != cuda_pinned_memory_info_to_device_id_map_.end()) { + *device_id = iter->second; + return true; + } + return false; +} + +const OrtMemoryInfo* TensorrtExecutionProviderFactory::GetHostAccessibleMemInfoForDeviceId(uint32_t device_id) const { + auto iter = device_id_to_cuda_pinned_memory_info_map_.find(device_id); + if (iter != device_id_to_cuda_pinned_memory_info_map_.end()) { + return iter->second; + } + return nullptr; +} + +void TensorrtExecutionProviderFactory::SetHostAccessibleMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id) { + cuda_pinned_memory_info_to_device_id_map_[mem_info.get()] = device_id; + device_id_to_cuda_pinned_memory_info_map_[device_id] = mem_info.get(); + cuda_pinned_memory_infos_.push_back(std::move(mem_info)); } // To make symbols visible on macOS/iOS diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h index 5d2d476c..a8c52882 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -3,13 +3,18 @@ #include "tensorrt_execution_provider_utils.h" #include "tensorrt_execution_provider_data_transfer.h" +using MemoryInfoUniquePtr = std::unique_ptr>; + /// /// Plugin TensorRT EP factory that can create an OrtEp and return information about the supported hardware devices. /// struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { public: TensorrtExecutionProviderFactory(const char* ep_name, ApiPtrs apis); - OrtMemoryInfo* GetDefaultMemInfo() const; + + const OrtMemoryInfo* GetDefaultGpuMemInfoForDeviceId(uint32_t device_id) const; + + const OrtMemoryInfo* GetHostAccessibleMemInfoForDeviceId(uint32_t device_id) const; private: static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; @@ -37,17 +42,30 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, OrtDataTransferImpl** data_transfer) noexcept; + bool GetDeviceIdForDefaultGpuMemInfo(const OrtMemoryInfo* mem_info, uint32_t* device_id) const; + + void SetDefaultGpuMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id); + + bool GetDeviceIdForHostAccessibleMemInfo(const OrtMemoryInfo* mem_info, uint32_t* device_id) const; + + void SetHostAccessibleMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id); + const std::string ep_name_; // EP name const std::string vendor_{"Nvidia"}; // EP vendor name - // CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed. - using MemoryInfoUniquePtr = std::unique_ptr>; - //MemoryInfoUniquePtr cpu_memory_info_; + // OrtMemoryInfo for allocators and data transfer. + + // CUDA gpu memory and CUDA pinned memory are required for allocator and data transfer, these are the OrtMemoryInfo instance required for that. + // Current TRT EP implementation uses one default OrtMemoryInfo and one host accessible OrtMemoryInfo per ep device. + std::unordered_map cuda_gpu_memory_info_to_device_id_map_; // OrtMemoryInfo -> device id + std::unordered_map cuda_pinned_memory_info_to_device_id_map_; + std::unordered_map device_id_to_cuda_gpu_memory_info_map_; // device id -> OrtMemoryInfo + std::unordered_map device_id_to_cuda_pinned_memory_info_map_; + std::vector cuda_gpu_memory_infos_; + std::vector cuda_pinned_memory_infos_; - // GPU memory and pinned/shared memory are required for data transfer, these are the - // OrtMemoryInfo instance required for that. - MemoryInfoUniquePtr default_gpu_memory_info_; - MemoryInfoUniquePtr host_accessible_gpu_memory_info_; + // CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed. + // MemoryInfoUniquePtr cpu_memory_info_; std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory }; \ No newline at end of file From f443a332adf4fe6111f17edb54c1b6a4c58d1c65 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Sun, 13 Jul 2025 23:28:17 -0700 Subject: [PATCH 25/60] update cuda/pinned allocator to make compiler happy --- .../tensorrt/cuda_allocator.h | 13 ++----------- .../tensorrt/tensorrt_provider_factory.cc | 2 +- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/plugin_execution_providers/tensorrt/cuda_allocator.h b/plugin_execution_providers/tensorrt/cuda_allocator.h index eb5ac144..1a1dfaa6 100644 --- a/plugin_execution_providers/tensorrt/cuda_allocator.h +++ b/plugin_execution_providers/tensorrt/cuda_allocator.h @@ -7,9 +7,6 @@ #define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" -constexpr const char* CUDA_ALLOCATOR = "Cuda"; -constexpr const char* CUDA_PINNED_ALLOCATOR = "CudaPinned"; - using DeviceId = int16_t; struct CUDAAllocator : OrtAllocator { @@ -41,17 +38,11 @@ struct CUDAAllocator : OrtAllocator { }; struct CUDAPinnedAllocator : OrtAllocator { - CUDAPinnedAllocator(const char* name = CUDA_PINNED_ALLOCATOR) { + CUDAPinnedAllocator(const OrtMemoryInfo* mem_info) : mem_info_(mem_info) { OrtAllocator::version = ORT_API_VERSION; OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; - const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); - api->CreateMemoryInfo(name, - OrtAllocatorType::OrtDeviceAllocator, - 0 /* CPU device always with id 0 */, - OrtMemType::OrtMemTypeDefault, - &mem_info_); } // TODO: Handle destructor //~CUDAPinnedAllocator(); @@ -67,5 +58,5 @@ struct CUDAPinnedAllocator : OrtAllocator { CUDAPinnedAllocator& operator=(const CUDAPinnedAllocator&) = delete; DeviceId device_id_ = 0; - OrtMemoryInfo* mem_info_ = nullptr; + const OrtMemoryInfo* mem_info_ = nullptr; }; diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index 56517d80..40725450 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -198,7 +198,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl( if (factory.GetDeviceIdForDefaultGpuMemInfo(memory_info, &device_id)) { // create a CUDA allocator - auto cuda_allocator = std::make_unique(memory_info, static_cast(device_id)); + auto cuda_allocator = std::make_unique(memory_info, static_cast(device_id)); *allocator = cuda_allocator.release(); } else if (factory.GetDeviceIdForHostAccessibleMemInfo(memory_info, &device_id)) { // create a CUDA PINNED allocator From 95dd71eba23f9dff8cb6a7519d34ab7460272209 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 14 Jul 2025 14:13:20 -0700 Subject: [PATCH 26/60] add GetVersionImpl in factory --- .../tensorrt/tensorrt_provider_factory.cc | 6 ++++++ .../tensorrt/tensorrt_provider_factory.h | 5 ++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index 40725450..a2b43c4b 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -18,6 +18,7 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* e ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; @@ -40,6 +41,11 @@ const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetVendorImpl(const O return factory->vendor_.c_str(); } +const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->ep_version_.c_str(); +} + OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImpl( OrtEpFactory* this_ptr, const OrtHardwareDevice* const* devices, diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h index a8c52882..17d60eb5 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -1,6 +1,6 @@ #pragma once -#include "tensorrt_execution_provider_utils.h" +#include "ep_utils.h" #include "tensorrt_execution_provider_data_transfer.h" using MemoryInfoUniquePtr = std::unique_ptr>; @@ -21,6 +21,8 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + static const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* devices, size_t num_devices, OrtEpDevice** ep_devices, size_t max_ep_devices, @@ -52,6 +54,7 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { const std::string ep_name_; // EP name const std::string vendor_{"Nvidia"}; // EP vendor name + const std::string ep_version_{"0.1.0"}; // EP version // OrtMemoryInfo for allocators and data transfer. From 35b0cf1efbcfbde2a791849c2d6f314844aa322c Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 14 Jul 2025 14:17:47 -0700 Subject: [PATCH 27/60] update data transfer initialization in TRT EP --- .../tensorrt/tensorrt_provider_factory.cc | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index a2b43c4b..d0077c24 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -119,14 +119,6 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp ep_devices[num_ep_devices++] = ep_device; } - - // Create gpu data transfer - auto data_transfer_impl = std::make_unique( - static_cast(*factory), - cuda_gpu_mem_devices, // device memory - cuda_pinned_mem_devices // shared memory - ); - // C++ API equivalent. Throws on error. //{ @@ -142,6 +134,12 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp //} } + // Create gpu data transfer + auto data_transfer_impl = std::make_unique(static_cast(*factory), + cuda_gpu_mem_devices, // device memory + cuda_pinned_mem_devices // shared memory + ); + return nullptr; } From a65908fe6effe61e1fd802013c15488f3d6ac8f5 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 14 Jul 2025 16:06:59 -0700 Subject: [PATCH 28/60] Fix compile errors/issues --- .../tensorrt/onnx_ctx_model_helper.cc | 19 +++- .../tensorrt/onnx_ctx_model_helper.h | 2 +- .../tensorrt/tensorrt_execution_provider.cc | 18 +++- .../tensorrt/tensorrt_execution_provider.def | 5 + .../tensorrt/tensorrt_execution_provider.lds | 7 ++ ...ensorrt_execution_provider_data_transfer.h | 2 +- .../tensorrt_execution_provider_info.cc | 1 + .../tensorrt_execution_provider_info.h | 1 - .../tensorrt_execution_provider_utils.h | 97 +------------------ .../tensorrt/tensorrt_provider_factory.cc | 8 +- .../tensorrt/utils/provider_options_utils.h | 2 +- 11 files changed, 55 insertions(+), 107 deletions(-) create mode 100644 plugin_execution_providers/tensorrt/tensorrt_execution_provider.def create mode 100644 plugin_execution_providers/tensorrt/tensorrt_execution_provider.lds diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc index 0a479ebc..b8e3838e 100644 --- a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc @@ -5,7 +5,7 @@ #include #include -#include "tensorrt_execution_provider_utils.h" +#include "ep_utils.h" #include "onnx_ctx_model_helper.h" extern TensorrtLogger& GetTensorrtLogger(bool verbose_log); @@ -109,3 +109,20 @@ OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_ca return nullptr; } + +/* + * Get the weight-refitted engine cache path from a weight-stripped engine cache path + * + * Weight-stipped engine: + * An engine with weights stripped and its size is smaller than a regualr engine. + * The cache name of weight-stripped engine is TensorrtExecutionProvider_TRTKernel_XXXXX.stripped.engine + * + * Weight-refitted engine: + * An engine that its weights have been refitted and it's simply a regular engine. + * The cache name of weight-refitted engine is TensorrtExecutionProvider_TRTKernel_XXXXX.engine + */ +std::string GetWeightRefittedEnginePath(std::string stripped_engine_cache) { + std::filesystem::path stripped_engine_cache_path(stripped_engine_cache); + std::string refitted_engine_cache_path = stripped_engine_cache_path.stem().stem().string() + ".engine"; + return refitted_engine_cache_path; +} \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h index 1b1d2891..4d77fede 100644 --- a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h @@ -3,8 +3,8 @@ #pragma once -#include "tensorrt_execution_provider_utils.h" #include "tensorrt_execution_provider.h" +#include "ep_utils.h" #include "nv_includes.h" #include diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 4b126004..938b9e13 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -12,12 +12,13 @@ #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL #include "ort_graph_to_proto.h" -//#include "tensorrt_execution_provider_utils.h" +#include "tensorrt_execution_provider_utils.h" #include "tensorrt_execution_provider.h" #include "cuda_allocator.h" #include "onnx_ctx_model_helper.h" #include "onnx/onnx_pb.h" #include "cuda/unary_elementwise_ops_impl.h" +#include "ep_utils.h" #ifdef _WIN32 #include @@ -31,6 +32,10 @@ #define LIBFUNC(lib, fn) dlsym((lib), (fn)) #endif +const OrtApi* g_ort_api = nullptr; +const OrtEpApi* g_ep_api = nullptr; +const OrtModelEditorApi* g_model_editor_api = nullptr; + void CUDA_RETURN_IF_ERROR(cudaError_t res) { if (res != cudaSuccess) abort(); } @@ -1795,9 +1800,9 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ OrtStatus* status; if (EPContextNodeHelper::GraphHasCtxNode(graphs[fused_node_idx], ort_api)) { - RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[fused_node_idx], fused_node, - input_map, output_map, - &node_compute_infos_result[fused_node_idx])); + //RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[fused_node_idx], fused_node, + // input_map, output_map, + // &node_compute_infos_result[fused_node_idx])); } else { RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[fused_node_idx], fused_node, input_map, output_map, &node_compute_infos_result[fused_node_idx], @@ -1899,6 +1904,8 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine( #endif } +TensorrtExecutionProvider::~TensorrtExecutionProvider() = default; + /// /// /// Plugin TensorRT EP that implements OrtEp @@ -1908,7 +1915,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa const std::string& name, const OrtSessionOptions& session_options, const OrtLogger& logger) - : ApiPtrs{static_cast(factory)}, + : OrtEp{}, // explicitly call the struct ctor to ensure all optional values are default initialized + ApiPtrs{static_cast(factory)}, factory_(factory), name_{name}, session_options_{session_options}, diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.def b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.def new file mode 100644 index 00000000..ae83cb71 --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.def @@ -0,0 +1,5 @@ +LIBRARY "TensorRTEp.dll" +EXPORTS + CreateEpFactories @1 + ReleaseEpFactory @2 + \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.lds b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.lds new file mode 100644 index 00000000..a6d2ef09 --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.lds @@ -0,0 +1,7 @@ +VERS_1.0.0 { + global: + CreateEpFactories; + ReleaseEpFactory; + local: + *; +}; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h index 2e5ac808..3dead944 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h @@ -3,7 +3,7 @@ #pragma once -#include "tensorrt_execution_provider_utils.h" +#include "ep_utils.h" struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs { TRTEpDataTransfer(ApiPtrs api_ptrs, std::vector device_mem_infos, diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc index c9154bbf..c27d8095 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc @@ -6,6 +6,7 @@ #include "tensorrt_execution_provider_info.h" #include "provider_options_utils.h" #include "cuda/cuda_common.h" +#include "ep_utils.h" namespace tensorrt { namespace provider_option_names { diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h index 3dec464d..f2614721 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h @@ -3,7 +3,6 @@ #pragma once -#include "tensorrt_execution_provider_utils.h" #include "provider_options.h" #include diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h index 49c11286..ac27904b 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h @@ -4,6 +4,7 @@ #include "onnxruntime_cxx_api.h" #undef ORT_API_MANUAL_INIT +#include "ep_utils.h" #include "flatbuffers/idl.h" #include "ort_trt_int8_cal_table.fbs.h" #include "make_string.h" @@ -22,104 +23,8 @@ #include #include -struct ApiPtrs { - const OrtApi& ort_api; - const OrtEpApi& ep_api; - const OrtModelEditorApi& model_editor_api; -}; - -const OrtApi* g_ort_api = nullptr; -const OrtEpApi* g_ep_api = nullptr; -const OrtModelEditorApi* g_model_editor_api = nullptr; - -#define ENFORCE(condition, ...) \ - do { \ - if (!(condition)) { \ - throw std::runtime_error(MakeString(__VA_ARGS__)); \ - } \ - } while (false) - -#define THROW(...) \ - throw std::runtime_error(MakeString(__VA_ARGS__)); - -#define RETURN_IF_ERROR(fn) \ - do { \ - OrtStatus* _status = (fn); \ - if (_status != nullptr) { \ - return _status; \ - } \ - } while (0) - -/* -template -std::string ComposeString(Args&&... args) { - std::ostringstream oss; - (oss << ... << args); - return oss.str(); -}; -*/ - -#define RETURN_IF(cond, ...) \ - do { \ - if ((cond)) { \ - return Ort::GetApi().CreateStatus(ORT_EP_FAIL, MakeString(__VA_ARGS__).c_str()); \ - } \ - } while (0) - -#define RETURN_IF_NOT(condition, ...) RETURN_IF(!(condition), __VA_ARGS__) - -#define MAKE_STATUS(error_code, msg) \ - Ort::GetApi().CreateStatus(error_code, (msg)); - -#define THROW_IF_ERROR(expr) \ - do { \ - auto _status = (expr); \ - if (_status != nullptr) { \ - std::ostringstream oss; \ - oss << Ort::GetApi().GetErrorMessage(_status); \ - Ort::GetApi().ReleaseStatus(_status); \ - throw std::runtime_error(oss.str()); \ - } \ - } while (0) - -#define RETURN_FALSE_AND_PRINT_IF_ERROR(fn) \ - do { \ - OrtStatus* status = (fn); \ - if (status != nullptr) { \ - std::cerr << Ort::GetApi().GetErrorMessage(status) << std::endl; \ - return false; \ - } \ - } while (0) - -// Helper to release Ort one or more objects obtained from the public C API at the end of their scope. -template -struct DeferOrtRelease { - DeferOrtRelease(T** object_ptr, std::function release_func) - : objects_(object_ptr), count_(1), release_func_(release_func) {} - - DeferOrtRelease(T** objects, size_t count, std::function release_func) - : objects_(objects), count_(count), release_func_(release_func) {} - - ~DeferOrtRelease() { - if (objects_ != nullptr && count_ > 0) { - for (size_t i = 0; i < count_; ++i) { - if (objects_[i] != nullptr) { - release_func_(objects_[i]); - objects_[i] = nullptr; - } - } - } - } - T** objects_ = nullptr; - size_t count_ = 0; - std::function release_func_ = nullptr; -}; - namespace fs = std::filesystem; -template -using AllocatorUniquePtr = std::unique_ptr>; - bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept { size_t alloc_size = size; if (alignment == 0) { diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index d0077c24..2b8f9533 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -58,11 +58,16 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp std::vector cuda_gpu_mem_devices; std::vector cuda_pinned_mem_devices; + int GPU_cnt = 0; for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { // C API const OrtHardwareDevice& device = *devices[i]; if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + if (GPU_cnt > 0) { + continue; + } + GPU_cnt++; // These can be returned as nullptr if you have nothing to add. OrtKeyValuePairs* ep_metadata = nullptr; OrtKeyValuePairs* ep_options = nullptr; @@ -87,7 +92,8 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp } uint32_t vendor_id = factory->ort_api.HardwareDevice_VendorId(&device); - uint32_t device_id = factory->ort_api.HardwareDevice_DeviceId(&device); + //uint32_t device_id = factory->ort_api.HardwareDevice_DeviceId(&device); + uint32_t device_id = 0; // CUDA allocator OrtMemoryInfo OrtMemoryInfo* mem_info = nullptr; diff --git a/plugin_execution_providers/tensorrt/utils/provider_options_utils.h b/plugin_execution_providers/tensorrt/utils/provider_options_utils.h index f190f20c..9a02d272 100644 --- a/plugin_execution_providers/tensorrt/utils/provider_options_utils.h +++ b/plugin_execution_providers/tensorrt/utils/provider_options_utils.h @@ -10,7 +10,7 @@ #include #include "onnxruntime_c_api.h" -#include "../tensorrt_execution_provider_utils.h" +#include "ep_utils.h" #include "parse_string.h" #include "provider_options.h" From c77391fb6190e5e4d2021450bfa71f6bd23fb399 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 14 Jul 2025 16:24:16 -0700 Subject: [PATCH 29/60] fix to use correct API --- .../tensorrt/tensorrt_execution_provider.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 938b9e13..a62e67a2 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -1717,7 +1717,9 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this int number_of_trt_nodes = 0; for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { - std::vector supported_nodes(group.first.size()); + std::vector supported_nodes; + supported_nodes.reserve(group.first.size()); + for (const auto& index : group.first) { const OrtNode* supported_node = nodes[index]; @@ -1782,7 +1784,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ // Gets number of node's outputs size_t num_node_outputs = 0; - RETURN_IF_ERROR(ort_api.Node_GetNumInputs(fused_node, &num_node_outputs)); + RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(fused_node, &num_node_outputs)); std::vector node_outputs(num_node_outputs); RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node, node_outputs.data(), node_outputs.size())); From c5363e6161e1ebfd9f08110b7648eeadfd6d270d Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 14 Jul 2025 17:55:45 -0700 Subject: [PATCH 30/60] fix bug for gpu data transfer implementation --- .../tensorrt/tensorrt_execution_provider_data_transfer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc index 67b868d2..b8c74511 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc @@ -61,7 +61,7 @@ OrtStatus* ORT_API_CALL TRTEpDataTransfer::CopyTensorsImpl(void* this_ptr, RETURN_IF_ERROR(impl.ort_api.GetTensorMutableData(dst_tensors[i], &dst_data)); size_t bytes = 0; - RETURN_IF_ERROR(impl.ort_api.GetTensorSizeInBytes(reinterpret_cast(src_data), &bytes)); + RETURN_IF_ERROR(impl.ort_api.GetTensorSizeInBytes(src_tensors[i], &bytes)); // for the sync version of memcpy, launch to cuda default stream if (dst_device_type == OrtMemoryInfoDeviceType_GPU) { From 09138eecfeb032a0e50b0d9800f95ac8db57cd75 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 14 Jul 2025 17:56:26 -0700 Subject: [PATCH 31/60] clean up --- plugin_execution_providers/tensorrt/cuda_allocator.h | 2 -- .../tensorrt/onnx_ctx_model_helper.cc | 1 + .../tensorrt/tensorrt_execution_provider.cc | 2 -- .../tensorrt/tensorrt_execution_provider.h | 4 ---- .../tensorrt/tensorrt_execution_provider_utils.h | 2 -- .../tensorrt/tensorrt_provider_factory.cc | 8 +++++--- .../tensorrt/tensorrt_provider_factory.h | 2 ++ 7 files changed, 8 insertions(+), 13 deletions(-) diff --git a/plugin_execution_providers/tensorrt/cuda_allocator.h b/plugin_execution_providers/tensorrt/cuda_allocator.h index 1a1dfaa6..44557cad 100644 --- a/plugin_execution_providers/tensorrt/cuda_allocator.h +++ b/plugin_execution_providers/tensorrt/cuda_allocator.h @@ -4,8 +4,6 @@ #pragma once #include #include "onnxruntime_c_api.h" -#define ORT_API_MANUAL_INIT -#include "onnxruntime_cxx_api.h" using DeviceId = int16_t; diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc index b8e3838e..53dcab78 100644 --- a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc @@ -21,6 +21,7 @@ bool EPContextNodeHelper::GraphHasCtxNode(const OrtGraph* graph, const OrtApi& o RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); std::vector nodes(num_nodes); + RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); for (size_t i = 0; i < num_nodes; ++i) { auto node = nodes[i]; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index a62e67a2..49c8e41e 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -5,9 +5,7 @@ #include #include -#define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" -#undef ORT_API_MANUAL_INIT #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL #include "ort_graph_to_proto.h" diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index f2ae4d45..fac9f024 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -1,9 +1,5 @@ #pragma once -#define ORT_API_MANUAL_INIT -#include "onnxruntime_cxx_api.h" -#undef ORT_API_MANUAL_INIT - #include "tensorrt_provider_factory.h" #include "utils/provider_options.h" #include "tensorrt_execution_provider_info.h" diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h index ac27904b..4685f386 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h @@ -1,8 +1,6 @@ #pragma once -#define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" -#undef ORT_API_MANUAL_INIT #include "ep_utils.h" #include "flatbuffers/idl.h" diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index 2b8f9533..f70e166f 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -1,6 +1,4 @@ -#define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" -#undef ORT_API_MANUAL_INIT #include "tensorrt_provider_factory.h" #include "tensorrt_execution_provider.h" #include "cuda_allocator.h" @@ -145,7 +143,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp cuda_gpu_mem_devices, // device memory cuda_pinned_mem_devices // shared memory ); - + factory->SetGPUDataTransfer(std::move(data_transfer_impl)); return nullptr; } @@ -283,6 +281,10 @@ void TensorrtExecutionProviderFactory::SetHostAccessibleMemInfo(MemoryInfoUnique cuda_pinned_memory_infos_.push_back(std::move(mem_info)); } +void TensorrtExecutionProviderFactory::SetGPUDataTransfer(std::unique_ptr gpu_data_transfer) { + data_transfer_impl_ = std::move(gpu_data_transfer); +} + // To make symbols visible on macOS/iOS #ifdef __APPLE__ #define EXPORT_SYMBOL __attribute__((visibility("default"))) diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h index 17d60eb5..b73550f9 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -52,6 +52,8 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { void SetHostAccessibleMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id); + void SetGPUDataTransfer(std::unique_ptr gpu_data_transfer); + const std::string ep_name_; // EP name const std::string vendor_{"Nvidia"}; // EP vendor name const std::string ep_version_{"0.1.0"}; // EP version From a8dde45ddea1642b7b56a847be80fdc57dfcb04f Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 14 Jul 2025 18:17:17 -0700 Subject: [PATCH 32/60] remove unnecessary files --- .../tensorrt/utils/helper.cc | 59 ------------ .../tensorrt/utils/status.cc | 91 ------------------- 2 files changed, 150 deletions(-) delete mode 100644 plugin_execution_providers/tensorrt/utils/helper.cc delete mode 100644 plugin_execution_providers/tensorrt/utils/status.cc diff --git a/plugin_execution_providers/tensorrt/utils/helper.cc b/plugin_execution_providers/tensorrt/utils/helper.cc deleted file mode 100644 index 7a889c30..00000000 --- a/plugin_execution_providers/tensorrt/utils/helper.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include "common.h" - -#ifdef _WIN32 -#include -#include -#endif - -namespace onnxruntime { -#ifdef _WIN32 -std::string ToUTF8String(const std::wstring& s) { - if (s.size() >= static_cast(std::numeric_limits::max())) - ORT_THROW("length overflow"); - - const int src_len = static_cast(s.size() + 1); - const int len = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, nullptr, 0, nullptr, nullptr); - assert(len > 0); - std::string ret(static_cast(len) - 1, '\0'); -#pragma warning(disable : 4189) - const int r = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, (char*)ret.data(), len, nullptr, nullptr); - assert(len == r); -#pragma warning(default : 4189) - return ret; -} - -std::wstring ToWideString(const std::string& s) { - if (s.size() >= static_cast(std::numeric_limits::max())) - ORT_THROW("length overflow"); - - const int src_len = static_cast(s.size() + 1); - const int len = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, nullptr, 0); - assert(len > 0); - std::wstring ret(static_cast(len) - 1, '\0'); -#pragma warning(disable : 4189) - const int r = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, (wchar_t*)ret.data(), len); - assert(len == r); -#pragma warning(default : 4189) - return ret; -} -#endif // #ifdef _WIN32 - -#ifdef ORT_NO_EXCEPTIONS -void PrintFinalMessage(const char* msg) { -#if defined(__ANDROID__) - __android_log_print(ANDROID_LOG_ERROR, "onnxruntime", "%s", msg); -#else - // TODO, consider changing the output of the error message from std::cerr to logging when the - // exceptions are disabled, since using std::cerr might increase binary size, and std::cerr output - // might not be easily accessible on some systems such as mobile - // TODO, see if we need to change the output of the error message from std::cerr to NSLog for iOS - std::cerr << msg << std::endl; -#endif -} -#endif // #ifdef ORT_NO_EXCEPTIONS - -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/status.cc b/plugin_execution_providers/tensorrt/utils/status.cc deleted file mode 100644 index b3a89c8c..00000000 --- a/plugin_execution_providers/tensorrt/utils/status.cc +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Modifications Copyright (c) Microsoft. - -#include "status.h" -#include "common.h" - -namespace onnxruntime { -namespace common { -Status::Status(StatusCategory category, int code, const std::string& msg) { - // state_ will be allocated here causing the status to be treated as a failure - ORT_ENFORCE(code != static_cast(common::OK)); - - state_ = std::make_unique(category, code, msg); -} - -Status::Status(StatusCategory category, int code, const char* msg) { - // state_ will be allocated here causing the status to be treated as a failure - ORT_ENFORCE(code != static_cast(common::OK)); - - state_ = std::make_unique(category, code, msg); -} - -Status::Status(StatusCategory category, int code) - : Status(category, code, "") { -} - -StatusCategory Status::Category() const noexcept { - return IsOK() ? common::NONE : state_->category; -} - -int Status::Code() const noexcept { - return IsOK() ? static_cast(common::OK) : state_->code; -} - -const std::string& Status::ErrorMessage() const noexcept { - return IsOK() ? EmptyString() : state_->msg; -} - -std::string Status::ToString() const { - if (state_ == nullptr) { - return std::string("OK"); - } - - std::string result; - - if (common::SYSTEM == state_->category) { - result += "SystemError"; - result += " : "; - result += std::to_string(errno); - } else if (common::ONNXRUNTIME == state_->category) { - result += "[ONNXRuntimeEPError]"; - result += " : "; - result += std::to_string(Code()); - result += " : "; - result += StatusCodeToString(static_cast(Code())); - result += " : "; - result += state_->msg; - } - - return result; -} - -// GSL_SUPRESS(i.22) is broken. Ignore the warnings for the static local variables that are trivial -// and should not have any destruction order issues via pragmas instead. -// https://developercommunity.visualstudio.com/content/problem/249706/gslsuppress-does-not-work-for-i22-c-core-guideline.html -#ifdef _MSC_VER -#pragma warning(push) -#pragma warning(disable : 26426) -#endif - -const std::string& Status::EmptyString() noexcept { - static std::string s_empty; - return s_empty; -} - -#ifdef _MSC_VER -#pragma warning(pop) -#endif - -} // namespace common -} // namespace onnxruntime From b9117546f0e4de03b93bca024e1ef6e4c0375fb1 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 14 Jul 2025 18:17:53 -0700 Subject: [PATCH 33/60] Temporarily manually creates cudaStream to run --- .../tensorrt/tensorrt_execution_provider.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 49c8e41e..6e9fb790 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -2339,6 +2339,9 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* Ort::ThrowOnError(ep.ort_api.KernelContext_GetGPUComputeStream(kernel_context, &cuda_stream)); cudaStream_t stream = static_cast(cuda_stream); + //cudaStream_t stream; + cudaStreamCreate(&stream); + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even // if they share the same compute capacity Prepare cache name From 0c817acfa54b23de10173db6762328a46d243171 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 14 Jul 2025 18:18:50 -0700 Subject: [PATCH 34/60] Temporary make plugin TRT links against the protobuf, onnx, flatbuffers built from ORT repo --- .../tensorrt/CMakeLists.txt | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/plugin_execution_providers/tensorrt/CMakeLists.txt b/plugin_execution_providers/tensorrt/CMakeLists.txt index 9456f7cf..402276c0 100644 --- a/plugin_execution_providers/tensorrt/CMakeLists.txt +++ b/plugin_execution_providers/tensorrt/CMakeLists.txt @@ -9,6 +9,9 @@ enable_language(CUDA) file(TO_CMAKE_PATH CUDAToolkit_ROOT "/usr/local/cuda") find_package(CUDAToolkit REQUIRED) +# Use dynamic runtime /MD and /MDd +set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -Xcompiler=\"/MTd\"") + add_definitions(-DONNX_NAMESPACE=onnx) add_definitions(-DONNX_ML) add_definitions(-DNV_TENSORRT_MAJOR=10) @@ -71,15 +74,16 @@ FetchContent_MakeAvailable(flatbuffers) set(DEPS_PATH "${CMAKE_BINARY_DIR}/_deps") if (WIN32) # Windows - set(ORT_LIB "${ORT_HOME}/lib/onnxruntime.lib") - #set(ORT_LIB "${ORT_HOME}/lib/onnxruntime.dll") + set(PLATFORM "Windows") + set(ORT_LIB "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/${CMAKE_BUILD_TYPE}/onnxruntime.lib") + set(DEPS_PATH "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/_deps") set(TRT_LIBS "${TENSORRT_HOME}/lib/nvinfer_10.lib" "${TENSORRT_HOME}/lib/nvinfer_plugin_10.lib" "${TENSORRT_HOME}/lib/nvonnxparser_10.lib") set(DEPS_LIBS "${DEPS_PATH}/flatbuffers-build/${CMAKE_BUILD_TYPE}/flatbuffers.lib" - "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx.lib" - "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx_proto.lib") - + "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx.lib" + "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx_proto.lib") + if(CMAKE_BUILD_TYPE STREQUAL "Debug") set(DEPS_LIBS ${DEPS_LIBS} "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobufd.lib" @@ -89,6 +93,10 @@ if (WIN32) # Windows "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobuf.lib" "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotoc.lib") endif() + + set(TRT_EP_LIB_LINK_FLAG + "-DEF:${CMAKE_SOURCE_DIR}/tensorrt_execution_provider.def") + else() # Linux set(ORT_LIB "${ORT_HOME}/lib/libonnxruntime.so") set(TRT_LIBS "${TENSORRT_HOME}/lib/libnvinfer.so" @@ -114,12 +122,16 @@ MESSAGE(STATUS "ORT lib : ${ORT_LIB}") MESSAGE(STATUS "TRT libs : ${TRT_LIBS}") MESSAGE(STATUS "Deps libs: ${DEPS_LIBS}") -target_include_directories(TensorRTEp PUBLIC "${ORT_HOME}/include" +set_property(TARGET TensorRTEp APPEND_STRING PROPERTY LINK_FLAGS + ${TRT_EP_LIB_LINK_FLAG}) + +target_include_directories(TensorRTEp PUBLIC #"${ORT_HOME}/include" + "${ORT_HOME}/include/onnxruntime/core/session" "./utils" "/usr/local/cuda/include" "${TENSORRT_HOME}/include" "${DEPS_PATH}/flatbuffers-src/include" - "${DEPS_PATH}/gsl-src/include" + "${DEPS_PATH}/gsl-src/include" # GSL is header-only "${DEPS_PATH}/onnx-src" "${DEPS_PATH}/onnx-build" "${DEPS_PATH}/protobuf-src/src" @@ -128,8 +140,5 @@ target_include_directories(TensorRTEp PUBLIC "${ORT_HOME}/include" target_link_libraries(TensorRTEp PUBLIC ${ORT_LIB} ${TRT_LIBS} CUDA::cudart - protobuf - onnx - gsl - flatbuffers + ${DEPS_LIBS} ) From da729f93b218c41ebff4118b0f7cd88343a76808 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 15 Jul 2025 11:46:33 -0700 Subject: [PATCH 35/60] fix the issue of error LNK2038: mismatch detected for 'RuntimeLibrary' in CMake for Windows debug build --- .../tensorrt/CMakeLists.txt | 51 ++++++++++++------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/plugin_execution_providers/tensorrt/CMakeLists.txt b/plugin_execution_providers/tensorrt/CMakeLists.txt index 402276c0..1db2a761 100644 --- a/plugin_execution_providers/tensorrt/CMakeLists.txt +++ b/plugin_execution_providers/tensorrt/CMakeLists.txt @@ -1,6 +1,6 @@ # usage: # cd build/ -# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DORT_HOME=/home/lochi/onnxruntime-win-x64-gpu-1.22.0 -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DTENSORRT_HOME=/home/lochi/tensorrt/TensorRT-10.3.0.26 (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") +# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DORT_HOME=/home/lochi/onnxruntime-win-x64-gpu-1.23.0 -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DTENSORRT_HOME=/home/lochi/tensorrt/TensorRT-10.3.0.26 (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") # cmake --build ./ --config Debug cmake_minimum_required(VERSION 3.26) project(TensorRTEp VERSION 1.0) @@ -9,8 +9,15 @@ enable_language(CUDA) file(TO_CMAKE_PATH CUDAToolkit_ROOT "/usr/local/cuda") find_package(CUDAToolkit REQUIRED) -# Use dynamic runtime /MD and /MDd -set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -Xcompiler=\"/MTd\"") +# CMake config to force dynamic debug CRT globally for all dependencies. +# This is to address the issue of: +# libprotobufd.lib(common.obj) : error LNK2038: mismatch detected for 'RuntimeLibrary': value 'MTd_StaticDebug' doesn't match value 'MDd_DynamicDebug' in unary_elementwise_ops_impl.obj +if (WIN32) + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreadedDebugDLL" CACHE STRING "" FORCE) # /MDd + set(BUILD_SHARED_LIBS OFF) # Build protobuf as static .lib, but using dynamic runtime + endif() +endif() add_definitions(-DONNX_NAMESPACE=onnx) add_definitions(-DONNX_ML) @@ -42,6 +49,13 @@ FetchContent_Declare( GIT_TAG v21.12 # Use a specific tag or commit ) +if (WIN32) + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + # Sometimes, protobuf ignores CMAKE_MSVC_RUNTIME_LIBRARY. To ensure it works: + set(protobuf_MSVC_STATIC_RUNTIME OFF CACHE BOOL "" FORCE) + endif() +endif() + FetchContent_MakeAvailable(protobuf) # Add ONNX @@ -53,6 +67,10 @@ FetchContent_Declare( FetchContent_MakeAvailable(onnx) +#set(ONNX_USE_LITE_PROTO OFF CACHE BOOL "" FORCE) +#set(ONNX_BUILD_TESTS OFF CACHE BOOL "" FORCE) +#set(ONNX_GEN_PB_TYPE_STUBS OFF CACHE BOOL "" FORCE) + # Add GSL FetchContent_Declare( gsl @@ -74,26 +92,25 @@ FetchContent_MakeAvailable(flatbuffers) set(DEPS_PATH "${CMAKE_BINARY_DIR}/_deps") if (WIN32) # Windows - set(PLATFORM "Windows") - set(ORT_LIB "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/${CMAKE_BUILD_TYPE}/onnxruntime.lib") - set(DEPS_PATH "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/_deps") + set(ORT_LIB "${ORT_HOME}/lib/onnxruntime.lib") set(TRT_LIBS "${TENSORRT_HOME}/lib/nvinfer_10.lib" "${TENSORRT_HOME}/lib/nvinfer_plugin_10.lib" "${TENSORRT_HOME}/lib/nvonnxparser_10.lib") - set(DEPS_LIBS "${DEPS_PATH}/flatbuffers-build/${CMAKE_BUILD_TYPE}/flatbuffers.lib" - "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx.lib" - "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx_proto.lib") if(CMAKE_BUILD_TYPE STREQUAL "Debug") set(DEPS_LIBS ${DEPS_LIBS} - "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobufd.lib" - "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotocd.lib") + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobufd.lib" + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotocd.lib") else() set(DEPS_LIBS ${DEPS_LIBS} - "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobuf.lib" - "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotoc.lib") + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobuf.lib" + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotoc.lib") endif() + set(DEPS_LIBS "${DEPS_PATH}/flatbuffers-build/${CMAKE_BUILD_TYPE}/flatbuffers.lib" + "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx.lib" + "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx_proto.lib") + set(TRT_EP_LIB_LINK_FLAG "-DEF:${CMAKE_SOURCE_DIR}/tensorrt_execution_provider.def") @@ -125,8 +142,7 @@ MESSAGE(STATUS "Deps libs: ${DEPS_LIBS}") set_property(TARGET TensorRTEp APPEND_STRING PROPERTY LINK_FLAGS ${TRT_EP_LIB_LINK_FLAG}) -target_include_directories(TensorRTEp PUBLIC #"${ORT_HOME}/include" - "${ORT_HOME}/include/onnxruntime/core/session" +target_include_directories(TensorRTEp PUBLIC "${ORT_HOME}/include" "./utils" "/usr/local/cuda/include" "${TENSORRT_HOME}/include" @@ -137,8 +153,9 @@ target_include_directories(TensorRTEp PUBLIC #"${ORT_HOME}/include" "${DEPS_PATH}/protobuf-src/src" ) -target_link_libraries(TensorRTEp PUBLIC ${ORT_LIB} +target_link_libraries(TensorRTEp PUBLIC #${DEPS_LIBS} + protobuf::libprotobuf onnx flatbuffers + ${ORT_LIB} ${TRT_LIBS} CUDA::cudart - ${DEPS_LIBS} ) From 6fd38c300af60c360055bfe9a3ad108c1e5cf453 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 15 Jul 2025 14:54:33 -0700 Subject: [PATCH 36/60] refactor memory info stored in factory --- .../tensorrt/tensorrt_execution_provider.cc | 6 +- .../tensorrt/tensorrt_provider_factory.cc | 126 +++++++----------- .../tensorrt/tensorrt_provider_factory.h | 31 ++--- 3 files changed, 61 insertions(+), 102 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 6e9fb790..e28b61c1 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -2327,7 +2327,11 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* // Get default OrtMemoryInfo from factory // Get allocator from OrtKernelContext - const OrtMemoryInfo* mem_info = ep.factory_.GetDefaultGpuMemInfoForDeviceId(device_id); + const OrtMemoryInfo* mem_info = nullptr; + if (ep.factory_.device_id_to_cuda_gpu_memory_info_map.find(device_id) != + ep.factory_.device_id_to_cuda_gpu_memory_info_map.end()) { + mem_info = ep.factory_.device_id_to_cuda_gpu_memory_info_map[device_id]; + } OrtAllocator* alloc = nullptr; ep.GetAllocator(&alloc); if (alloc == nullptr) { diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index f70e166f..4cfbb98b 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -44,6 +44,32 @@ const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetVersionImpl(const return factory->ep_version_.c_str(); } +OrtStatus* TensorrtExecutionProviderFactory::CreateMemoryInfoForDevices(int num_devices) { + cuda_gpu_memory_infos.reserve(num_devices); + cuda_pinned_memory_infos.reserve(num_devices); + + for (int device_id = 0; device_id < num_devices; ++device_id) { + OrtMemoryInfo* mem_info = nullptr; + RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("Cuda", OrtMemoryInfoDeviceType_GPU, + /*vendor OrtDevice::VendorIds::NVIDIA*/ 0x10DE, + /* device_id */ device_id, OrtDeviceMemoryType_DEFAULT, + /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info)); + + cuda_gpu_memory_infos.emplace_back(MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo)); + + // HOST_ACCESSIBLE memory should use the non-CPU device type + mem_info = nullptr; + RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("CudaPinned", OrtMemoryInfoDeviceType_GPU, + /*vendor OrtDevice::VendorIds::NVIDIA*/ 0x10DE, + /* device_id */ device_id, OrtDeviceMemoryType_HOST_ACCESSIBLE, + /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info)); + + cuda_pinned_memory_infos.emplace_back(MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo)); + } + + return nullptr; +} + OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImpl( OrtEpFactory* this_ptr, const OrtHardwareDevice* const* devices, @@ -54,18 +80,24 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp size_t& num_ep_devices = *p_num_ep_devices; auto* factory = static_cast(this_ptr); + int num_cuda_devices = 0; + cudaGetDeviceCount(&num_cuda_devices); + RETURN_IF_ERROR(factory->CreateMemoryInfoForDevices(num_cuda_devices)); + std::vector cuda_gpu_mem_devices; std::vector cuda_pinned_mem_devices; - int GPU_cnt = 0; + int32_t device_id = 0; for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { // C API const OrtHardwareDevice& device = *devices[i]; if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - if (GPU_cnt > 0) { + + // workaround for duplicate devices when using remote desktop. + if (device_id > 0) { continue; } - GPU_cnt++; + // These can be returned as nullptr if you have nothing to add. OrtKeyValuePairs* ep_metadata = nullptr; OrtKeyValuePairs* ep_options = nullptr; @@ -89,39 +121,19 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp return status; } - uint32_t vendor_id = factory->ort_api.HardwareDevice_VendorId(&device); - //uint32_t device_id = factory->ort_api.HardwareDevice_DeviceId(&device); - uint32_t device_id = 0; - - // CUDA allocator OrtMemoryInfo - OrtMemoryInfo* mem_info = nullptr; - status = factory->ort_api.CreateMemoryInfo_V2("Cuda", OrtMemoryInfoDeviceType_GPU, vendor_id, device_id, OrtDeviceMemoryType_DEFAULT, - /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info); - - assert(status == nullptr); // should never fail. - MemoryInfoUniquePtr cuda_gpu_memory_info = MemoryInfoUniquePtr(mem_info, factory->ort_api.ReleaseMemoryInfo); - - // CUDA PINNED allocator OrtMemoryInfo - // HOST_ACCESSIBLE memory should use the non-CPU device type. - mem_info = nullptr; - status = factory->ort_api.CreateMemoryInfo_V2("CudaPinned", OrtMemoryInfoDeviceType_GPU, vendor_id, device_id, OrtDeviceMemoryType_HOST_ACCESSIBLE, - /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info); - - assert(status == nullptr); // should never fail. - MemoryInfoUniquePtr cuda_pinned_memory_info = MemoryInfoUniquePtr(mem_info, factory->ort_api.ReleaseMemoryInfo); + const OrtMemoryInfo* cuda_gpu_mem_info = factory->cuda_gpu_memory_infos[device_id].get(); + const OrtMemoryInfo* cuda_pinned_mem_info = factory->cuda_pinned_memory_infos[device_id].get(); // Register the allocator info required by TRT EP. - RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_gpu_memory_info.get())); - RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_pinned_memory_info.get())); + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_gpu_mem_info)); + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_pinned_mem_info)); // Get memory device from memory info for gpu data transfer - cuda_gpu_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_gpu_memory_info.get())); - cuda_pinned_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_pinned_memory_info.get())); - - factory->SetDefaultGpuMemInfo(std::move(cuda_gpu_memory_info), device_id); - factory->SetHostAccessibleMemInfo(std::move(cuda_pinned_memory_info), device_id); + cuda_gpu_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_gpu_mem_info)); + cuda_pinned_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_pinned_mem_info)); ep_devices[num_ep_devices++] = ep_device; + ++device_id; } // C++ API equivalent. Throws on error. @@ -202,13 +214,15 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl( // NOTE: The OrtMemoryInfo pointer should only ever be coming straight from an OrtEpDevice, and pointer based // matching should work. - uint32_t device_id = 0; + const OrtMemoryDevice* mem_device = factory.ep_api.MemoryInfo_GetMemoryDevice(memory_info); + uint32_t device_id = factory.ep_api.MemoryDevice_GetDeviceId(mem_device); - if (factory.GetDeviceIdForDefaultGpuMemInfo(memory_info, &device_id)) { + if (factory.ep_api.MemoryDevice_GetMemoryType(mem_device) == OrtDeviceMemoryType_DEFAULT) { // create a CUDA allocator auto cuda_allocator = std::make_unique(memory_info, static_cast(device_id)); + factory.device_id_to_cuda_gpu_memory_info_map[device_id] = memory_info; *allocator = cuda_allocator.release(); - } else if (factory.GetDeviceIdForHostAccessibleMemInfo(memory_info, &device_id)) { + } else if (factory.ep_api.MemoryDevice_GetMemoryType(mem_device) == OrtDeviceMemoryType_HOST_ACCESSIBLE) { // create a CUDA PINNED allocator auto cuda_pinned_allocator = std::make_unique(memory_info); *allocator = cuda_pinned_allocator.release(); @@ -235,52 +249,6 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl return nullptr; } -bool TensorrtExecutionProviderFactory::GetDeviceIdForDefaultGpuMemInfo(const OrtMemoryInfo* mem_info, uint32_t* device_id) const { - auto iter = cuda_gpu_memory_info_to_device_id_map_.find(mem_info); - if (iter != cuda_gpu_memory_info_to_device_id_map_.end()) { - *device_id = iter->second; - return true; - } - return false; -} - -const OrtMemoryInfo* TensorrtExecutionProviderFactory::GetDefaultGpuMemInfoForDeviceId(uint32_t device_id) const { - auto iter = device_id_to_cuda_gpu_memory_info_map_.find(device_id); - if (iter != device_id_to_cuda_gpu_memory_info_map_.end()) { - return iter->second; - } - return nullptr; -} - -void TensorrtExecutionProviderFactory::SetDefaultGpuMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id) { - cuda_gpu_memory_info_to_device_id_map_[mem_info.get()] = device_id; - device_id_to_cuda_gpu_memory_info_map_[device_id] = mem_info.get(); - cuda_gpu_memory_infos_.push_back(std::move(mem_info)); -} - -bool TensorrtExecutionProviderFactory::GetDeviceIdForHostAccessibleMemInfo(const OrtMemoryInfo* mem_info, uint32_t* device_id) const { - auto iter = cuda_pinned_memory_info_to_device_id_map_.find(mem_info); - if (iter != cuda_pinned_memory_info_to_device_id_map_.end()) { - *device_id = iter->second; - return true; - } - return false; -} - -const OrtMemoryInfo* TensorrtExecutionProviderFactory::GetHostAccessibleMemInfoForDeviceId(uint32_t device_id) const { - auto iter = device_id_to_cuda_pinned_memory_info_map_.find(device_id); - if (iter != device_id_to_cuda_pinned_memory_info_map_.end()) { - return iter->second; - } - return nullptr; -} - -void TensorrtExecutionProviderFactory::SetHostAccessibleMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id) { - cuda_pinned_memory_info_to_device_id_map_[mem_info.get()] = device_id; - device_id_to_cuda_pinned_memory_info_map_[device_id] = mem_info.get(); - cuda_pinned_memory_infos_.push_back(std::move(mem_info)); -} - void TensorrtExecutionProviderFactory::SetGPUDataTransfer(std::unique_ptr gpu_data_transfer) { data_transfer_impl_ = std::move(gpu_data_transfer); } diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h index b73550f9..c9a43931 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -16,6 +16,15 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { const OrtMemoryInfo* GetHostAccessibleMemInfoForDeviceId(uint32_t device_id) const; + OrtStatus* CreateMemoryInfoForDevices(int num_devices); + + // CUDA gpu memory and CUDA pinned memory are required for allocator and data transfer, these are the OrtMemoryInfo + // instance required for that. + // Current TRT EP implementation uses one default OrtMemoryInfo and one host accessible OrtMemoryInfo per ep device. + std::vector cuda_gpu_memory_infos; + std::vector cuda_pinned_memory_infos; + std::unordered_map device_id_to_cuda_gpu_memory_info_map; // device id -> OrtMemoryInfo + private: static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; @@ -44,33 +53,11 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, OrtDataTransferImpl** data_transfer) noexcept; - bool GetDeviceIdForDefaultGpuMemInfo(const OrtMemoryInfo* mem_info, uint32_t* device_id) const; - - void SetDefaultGpuMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id); - - bool GetDeviceIdForHostAccessibleMemInfo(const OrtMemoryInfo* mem_info, uint32_t* device_id) const; - - void SetHostAccessibleMemInfo(MemoryInfoUniquePtr mem_info, uint32_t device_id); - void SetGPUDataTransfer(std::unique_ptr gpu_data_transfer); const std::string ep_name_; // EP name const std::string vendor_{"Nvidia"}; // EP vendor name const std::string ep_version_{"0.1.0"}; // EP version - // OrtMemoryInfo for allocators and data transfer. - - // CUDA gpu memory and CUDA pinned memory are required for allocator and data transfer, these are the OrtMemoryInfo instance required for that. - // Current TRT EP implementation uses one default OrtMemoryInfo and one host accessible OrtMemoryInfo per ep device. - std::unordered_map cuda_gpu_memory_info_to_device_id_map_; // OrtMemoryInfo -> device id - std::unordered_map cuda_pinned_memory_info_to_device_id_map_; - std::unordered_map device_id_to_cuda_gpu_memory_info_map_; // device id -> OrtMemoryInfo - std::unordered_map device_id_to_cuda_pinned_memory_info_map_; - std::vector cuda_gpu_memory_infos_; - std::vector cuda_pinned_memory_infos_; - - // CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed. - // MemoryInfoUniquePtr cpu_memory_info_; - std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory }; \ No newline at end of file From 7467c65f920384bcd5922557bfb6fd189313a9a9 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 16 Jul 2025 15:29:54 -0700 Subject: [PATCH 37/60] update as onnxruntime_ep_c_api.h changes --- .../tensorrt/tensorrt_execution_provider.cc | 7 ++++--- .../tensorrt/tensorrt_execution_provider.h | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index e28b61c1..8f54c07f 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -1481,7 +1481,8 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } -OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, OrtEpGraphSupportInfo* graph_support_info) { +OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { TensorrtExecutionProvider* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; @@ -1751,7 +1752,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes) { + _Out_writes_(count) OrtNode** ep_context_nodes) noexcept { TensorrtExecutionProvider* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; @@ -2228,7 +2229,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa } void ORT_API_CALL TensorrtExecutionProvider::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, - size_t num_node_compute_infos) { + size_t num_node_compute_infos) noexcept { (void)this_ptr; for (size_t i = 0; i < num_node_compute_infos; i++) { delete node_compute_infos[i]; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index fac9f024..35999869 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -313,13 +313,13 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { private: static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info); + OrtEpGraphSupportInfo* graph_support_info) noexcept; static OrtStatus* ORT_API_CALL CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes); + _Out_writes_(count) OrtNode** ep_context_nodes) noexcept; static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, - size_t num_node_compute_infos); + size_t num_node_compute_infos) noexcept; OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); From da0f9c64c8c2bd985830b523b34c2fbad60ad3a9 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 23 Jul 2025 10:43:15 -0700 Subject: [PATCH 38/60] Add support for dump and run EP Context model --- .../tensorrt/onnx_ctx_model_helper.cc | 158 +- .../tensorrt/onnx_ctx_model_helper.h | 46 + .../tensorrt/tensorrt_execution_provider.cc | 1265 +++++++++++------ .../tensorrt/tensorrt_execution_provider.h | 16 + 4 files changed, 1055 insertions(+), 430 deletions(-) diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc index 53dcab78..87d8d1cd 100644 --- a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc @@ -7,6 +7,7 @@ #include "ep_utils.h" #include "onnx_ctx_model_helper.h" +#include "onnx/onnx_pb.h" extern TensorrtLogger& GetTensorrtLogger(bool verbose_log); @@ -28,7 +29,7 @@ bool EPContextNodeHelper::GraphHasCtxNode(const OrtGraph* graph, const OrtApi& o const char* op_type = nullptr; RETURN_IF_ERROR(ort_api.Node_GetOperatorType(node, &op_type)); - if (node != nullptr && op_type == "EPContext") { + if (node != nullptr && std::string(op_type) == "EPContext") { return true; } } @@ -85,7 +86,7 @@ OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_ca std::array attributes = {}; DeferOrtRelease defer_release_attrs(attributes.data(), attributes.size(), ort_api.ReleaseOpAttr); - RETURN_IF_ERROR(ort_api.CreateOpAttr("embed_mode", &embed_mode, 1, ORT_OP_ATTR_INT, &attributes[0])); + RETURN_IF_ERROR(ort_api.CreateOpAttr("embed_mode", &embed_mode, sizeof(int64_t), ORT_OP_ATTR_INT, &attributes[0])); std::string engine_data_str = ""; if (embed_mode) { @@ -93,13 +94,13 @@ OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_ca engine_data_str.assign(engine_data, size); } RETURN_IF_ERROR( - ort_api.CreateOpAttr("ep_cache_context", engine_data_str.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[1])); + ort_api.CreateOpAttr("ep_cache_context", engine_data_str.c_str(), engine_data_str.size(), ORT_OP_ATTR_STRING, &attributes[1])); } else { - RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", engine_cache_path.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[1])); + RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", engine_cache_path.c_str(), engine_cache_path.size(), ORT_OP_ATTR_STRING, &attributes[1])); } - ort_api.CreateOpAttr("hardware_architecture", compute_capability.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[2]); + ort_api.CreateOpAttr("hardware_architecture", compute_capability.c_str(), compute_capability.size(), ORT_OP_ATTR_STRING, &attributes[2]); ort_api.CreateOpAttr("onnx_model_filename", std::filesystem::path(onnx_model_path).filename().string().c_str(), 1, ORT_OP_ATTR_STRING, &attributes[3]); @@ -111,6 +112,153 @@ OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_ca return nullptr; } +OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) { + /* + if (!ValidateEPCtxNode(graph)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node"); + } + */ + + size_t num_nodes = 0; + RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&graph, &num_nodes)); + + std::vector nodes(num_nodes); + RETURN_IF_ERROR(ort_api.Graph_GetNodes(&graph, nodes.data(), nodes.size())); + + auto node = nodes[0]; + + size_t num_node_attributes = 0; + RETURN_IF_ERROR(ort_api.Node_GetNumAttributes(node, &num_node_attributes)); + + /* + std::vector node_attributes(num_node_attributes); + RETURN_IF_ERROR(ort_api.Node_GetAttributes(node, node_attributes.data(), node_attributes.size())); + */ + + const OrtOpAttr* node_attr = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetAttributeByName(node, "embed_mode", &node_attr)); + const int64_t embed_mode = reinterpret_cast(node_attr)->i(); + + // Only make path checks if model not provided as byte buffer + //bool make_secure_path_checks = !GetModelPath(graph_viewer).empty(); + bool make_secure_path_checks = false; + + if (embed_mode) { + // Get engine from byte stream. + node_attr = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetAttributeByName(node, "ep_cache_context", &node_attr)); + const std::string& context_binary = reinterpret_cast(node_attr)->s(); + + *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(const_cast(context_binary.c_str()), + static_cast(context_binary.length()))); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it"; + if (!(*trt_engine_)) { + return ort_api.CreateStatus(ORT_EP_FAIL, "TensorRT EP could not deserialize engine from binary data"); + } + + /* + if (weight_stripped_engine_refit_) { + const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s(); + std::string placeholder; + auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename, + onnx_model_folder_path_, + placeholder, + make_secure_path_checks, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, + (*trt_engine_).get(), + false, // serialize refitted engine to disk + detailed_build_log_); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + } + */ + } else { + // Get engine from cache file. + node_attr = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetAttributeByName(node, "ep_cache_context", &node_attr)); + std::string cache_path = reinterpret_cast(node_attr)->s(); + + /* + // For security purpose, in the case of running context model, TRT EP won't allow + // engine cache path to be the relative path like "../file_path" or the absolute path. + // It only allows the engine cache to be in the same directory or sub directory of the context model. + if (IsAbsolutePath(cache_path)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "For security purpose, the ep_cache_context attribute should be set with a relative path, but it is an absolute path: " + cache_path); + } + if (IsRelativePathToParentPath(cache_path)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "The file path in ep_cache_context attribute has '..'. For security purpose, it's not allowed to point outside the directory."); + } + + // The engine cache and context model (current model) should be in the same directory + std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_)); + auto engine_cache_path = ctx_model_dir.append(cache_path); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string(); + + // If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled + if (!weight_stripped_engine_refit_) { + weight_stripped_engine_refit_ = IsWeightStrippedEngineCache(engine_cache_path); + } + + // If the serialized refitted engine is present, use it directly without refitting the engine again + if (weight_stripped_engine_refit_) { + const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string()); + if (std::filesystem::exists(refitted_engine_cache_path)) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " + refitted_engine_cache_path.string() + " exists."; + engine_cache_path = refitted_engine_cache_path.string(); + weight_stripped_engine_refit_ = false; + } + } + */ + + std::filesystem::path engine_cache_path(cache_path); + if (!std::filesystem::exists(engine_cache_path)) { + std::string error_msg = + "TensorRT EP can't find engine cache: " + engine_cache_path.string() + + ". Please make sure engine cache is in the same directory or sub-directory of context model."; + return ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + } + + std::ifstream engine_file(engine_cache_path.string(), std::ios::binary | std::ios::in); + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); + if (!(*trt_engine_)) { + std::string error_msg = "TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string(); + return ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + } + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string(); + + /* + if (weight_stripped_engine_refit_) { + const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s(); + std::string weight_stripped_engine_cache = engine_cache_path.string(); + auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename, + onnx_model_folder_path_, + weight_stripped_engine_cache, + make_secure_path_checks, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, + (*trt_engine_).get(), + true, // serialize refitted engine to disk + detailed_build_log_); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + } + */ + } + return nullptr; +} + /* * Get the weight-refitted engine cache path from a weight-stripped engine cache path * diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h index 4d77fede..0c0d0f2a 100644 --- a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h @@ -34,3 +34,49 @@ class EPContextNodeHelper : public ApiPtrs { const OrtGraph* graph_ = nullptr; const OrtNode* fused_node_ = nullptr; }; + +class EPContextNodeReader : public ApiPtrs { + public: + EPContextNodeReader(TensorrtExecutionProvider& ep, + std::unique_ptr* trt_engine, + nvinfer1::IRuntime* trt_runtime, + std::string ep_context_model_path, + std::string compute_capability, + bool weight_stripped_engine_refit, + std::string onnx_model_folder_path, + const void* onnx_model_bytestream, + size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, + bool detailed_build_log) + : ApiPtrs{static_cast(ep)}, + trt_engine_(trt_engine), + trt_runtime_(trt_runtime), + ep_context_model_path_(ep_context_model_path), + compute_capability_(compute_capability), + weight_stripped_engine_refit_(weight_stripped_engine_refit), + onnx_model_folder_path_(onnx_model_folder_path), + onnx_model_bytestream_(onnx_model_bytestream), + onnx_model_bytestream_size_(onnx_model_bytestream_size), + onnx_external_data_bytestream_(onnx_external_data_bytestream), + onnx_external_data_bytestream_size_(onnx_external_data_bytestream_size), + detailed_build_log_(detailed_build_log) { + } + + //bool ValidateEPCtxNode(const OrtGraph& graph); + + OrtStatus* GetEpContextFromGraph(const OrtGraph& graph); + + private: + std::unique_ptr* trt_engine_; + nvinfer1::IRuntime* trt_runtime_; + std::string ep_context_model_path_; // If using context model, it implies context model and engine cache is in the same directory + std::string compute_capability_; + bool weight_stripped_engine_refit_; + std::string onnx_model_folder_path_; + const void* onnx_model_bytestream_; + size_t onnx_model_bytestream_size_; + const void* onnx_external_data_bytestream_; + size_t onnx_external_data_bytestream_size_; + bool detailed_build_log_; +}; // TRTCacheModelHandler diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 8f54c07f..5ba38893 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -721,13 +721,279 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect return nodes_list_output; } +OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + TensorrtExecutionProvider* ep = static_cast(this_ptr); + const OrtApi& ort_api = ep->ort_api; + + size_t num_nodes = 0; + RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); + + // Get all the nodes from the graph + std::vector nodes(num_nodes); + RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); + + SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; + bool new_subgraph = true; + + std::unordered_set control_flow_op_set = {"If", "Loop", "Scan"}; + + // Get pre-excluded op list from provider options + auto get_exclude_ops_set = [&](std::string node_list_to_exclude) -> std::set { + std::set set; + if (!node_list_to_exclude.empty()) { + std::stringstream node_list(node_list_to_exclude); + std::string node; + while (std::getline(node_list, node, ',')) { + set.insert(node); + } + } + return set; + }; + + // auto exclude_ops_set = get_exclude_ops_set(op_types_to_exclude_); + auto exclude_ops_set = get_exclude_ops_set(""); + + /* Iterate all the nodes and exclude the node if: + * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. + * 2. Its op type is in the exclusion list. + */ + for (size_t index = 0; index < nodes.size(); index++) { + const OrtNode* node = nodes[index]; + bool supported_node = true; + + /* If current node is control flow op, we take different approach based on following four cases: + * + * (1) control flow op is supported by TRT, and its subgraphs are all supported by TRT. Assign this node to TRT. + * (2) control flow op is supported by TRT, but not all its subgraphs supported by TRT. Don't assign this node to TRT. + * (3) control flow op is not supported by TRT, but its subgraphs all supported by TRT. Don't assign this node to TRT. + * (4) control flow op is not supported by TRT, and not all its subgraphs supported by TRT. Don't assign this node to TRT. + * + * For cases 2, 3, 4, even though the control flow op is not assigned to TRT, any portion of its subgraphs that can run in TRT will be still fused and assigned to TRT EP. + */ + const char* op_type = nullptr; + RETURN_IF_ERROR(ep->ort_api.Node_GetOperatorType(node, &op_type)); + + if (control_flow_op_set.find(op_type) != control_flow_op_set.end()) { + auto supported_control_flow_op = [&](const OrtNode* node) { + OrtStatus* status = nullptr; + size_t num_subgraphs = 0; + RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetNumSubgraphs(node, &num_subgraphs)); + + std::vector node_subgraphs(num_subgraphs); + RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetSubgraphs(node, node_subgraphs.data(), node_subgraphs.size(), nullptr)); + + + // Iterate the node's subgraphs + for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { + const OrtGraph* subgraph = node_subgraphs[subgraph_idx]; + + // Get number of subgraph's nodes + size_t num_subgraph_nodes = 0; + RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Graph_GetNumNodes(subgraph, &num_subgraph_nodes)); + + // TRT EP should consider the empty subgraph is fully supported by TRT. + if (num_subgraph_nodes == 0) { + continue; + } + + /* + if (!ep->AllNodesAssignedToSpecificEP(*(subgraph->CreateGraphViewer()), kTensorrtExecutionProvider)) { + // if not all its subgraphs are supported, we need to exclude this control flow op + return false; + } + */ + } + return true; + }; + supported_node = supported_control_flow_op(node); + } + + // Exclude any ops, if applicable + if (exclude_ops_set.find(op_type) != exclude_ops_set.end()) { + supported_node = false; + } + + if (supported_node) { + if (new_subgraph) { + parser_nodes_vector.emplace_back(); + // Mark all new graphs as "UnKnown" which will later be parsed by TRT parser + parser_nodes_vector.back().second = false; + new_subgraph = false; + } + parser_nodes_vector.back().first.emplace_back(index); + } else { + new_subgraph = true; + } + } + + + // Use this local definitions for now + // TODO: Use provider option + int max_partition_iterations = 1000; + int min_subgraph_size = 1; + + bool early_termination = false; + supported_nodes_vector = ep->GetSupportedList(parser_nodes_vector, 0, max_partition_iterations, graph, &early_termination); + if (early_termination) { + supported_nodes_vector.clear(); + } + + // Temporarily make all nodes supported + supported_nodes_vector = parser_nodes_vector; + + // Remove subgraphs if its size is less than the predefined minimal size + for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end(); ++it) { + const size_t subgraph_size = it->first.size(); + if (subgraph_size < min_subgraph_size) { + supported_nodes_vector.erase(it--); + } + } + + // Detect and remove cycles from supported node list + /* ep->DetectTensorRTGraphCycles(supported_nodes_vector, graph, model_hash); */ + + // Consolidate supported node list + /* + if (supported_nodes_vector.size() > 1) { + nodes_vector.clear(); + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + nodes_vector.insert(nodes_vector.end(), group.first.begin(), group.first.end()); + } + } + SubGraphCollection_t consolidated_supported_nodes_vector = {{nodes_vector, true}}; + if (p->DetectTensorRTGraphCycles(consolidated_supported_nodes_vector, graph, model_hash, false)) { + // LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are not consolidated because graph will have cycles after consolidation"; + } else { + // LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are consolidated into one subgraph"; + supported_nodes_vector = consolidated_supported_nodes_vector; + } + } + */ + + // Handle the case where the graph is subgraph of control flow op. + // The purpose is to make control flow op as well as its subgraphs run on TRT. + // Here we need to check whether subgraph is fully supported by TRT and don't fuse the nodes of the subgraph until control flow op level. + /* + if (p->IsSubGraphOfControlFlowOp(graph) && p->IsSubGraphFullySupported(supported_nodes_vector, number_of_ort_nodes)) { + bool all_subgraphs_are_supported = true; + + // "If" control flow op has two subgraph bodies, "then" body and "else" body respectively. + // Check its parent node's another subgraph to see whether that subgraph is also fully supported by TRT. + const OrtNode* parent_node = nullptr; + graph_api_->OrtGraph_GetParenNode(graph, &parent_node); + const char* parent_node_op_type = nullptr; + graph_api_->OrtNode_GetOpType(parent_node, &parent_node_op_type); + if (strcmp(parent_node_op_type, "If") == 0) { + all_subgraphs_are_supported = false; + SubGraphCollection_t subgraph_supported_nodes_vector; + const OrtGraphViewer** subgraphs = nullptr; + size_t subgraph_count = 0; + graph_api_->OrtNode_GetSubgraphs(parent_node, &subgraphs, &subgraph_count); + for (size_t i = 0; i < subgraph_count; i++) { + bool same_graph = false; + graph_api_->OrtGraph_IsSameGraph(graph, subgraphs[i], &same_graph); + if (same_graph) { + continue; + } + int number_of_ort_subgraph_nodes = 0; + graph_api_->OrtGraph_NumberOfNodes(subgraphs[i], &number_of_ort_subgraph_nodes); + std::vector subgraph_nodes_vector(number_of_ort_subgraph_nodes); + std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0); + SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}}; + bool subgraph_early_termination = false; + + // Another subgraph of "If" control flow op has no nodes. + // In this case, TRT EP should consider this empty subgraph is fully supported by TRT. + if (number_of_ort_subgraph_nodes == 0) { + all_subgraphs_are_supported = true; + break; + } + // Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. + else if (p->AllNodesAssignedToSpecificEP(subgraphs[i], tensorrtEp)) { + all_subgraphs_are_supported = true; + break; + } + // Another subgraph of "If" control flow has been parsed by GetCapability and not all subgraph's nodes assigned to TRT EP. + // (Note: GetExecutionProviderType() returns "" meaning node has not yet been assigned to any EPs) + else if (!p->AllNodesAssignedToSpecificEP(subgraphs[i], "")) { + all_subgraphs_are_supported = false; + break; + } + + // Another subgraph of "If" control flow has not yet been parsed by GetCapability. + subgraph_supported_nodes_vector = p->GetSupportedList(parser_subgraph_nodes_vector, 0, p->max_partition_iterations_, subgraphs[i], &subgraph_early_termination); + all_subgraphs_are_supported = p->IsSubGraphFullySupported(subgraph_supported_nodes_vector, number_of_ort_subgraph_nodes); + break; + } + graph_api_->OrtGraph_ReleaseGraphViewerArray(subgraphs, subgraph_count); + } + + if (all_subgraphs_are_supported) { + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + for (const auto& index : group.first) { + std::unique_ptr sub_graph = std::make_unique(); + sub_graph->node_index_len = 1; + sub_graph->node_index = new size_t[sub_graph->node_index_len]; + sub_graph->node_index[0] = nodes_index[index]; + cache.push_back(sub_graph.release()); + } + } + } + *cnt = cache.size(); + *indexed_sub_graph = new OrtIndexedSubGraph*[*cnt]; + for (size_t i = 0; i < *cnt; i++) { + (*indexed_sub_graph)[i] = cache[i]; + } + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; + return; + } + } + */ + + int number_of_trt_nodes = 0; + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + std::vector supported_nodes; + supported_nodes.reserve(group.first.size()); + + for (const auto& index : group.first) { + const OrtNode* supported_node = nodes[index]; + + supported_nodes.push_back(supported_node); + } + + // Create (optional) fusion options for the supported nodes to fuse. + OrtNodeFusionOptions node_fusion_options = {}; + node_fusion_options.ort_version_supported = ORT_API_VERSION; + + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, supported_nodes.data(), + supported_nodes.size(), &node_fusion_options)); + number_of_trt_nodes += static_cast(group.first.size()); + } + } + + const size_t number_of_subgraphs = supported_nodes_vector.size(); + if (number_of_trt_nodes == 0) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] No graph will run on TensorRT execution provider"; + } else if (number_of_trt_nodes == nodes.size()) { + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; + } else { + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; + } + + return nullptr; +} + OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this_ptr, const OrtGraph* graph, const OrtNode* fused_node, std::unordered_map& input_map, std::unordered_map& output_map, - /* out */OrtNodeComputeInfo** node_compute_info, - /* out */OrtNode** ep_context_node) { + /* out */ OrtNodeComputeInfo** node_compute_info, + /* out */ OrtNode** ep_context_node) { TensorrtExecutionProvider* ep = static_cast(this_ptr); /* @@ -807,7 +1073,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this if (layer->getType() == nvinfer1::LayerType::kELEMENTWISE && next_layer->getType() == nvinfer1::LayerType::kREDUCE && (static_cast(layer))->getOperation() == nvinfer1::ElementWiseOperation::kPOW) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow"; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow"; layer->setPrecision(nvinfer1::DataType::kFLOAT); next_layer->setPrecision(nvinfer1::DataType::kFLOAT); layer->setOutputType(0, nvinfer1::DataType::kFLOAT); @@ -950,7 +1216,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this for (; it != end; ++it) { msg << "," << it->first; } - //return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, msg.str()); + // return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, msg.str()); } else { for (auto trt_profile : trt_profiles) { trt_config->addOptimizationProfile(trt_profile); @@ -974,8 +1240,8 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this #pragma warning(pop) #endif fp16_enable_ = false; - //LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE or ORT_TENSORRT_BF16_ENABLE is set, but " - // "platform doesn't support fast native fp16/bf16"; + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE or ORT_TENSORRT_BF16_ENABLE is set, but " + // "platform doesn't support fast native fp16/bf16"; } } @@ -989,8 +1255,8 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this #pragma warning(pop) #endif int8_enable_ = false; - //LOGS_DEFAULT(WARNING) - // << "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8"; + // LOGS_DEFAULT(WARNING) + // << "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8"; } } @@ -1016,12 +1282,12 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this if (fp16_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); trt_node_name_with_precision += "_fp16"; - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; } if (int8_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); trt_node_name_with_precision += "_int8"; - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; } #if defined(_MSC_VER) #pragma warning(pop) @@ -1031,16 +1297,16 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8 int number_of_dla_core = trt_builder->getNbDLACores(); if (number_of_dla_core == 0) { - //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core"; + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core"; dla_enable_ = false; } else { if (dla_core_ >= number_of_dla_core) { - //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core #" << dla_core_ - // << ", but it exceeds platform's maximum DLA core number " << number_of_dla_core - // << ". Use DLA core 0 instead."; + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core #" << dla_core_ + // << ", but it exceeds platform's maximum DLA core number " << number_of_dla_core + // << ". Use DLA core 0 instead."; dla_core_ = 0; } - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << dla_core_; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << dla_core_; trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); trt_config->setDLACore(dla_core_); @@ -1052,7 +1318,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this // enable sparse weights if (sparsity_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; } #if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 if (build_heuristics_enable_) { @@ -1065,11 +1331,11 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this // for TRT 8.6 onwards, heuristic-based tactic option is automatically enabled by setting builder optimization level 2 if (build_heuristics_enable_) { if (builder_optimization_level_ == 2) { - //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are automatically enabled by builder optimization " - // "level 2. trt_build_heuristics_enable is deprecated on TRT 8.6 onwards."; + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are automatically enabled by builder optimization " + // "level 2. trt_build_heuristics_enable is deprecated on TRT 8.6 onwards."; } else { - //LOGS_DEFAULT(WARNING) << "[TensorRT EP] trt_build_heuristics_enable is deprecated on TRT 8.6 onwards. Please set " - // "builder optimization level as 2 to enable builder heuristics."; + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] trt_build_heuristics_enable is deprecated on TRT 8.6 onwards. Please set " + // "builder optimization level as 2 to enable builder heuristics."; } } #endif @@ -1078,13 +1344,13 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this // switch optimizaion level if (builder_optimization_level_ != 3) { trt_config->setBuilderOptimizationLevel(builder_optimization_level_); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; } // limit auxiliary streams if (auxiliary_streams_ >= 0) { trt_config->setMaxAuxStreams(auxiliary_streams_); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << auxiliary_streams_; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << auxiliary_streams_; } #else if (builder_optimization_level_ != 3) { @@ -1098,9 +1364,9 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this if (weight_stripped_engine_enable_) { #if NV_TENSORRT_MAJOR >= 10 trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; #else LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; #endif @@ -1111,7 +1377,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this nvinfer1::TacticSources tactics = trt_config->getTacticSources(); tactics |= GetTacticSourceFromString(tactic_sources_); trt_config->setTacticSources(tactics); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using " << tactic_sources_; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using " << tactic_sources_; } // Build TRT engine (if needed) and load TRT engine if: @@ -1139,7 +1405,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this if (engine_cache_enable_ && engine_hw_compatible_) { trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); cache_hw_compat = "_sm80+"; - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; } #endif @@ -1179,9 +1445,9 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this engine_update = CompareProfiles(profile_cache_path, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); if (engine_update) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine will be built"; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine will be built"; } else { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine won't be rebuilt"; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine won't be rebuilt"; } } @@ -1194,7 +1460,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this engine_file.read((char*)engine_buf.get(), engine_size); trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; if (trt_engine == nullptr) { std::string err_msg = "TensorRT EP could not deserialize engine from cache: " + engine_cache_path; return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); @@ -1216,7 +1482,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this // Deserialize engine trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; if (trt_engine == nullptr) { std::string err_msg = "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path; return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); @@ -1250,7 +1516,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); if (detailed_build_log_) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; } } @@ -1275,16 +1541,16 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } if (detailed_build_log_) { auto engine_build_stop = std::chrono::steady_clock::now(); - //LOGS_DEFAULT(INFO) - // << "TensorRT engine build for " << trt_node_name_with_precision << " took: " - // << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() - // << "ms" << std::endl; + // LOGS_DEFAULT(INFO) + // << "TensorRT engine build for " << trt_node_name_with_precision << " took: " + // << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() + // << "ms" << std::endl; } if (engine_cache_enable_) { // Serialize engine profile if it has explicit profiles if (has_explicit_profile) { SerializeProfileV2(profile_cache_path, input_explicit_shape_ranges); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; } if (engine_decryption_enable_) { @@ -1296,15 +1562,15 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this std::string err_msg = "TensorRT EP call to engine encryption library failed"; return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; } else { - //LOGS_DEFAULT(WARNING) - // << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; + // LOGS_DEFAULT(WARNING) + // << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; } } else { std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path; } } // serialize and save timing cache @@ -1317,434 +1583,309 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); if (detailed_build_log_) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; } } } } if (weight_stripped_engine_refit_) { - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refit engine from main ONNX file after engine build"; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refit engine from main ONNX file after engine build"; char* onnx = string_buf.data(); size_t onnx_size = string_buf.size(); auto status = RefitEngine(model_path_, onnx_model_folder_path_, engine_cache_path, false /* path check for security */, onnx, onnx_size, trt_engine.get(), true /* serialize refitted engine to disk */, detailed_build_log_); if (status != nullptr) { - return ort_api.CreateStatus(ORT_EP_FAIL, "RefitEngine failed."); - } - } - - // Build context - // Note: Creating an execution context from an engine is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - if (context_memory_sharing_enable_) { - // Reset the max_ctx_mem_size_ and context_memory_ since we don't have access to the allocator here. - max_ctx_mem_size_ = 0; - context_memory_ = nullptr; -#if NV_TENSORRT_MAJOR < 10 - trt_context = - std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); -#else - trt_context = std::unique_ptr( - trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); -#endif - } else { - trt_context = std::unique_ptr(trt_engine->createExecutionContext()); - } - if (!trt_context) { - std::string err_msg = "TensorRT EP could not build execution context for fused node: " + fused_node_name; - return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); - } - } - - // Create input to index map - // TRT network input -> ORT fused_node input index - for (int i = 0; i < num_inputs; ++i) { - auto input = trt_network->getInput(i); - const std::string& input_name = input->getName(); - const auto& iter = input_map.find(input_name); - if (iter != input_map.end()) { - input_indexes[input_name] = iter->second; - } - } - - // Create output to index and type maps - // TRT network output -> ORT fused_node output index - const auto& graph_output = model_proto.graph().output(); - for (int i = 0; i < num_outputs; ++i) { - const std::string& output_name = trt_network->getOutput(i)->getName(); - const auto& iter = output_map.find(output_name); - if (iter != output_map.end()) { - output_indexes[output_name] = iter->second; - } - const auto& tensor_type = graph_output[i].type().tensor_type(); - output_types[output_name] = tensor_type.elem_type(); - } - - // Save TRT engine, other TRT objects and input/output info to map - parsers_.emplace(fused_node_name, std::move(trt_parser)); - engines_.emplace(fused_node_name, std::move(trt_engine)); - contexts_.emplace(fused_node_name, std::move(trt_context)); - networks_.emplace(fused_node_name, std::move(trt_network)); - input_info_[fused_node_name].push_back(input_indexes); - output_info_[fused_node_name].push_back(output_indexes); - output_info_[fused_node_name].push_back(output_types); - input_shape_ranges_[fused_node_name] = input_implicit_shape_ranges; - profiles_.emplace(fused_node_name, std::move(trt_profiles)); - - - // Create EP Context nodes - std::unique_ptr ep_ctx_node_helper = std::make_unique(*ep, graph, fused_node); - if (dump_ep_context_model_) { - std::string compute_capability_hw_compat = compute_capability_; - if (engine_cache_enable_ && engine_hw_compatible_) { - compute_capability_hw_compat = "80+"; - } - - ep_ctx_node_helper->CreateEPContextNode(engine_cache_path, - reinterpret_cast(serialized_engine->data()), - serialized_engine->size(), - ep_context_embed_mode_, - compute_capability_hw_compat, - model_path_, - ep_context_node); - } - - std::unique_ptr compute_state = std::make_unique(); - - // translate tactic sources string to nvinfer1::TacticSources - nvinfer1::TacticSources tactics = 0; - if (!tactic_sources_.empty()) { - tactics = GetTacticSourceFromString(tactic_sources_); - } - *compute_state = { - static_cast(device_id_), - fused_node_name, - builder_.get(), - &parsers_[fused_node_name], - &engines_[fused_node_name], - &contexts_[fused_node_name], - &networks_[fused_node_name], - input_info_[fused_node_name], - output_info_[fused_node_name], - input_shape_ranges_[fused_node_name], - &tensorrt_mu_, - compute_capability_, - max_workspace_size_, - fp16_enable_, - int8_enable_, - int8_calibration_cache_available_, - dla_enable_, - dla_core_, - trt_node_name_with_precision, - engine_cache_enable_, - cache_path_, - runtime_.get(), - profiles_[fused_node_name], - context_memory_sharing_enable_, - &max_ctx_mem_size_, - &context_memory_, - dynamic_range_map, - engine_decryption_enable_, - engine_decryption_, - engine_encryption_, - timing_cache_enable_, - global_cache_path_, - force_timing_cache_match_, - detailed_build_log_, - build_heuristics_enable_, - sparsity_enable_, - builder_optimization_level_, - auxiliary_streams_, - !tactic_sources_.empty(), - tactics, - cuda_graph_enable_, - weight_stripped_engine_enable_, - weight_stripped_engine_refit_, - model_path_, - onnx_model_folder_path_, - onnx_model_bytestream_, - onnx_model_bytestream_size_, - cache_prefix_, - cache_suffix, - engine_hw_compatible_, - sync_stream_after_enqueue_}; - - ep->compute_states_[fused_node_name] = std::move(compute_state); - - // Update the OrtNodeComputeInfo associated with the graph. - auto ep_node_compute_info = std::make_unique(*ep); - *node_compute_info = ep_node_compute_info.release(); - - return nullptr; -} - - -OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info) noexcept { - TensorrtExecutionProvider* ep = static_cast(this_ptr); - const OrtApi& ort_api = ep->ort_api; - - size_t num_nodes = 0; - RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); - - // Get all the nodes from the graph - std::vector nodes(num_nodes); - RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); - - SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; - bool new_subgraph = true; - - std::unordered_set control_flow_op_set = {"If", "Loop", "Scan"}; - - // Get pre-excluded op list from provider options - auto get_exclude_ops_set = [&](std::string node_list_to_exclude) -> std::set { - std::set set; - if (!node_list_to_exclude.empty()) { - std::stringstream node_list(node_list_to_exclude); - std::string node; - while (std::getline(node_list, node, ',')) { - set.insert(node); - } - } - return set; - }; - - // auto exclude_ops_set = get_exclude_ops_set(op_types_to_exclude_); - auto exclude_ops_set = get_exclude_ops_set(""); - - /* Iterate all the nodes and exclude the node if: - * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. - * 2. Its op type is in the exclusion list. - */ - for (size_t index = 0; index < nodes.size(); index++) { - const OrtNode* node = nodes[index]; - bool supported_node = true; - - /* If current node is control flow op, we take different approach based on following four cases: - * - * (1) control flow op is supported by TRT, and its subgraphs are all supported by TRT. Assign this node to TRT. - * (2) control flow op is supported by TRT, but not all its subgraphs supported by TRT. Don't assign this node to TRT. - * (3) control flow op is not supported by TRT, but its subgraphs all supported by TRT. Don't assign this node to TRT. - * (4) control flow op is not supported by TRT, and not all its subgraphs supported by TRT. Don't assign this node to TRT. - * - * For cases 2, 3, 4, even though the control flow op is not assigned to TRT, any portion of its subgraphs that can run in TRT will be still fused and assigned to TRT EP. - */ - const char* op_type = nullptr; - RETURN_IF_ERROR(ep->ort_api.Node_GetOperatorType(node, &op_type)); - - if (control_flow_op_set.find(op_type) != control_flow_op_set.end()) { - auto supported_control_flow_op = [&](const OrtNode* node) { - OrtStatus* status = nullptr; - size_t num_subgraphs = 0; - RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetNumSubgraphs(node, &num_subgraphs)); - - std::vector node_subgraphs(num_subgraphs); - RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetSubgraphs(node, node_subgraphs.data(), node_subgraphs.size(), nullptr)); - - - // Iterate the node's subgraphs - for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { - const OrtGraph* subgraph = node_subgraphs[subgraph_idx]; - - // Get number of subgraph's nodes - size_t num_subgraph_nodes = 0; - RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Graph_GetNumNodes(subgraph, &num_subgraph_nodes)); - - // TRT EP should consider the empty subgraph is fully supported by TRT. - if (num_subgraph_nodes == 0) { - continue; - } - - /* - if (!ep->AllNodesAssignedToSpecificEP(*(subgraph->CreateGraphViewer()), kTensorrtExecutionProvider)) { - // if not all its subgraphs are supported, we need to exclude this control flow op - return false; - } - */ - } - return true; - }; - supported_node = supported_control_flow_op(node); + return ort_api.CreateStatus(ORT_EP_FAIL, "RefitEngine failed."); + } } - // Exclude any ops, if applicable - if (exclude_ops_set.find(op_type) != exclude_ops_set.end()) { - supported_node = false; + // Build context + // Note: Creating an execution context from an engine is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + if (context_memory_sharing_enable_) { + // Reset the max_ctx_mem_size_ and context_memory_ since we don't have access to the allocator here. + max_ctx_mem_size_ = 0; + context_memory_ = nullptr; +#if NV_TENSORRT_MAJOR < 10 + trt_context = + std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); +#else + trt_context = std::unique_ptr( + trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); +#endif + } else { + trt_context = std::unique_ptr(trt_engine->createExecutionContext()); + } + if (!trt_context) { + std::string err_msg = "TensorRT EP could not build execution context for fused node: " + fused_node_name; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } + } - if (supported_node) { - if (new_subgraph) { - parser_nodes_vector.emplace_back(); - // Mark all new graphs as "UnKnown" which will later be parsed by TRT parser - parser_nodes_vector.back().second = false; - new_subgraph = false; - } - parser_nodes_vector.back().first.emplace_back(index); - } else { - new_subgraph = true; + // Create input to index map + // TRT network input -> ORT fused_node input index + for (int i = 0; i < num_inputs; ++i) { + auto input = trt_network->getInput(i); + const std::string& input_name = input->getName(); + const auto& iter = input_map.find(input_name); + if (iter != input_map.end()) { + input_indexes[input_name] = iter->second; } } + // Create output to index and type maps + // TRT network output -> ORT fused_node output index + const auto& graph_output = model_proto.graph().output(); + for (int i = 0; i < num_outputs; ++i) { + const std::string& output_name = trt_network->getOutput(i)->getName(); + const auto& iter = output_map.find(output_name); + if (iter != output_map.end()) { + output_indexes[output_name] = iter->second; + } + const auto& tensor_type = graph_output[i].type().tensor_type(); + output_types[output_name] = tensor_type.elem_type(); + } - // Use this local definitions for now - // TODO: Use provider option - int max_partition_iterations = 1000; - int min_subgraph_size = 1; + // Save TRT engine, other TRT objects and input/output info to map + parsers_.emplace(fused_node_name, std::move(trt_parser)); + engines_.emplace(fused_node_name, std::move(trt_engine)); + contexts_.emplace(fused_node_name, std::move(trt_context)); + networks_.emplace(fused_node_name, std::move(trt_network)); + input_info_[fused_node_name].push_back(input_indexes); + output_info_[fused_node_name].push_back(output_indexes); + output_info_[fused_node_name].push_back(output_types); + input_shape_ranges_[fused_node_name] = input_implicit_shape_ranges; + profiles_.emplace(fused_node_name, std::move(trt_profiles)); - bool early_termination = false; - supported_nodes_vector = ep->GetSupportedList(parser_nodes_vector, 0, max_partition_iterations, graph, &early_termination); - if (early_termination) { - supported_nodes_vector.clear(); - } + // Create EP Context nodes + std::unique_ptr ep_ctx_node_helper = std::make_unique(*ep, graph, fused_node); + if (dump_ep_context_model_) { + std::string compute_capability_hw_compat = compute_capability_; + if (engine_cache_enable_ && engine_hw_compatible_) { + compute_capability_hw_compat = "80+"; + } - // Temporarily make all nodes supported - supported_nodes_vector = parser_nodes_vector; + char* serialized_engine_pointer = nullptr; + size_t serialized_engine_size = 0; - // Remove subgraphs if its size is less than the predefined minimal size - for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end(); ++it) { - const size_t subgraph_size = it->first.size(); - if (subgraph_size < min_subgraph_size) { - supported_nodes_vector.erase(it--); + if (serialized_engine) { + serialized_engine_pointer = reinterpret_cast(serialized_engine->data()); + serialized_engine_size = serialized_engine->size(); + } else if (!serialized_engine && ep_context_embed_mode_ && engine_cache_enable_) { + serialized_engine = std::unique_ptr(trt_engine->serialize()); + serialized_engine_pointer = reinterpret_cast(serialized_engine->data()); + serialized_engine_size = serialized_engine->size(); } + + ep_ctx_node_helper->CreateEPContextNode(engine_cache_path, + serialized_engine_pointer, + serialized_engine_size, + ep_context_embed_mode_, + compute_capability_hw_compat, + model_path_, + ep_context_node); } - // Detect and remove cycles from supported node list - /* ep->DetectTensorRTGraphCycles(supported_nodes_vector, graph, model_hash); */ + std::unique_ptr compute_state = std::make_unique(); - // Consolidate supported node list - /* - if (supported_nodes_vector.size() > 1) { - nodes_vector.clear(); - for (const auto& group : supported_nodes_vector) { - if (!group.first.empty()) { - nodes_vector.insert(nodes_vector.end(), group.first.begin(), group.first.end()); - } - } - SubGraphCollection_t consolidated_supported_nodes_vector = {{nodes_vector, true}}; - if (p->DetectTensorRTGraphCycles(consolidated_supported_nodes_vector, graph, model_hash, false)) { - // LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are not consolidated because graph will have cycles after consolidation"; - } else { - // LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are consolidated into one subgraph"; - supported_nodes_vector = consolidated_supported_nodes_vector; - } + // translate tactic sources string to nvinfer1::TacticSources + nvinfer1::TacticSources tactics = 0; + if (!tactic_sources_.empty()) { + tactics = GetTacticSourceFromString(tactic_sources_); } - */ + *compute_state = { + static_cast(device_id_), + fused_node_name, + builder_.get(), + &parsers_[fused_node_name], + &engines_[fused_node_name], + &contexts_[fused_node_name], + &networks_[fused_node_name], + input_info_[fused_node_name], + output_info_[fused_node_name], + input_shape_ranges_[fused_node_name], + &tensorrt_mu_, + compute_capability_, + max_workspace_size_, + fp16_enable_, + int8_enable_, + int8_calibration_cache_available_, + dla_enable_, + dla_core_, + trt_node_name_with_precision, + engine_cache_enable_, + cache_path_, + runtime_.get(), + profiles_[fused_node_name], + context_memory_sharing_enable_, + &max_ctx_mem_size_, + &context_memory_, + dynamic_range_map, + engine_decryption_enable_, + engine_decryption_, + engine_encryption_, + timing_cache_enable_, + global_cache_path_, + force_timing_cache_match_, + detailed_build_log_, + build_heuristics_enable_, + sparsity_enable_, + builder_optimization_level_, + auxiliary_streams_, + !tactic_sources_.empty(), + tactics, + cuda_graph_enable_, + weight_stripped_engine_enable_, + weight_stripped_engine_refit_, + model_path_, + onnx_model_folder_path_, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + cache_prefix_, + cache_suffix, + engine_hw_compatible_, + sync_stream_after_enqueue_}; - // Handle the case where the graph is subgraph of control flow op. - // The purpose is to make control flow op as well as its subgraphs run on TRT. - // Here we need to check whether subgraph is fully supported by TRT and don't fuse the nodes of the subgraph until control flow op level. - /* - if (p->IsSubGraphOfControlFlowOp(graph) && p->IsSubGraphFullySupported(supported_nodes_vector, number_of_ort_nodes)) { - bool all_subgraphs_are_supported = true; + ep->compute_states_[fused_node_name] = std::move(compute_state); - // "If" control flow op has two subgraph bodies, "then" body and "else" body respectively. - // Check its parent node's another subgraph to see whether that subgraph is also fully supported by TRT. - const OrtNode* parent_node = nullptr; - graph_api_->OrtGraph_GetParenNode(graph, &parent_node); - const char* parent_node_op_type = nullptr; - graph_api_->OrtNode_GetOpType(parent_node, &parent_node_op_type); - if (strcmp(parent_node_op_type, "If") == 0) { - all_subgraphs_are_supported = false; - SubGraphCollection_t subgraph_supported_nodes_vector; - const OrtGraphViewer** subgraphs = nullptr; - size_t subgraph_count = 0; - graph_api_->OrtNode_GetSubgraphs(parent_node, &subgraphs, &subgraph_count); - for (size_t i = 0; i < subgraph_count; i++) { - bool same_graph = false; - graph_api_->OrtGraph_IsSameGraph(graph, subgraphs[i], &same_graph); - if (same_graph) { - continue; - } - int number_of_ort_subgraph_nodes = 0; - graph_api_->OrtGraph_NumberOfNodes(subgraphs[i], &number_of_ort_subgraph_nodes); - std::vector subgraph_nodes_vector(number_of_ort_subgraph_nodes); - std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0); - SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}}; - bool subgraph_early_termination = false; + // Update the OrtNodeComputeInfo associated with the graph. + auto ep_node_compute_info = std::make_unique(*ep); + *node_compute_info = ep_node_compute_info.release(); - // Another subgraph of "If" control flow op has no nodes. - // In this case, TRT EP should consider this empty subgraph is fully supported by TRT. - if (number_of_ort_subgraph_nodes == 0) { - all_subgraphs_are_supported = true; - break; - } - // Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. - else if (p->AllNodesAssignedToSpecificEP(subgraphs[i], tensorrtEp)) { - all_subgraphs_are_supported = true; - break; - } - // Another subgraph of "If" control flow has been parsed by GetCapability and not all subgraph's nodes assigned to TRT EP. - // (Note: GetExecutionProviderType() returns "" meaning node has not yet been assigned to any EPs) - else if (!p->AllNodesAssignedToSpecificEP(subgraphs[i], "")) { - all_subgraphs_are_supported = false; - break; - } + return nullptr; +} - // Another subgraph of "If" control flow has not yet been parsed by GetCapability. - subgraph_supported_nodes_vector = p->GetSupportedList(parser_subgraph_nodes_vector, 0, p->max_partition_iterations_, subgraphs[i], &subgraph_early_termination); - all_subgraphs_are_supported = p->IsSubGraphFullySupported(subgraph_supported_nodes_vector, number_of_ort_subgraph_nodes); - break; - } - graph_api_->OrtGraph_ReleaseGraphViewerArray(subgraphs, subgraph_count); - } +OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(OrtEp* this_ptr, const OrtGraph* graph, + const OrtNode* fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo** node_compute_info) { - if (all_subgraphs_are_supported) { - for (const auto& group : supported_nodes_vector) { - if (!group.first.empty()) { - for (const auto& index : group.first) { - std::unique_ptr sub_graph = std::make_unique(); - sub_graph->node_index_len = 1; - sub_graph->node_index = new size_t[sub_graph->node_index_len]; - sub_graph->node_index[0] = nodes_index[index]; - cache.push_back(sub_graph.release()); - } - } + TensorrtExecutionProvider* ep = static_cast(this_ptr); + + const char* name = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetName(fused_node, &name)); + std::string fused_node_name = name; + + std::unique_ptr trt_engine; + std::unique_ptr trt_context; + std::unordered_map input_indexes; // TRT engine input name -> ORT kernel context input index + std::unordered_map output_indexes; // TRT engine output name -> ORT kernel context output index + std::unordered_map output_types; // TRT engine output name -> ORT output tensor type + + // Get engine binary data and deserialize it + std::unique_ptr ep_context_node_reader = std::make_unique(*ep, + &trt_engine, + runtime_.get(), + model_path_, + compute_capability_, + weight_stripped_engine_enable_, + onnx_model_folder_path_, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, + detailed_build_log_); + RETURN_IF_ERROR(ep_context_node_reader->GetEpContextFromGraph(*graph)); + + // Build context + // Note: Creating an execution context from an engine is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + if (context_memory_sharing_enable_) { + // Reset the max_ctx_mem_size_ and context_memory_ since we don't have access to the allocator here. + max_ctx_mem_size_ = 0; + context_memory_ = nullptr; +#if NV_TENSORRT_MAJOR < 10 + trt_context = + std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); +#else + trt_context = std::unique_ptr( + trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); +#endif + } else { + trt_context = std::unique_ptr(trt_engine->createExecutionContext()); + } + if (!trt_context) { + std::string err_msg = "TensorRT EP could not build execution context for fused node: " + fused_node_name; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + // Create input/output to index maps + // TRT engine input -> ORT fused_node input index + // TRT engine output -> ORT fused_node output index + for (int32_t i = 0; i < trt_engine->getNbIOTensors(); ++i) { + auto const& name = trt_engine->getIOTensorName(i); + auto const& mode = trt_engine->getTensorIOMode(name); + if (mode == nvinfer1::TensorIOMode::kINPUT) { + const auto& iter = input_map.find(name); + if (iter != input_map.end()) { + input_indexes[name] = iter->second; } - *cnt = cache.size(); - *indexed_sub_graph = new OrtIndexedSubGraph*[*cnt]; - for (size_t i = 0; i < *cnt; i++) { - (*indexed_sub_graph)[i] = cache[i]; + } else { + const auto& iter = output_map.find(name); + if (iter != output_map.end()) { + output_indexes[name] = iter->second; } - // LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; - return; } } - */ - int number_of_trt_nodes = 0; - for (const auto& group : supported_nodes_vector) { - if (!group.first.empty()) { - std::vector supported_nodes; - supported_nodes.reserve(group.first.size()); + // Create output to type map + size_t num_graph_outputs = 0; + RETURN_IF_ERROR(ort_api.Graph_GetNumOutputs(graph, &num_graph_outputs)); - for (const auto& index : group.first) { - const OrtNode* supported_node = nodes[index]; + std::vector graph_outputs(num_graph_outputs); + RETURN_IF_ERROR(ort_api.Graph_GetOutputs(graph, graph_outputs.data(), graph_outputs.size())); - supported_nodes.push_back(supported_node); - } + for (size_t i = 0; i < graph_outputs.size(); i++) { + const OrtValueInfo* value_info = graph_outputs[i]; - // Create (optional) fusion options for the supported nodes to fuse. - OrtNodeFusionOptions node_fusion_options = {}; - node_fusion_options.ort_version_supported = ORT_API_VERSION; + const char* value_info_name = nullptr; + RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &value_info_name)); - RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, supported_nodes.data(), - supported_nodes.size(), &node_fusion_options)); - number_of_trt_nodes += static_cast(group.first.size()); - } - } + const OrtTypeInfo* type_info = nullptr; + RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(value_info, &type_info)); - const size_t number_of_subgraphs = supported_nodes_vector.size(); - if (number_of_trt_nodes == 0) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] No graph will run on TensorRT execution provider"; - } else if (number_of_trt_nodes == nodes.size()) { - // LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; - } else { - // LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; + const OrtTensorTypeAndShapeInfo* type_shape = nullptr; + RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(type_info, &type_shape)); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape, &elem_type)); + + output_types[value_info_name] = elem_type; } + // Save TRT engine, TRT context and input/output info to map + engines_.emplace(fused_node_name, std::move(trt_engine)); + contexts_.emplace(fused_node_name, std::move(trt_context)); + input_info_[fused_node_name].push_back(input_indexes); + output_info_[fused_node_name].push_back(output_indexes); + output_info_[fused_node_name].push_back(output_types); + + std::unique_ptr compute_state = std::make_unique(); + + *compute_state = { + static_cast(device_id_), + fused_node_name, + &engines_[fused_node_name], + &contexts_[fused_node_name], + input_info_[fused_node_name], + output_info_[fused_node_name], + context_memory_sharing_enable_, + &max_ctx_mem_size_, + &context_memory_, + &tensorrt_mu_, + sync_stream_after_enqueue_}; + + ep->compute_states_for_ep_context_[fused_node_name] = std::move(compute_state); + + // Update the OrtNodeComputeInfo associated with the graph. + auto ep_node_compute_info = std::make_unique(*ep); + *node_compute_info = ep_node_compute_info.release(); + return nullptr; + } OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ptr, @@ -1801,9 +1942,9 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ OrtStatus* status; if (EPContextNodeHelper::GraphHasCtxNode(graphs[fused_node_idx], ort_api)) { - //RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[fused_node_idx], fused_node, - // input_map, output_map, - // &node_compute_infos_result[fused_node_idx])); + RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[fused_node_idx], fused_node, + input_map, output_map, + &node_compute_infos_result[fused_node_idx])); } else { RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[fused_node_idx], fused_node, input_map, output_map, &node_compute_infos_result[fused_node_idx], @@ -2017,6 +2158,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa force_timing_cache_match_ = info_.force_timing_cache; detailed_build_log_ = info_.detailed_build_log; dump_ep_context_model_ = info_.dump_ep_context_model; + dump_ep_context_model_ = true; ep_context_file_path_ = info_.ep_context_file_path; ep_context_embed_mode_ = info_.ep_context_embed_mode; enable_engine_cache_for_ep_context_model(); @@ -2910,3 +3052,276 @@ void TRTEpNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* (void)trt_ep_compute_state; // Do nothing for here. } + +TRTEpEpContextNodeComputeInfo::TRTEpEpContextNodeComputeInfo(TensorrtExecutionProvider& ep) : ep(ep) { + ort_version_supported = ORT_API_VERSION; + CreateState = CreateStateImpl; + Compute = ComputeImpl; + ReleaseState = ReleaseStateImpl; +} + +OrtStatus* TRTEpEpContextNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context, + void** compute_state) { + auto* node_compute_info = static_cast(this_ptr); + TensorrtExecutionProvider& ep = node_compute_info->ep; + + std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context); + auto state_it = ep.compute_states_for_ep_context_.find(fused_node_name); + if (state_it == ep.compute_states_for_ep_context_.end()) { + std::string message = "Unable to TensorRT EP's compute state for fused node with name " + fused_node_name; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, message.c_str()); + } + + TensorrtComputeStateForEPContext& trt_ep_compute_state = *state_it->second; + *compute_state = &trt_ep_compute_state; + return nullptr; +} + +OrtStatus* TRTEpEpContextNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context) { + auto* node_compute_info = static_cast(this_ptr); + TensorrtExecutionProvider& ep = node_compute_info->ep; + + TensorrtComputeStateForEPContext* trt_state = reinterpret_cast(compute_state); + Ort::KernelContext ctx(kernel_context); + + // The whole compute_function should be considered the critical section. + // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); + + const std::unordered_map& input_indexes = (trt_state->input_info)[0]; + const std::unordered_map& output_indexes = (trt_state->output_info)[0]; + const std::unordered_map& output_types = (trt_state->output_info)[1]; + uint16_t device_id = trt_state->device_id; + auto fused_node_name = trt_state->fused_node_name; + std::unordered_map& dds_output_allocator_maps = ep.GetDDSOutputAllocators(); + auto& dds_output_allocator_map = dds_output_allocator_maps[fused_node_name]; + auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); + auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; + auto context_memory = trt_state->context_memory; + auto sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; + int num_outputs = static_cast(output_indexes.size()); + std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run + std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input + + // Get default OrtMemoryInfo from factory + // Get allocator from OrtKernelContext + const OrtMemoryInfo* mem_info = nullptr; + if (ep.factory_.device_id_to_cuda_gpu_memory_info_map.find(device_id) != + ep.factory_.device_id_to_cuda_gpu_memory_info_map.end()) { + mem_info = ep.factory_.device_id_to_cuda_gpu_memory_info_map[device_id]; + } + OrtAllocator* alloc = nullptr; + ep.GetAllocator(&alloc); + if (alloc == nullptr) { + Ort::ThrowOnError(ep.ort_api.KernelContext_GetAllocator(kernel_context, mem_info, &alloc)); + ep.SetAllocator(alloc); + } + + void* cuda_stream; + Ort::ThrowOnError(ep.ort_api.KernelContext_GetGPUComputeStream(kernel_context, &cuda_stream)); + cudaStream_t stream = static_cast(cuda_stream); + + // cudaStream_t stream; + cudaStreamCreate(&stream); + + // Check before using trt_engine + if (trt_engine == nullptr) { + return ep.ort_api.CreateStatus(ORT_EP_FAIL, "No engine is found."); + } + + // Get input and output binding names + int total_bindings = trt_engine->getNbIOTensors(); + std::vector input_binding_names, output_binding_names; + for (int i = 0, end = total_bindings; i < end; ++i) { + auto const& name = trt_engine->getIOTensorName(i); + auto const& mode = trt_engine->getTensorIOMode(name); + if (mode == nvinfer1::TensorIOMode::kINPUT) { + input_binding_names.push_back(name); + } else { + output_binding_names.push_back(name); + } + } + + /* + * Set input shapes and bind input buffers + */ + std::vector> scratch_buffers; + for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { + char const* input_name = input_binding_names[i]; + + size_t input_index = 0; + const auto iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } + auto input_tensor = ctx.GetInput(input_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + + auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, + shape_tensor_values_int64, scratch_buffers, alloc, stream); + if (status != nullptr) { + return ep.ort_api.CreateStatus(ORT_EP_FAIL, "BindContextInput failed."); + } + } + + /* + * Set output shapes and bind output buffers + */ + std::unordered_map buffers; + buffers.reserve(num_outputs); + using OutputOrtValue = Ort::UnownedValue; + std::unordered_map output_tensors; + output_tensors.reserve(num_outputs); + std::unordered_map output_dim_sizes; + output_dim_sizes.reserve(num_outputs); + + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + + size_t output_type = 0; + const auto type_iter = output_types.find(output_name); + if (type_iter != output_types.end()) { + output_type = type_iter->second; + } + + auto status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, + output_dim_sizes, dds_output_allocator_map, scratch_buffers, alloc, buffers); + if (status != nullptr) { + return ep.ort_api.CreateStatus(ORT_EP_FAIL, "BindContextOutput failed."); + } + } + + // Set execution context memory + if (trt_state->context_memory_sharing_enable) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + size_t mem_size = trt_engine->getDeviceMemorySize(); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (mem_size > *max_context_mem_size_ptr) { + *max_context_mem_size_ptr = mem_size; + *context_memory = MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr, true); + } + trt_context->setDeviceMemory((*context_memory).get()); + } + + /* + // Start CUDA graph capture. + // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because + // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { + // LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; + cuda_graph_.SetStream(stream); + CaptureBegin(0); + } + */ + + // Run TRT inference + if (!trt_context->enqueueV3(stream)) { + std::string err_msg = "TensorRT EP execution context enqueue failed."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + /* + * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this + * function concurrently, TRT EP needs to carefully take care of concurrency here, if not, following concurrent + * issue might happen: + * + * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per + * stream. In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling + * InferenceSession::Run() concurrently, the trt execution context instance is shared by all the threads and each + * thread aquires different stream from ORT. So TRT EP will end up having one trt execution context using multiple + * streams which is not suggested. But, since the whole compute_func() is protected by the lock and if + * cudaStreamSynchronize() is enforced here, one trt execution context per stream is guaranteed. + * + * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all + * operations to prevent the concurrent issue mentioned above. However, if cuda graph is enabled, TRT EP won't call + * cudaStreamSynchronize() since it's not allowed during graph capture. + */ + if (sync_stream_after_enqueue) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + } + + // Assign TRT output back to ORT output + // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) + // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_type = 0; + const auto& iter = output_types.find(output_name); + if (iter != output_types.end()) { + output_type = iter->second; + } + + if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + auto status = BindKernelOutput(ctx, mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); + if (status != nullptr) { + return ep.ort_api.CreateStatus(ORT_EP_FAIL, "BindKernelOutput failed."); + } + } else { + auto& output_tensor = output_tensors[i]; +#if NV_TENSORRT_MAJOR < 10 + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, + output_dim_sizes[i]); + } + } +#endif + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, + output_dim_sizes[i]); + } + } + } + } + + /* + // End CUDA graph capture. + // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream + // mentioned in graph capture above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and + // ExecuteGraph() per inference_session.cc. It's safe to start/end CUDA graph capture in compute_func() here since + // cuda graph object is maintained by a per thread basis. + if (cuda_graph_enable_ && !IsGraphCaptured(0)) { + if (IsGraphCaptureAllowed()) { + CaptureEnd(0); + // CUDA work issued to a capturing stream doesn�t actually run on the GPU, + // so run the captured graph here to actually execute the work. + ORT_RETURN_IF_ERROR(ReplayGraph(0)); + } else { + IncrementRegularRunCountBeforeGraphCapture(); + } + } + */ + + return nullptr; + +} + +void TRTEpEpContextNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) { + (void)this_ptr; + TensorrtComputeStateForEPContext& trt_ep_compute_state = *reinterpret_cast(compute_state); + (void)trt_ep_compute_state; + // Do nothing for here. +} diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index 35999869..e573e576 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -212,7 +212,9 @@ struct TensorrtComputeStateForEPContext { std::vector> output_info; bool context_memory_sharing_enable = false; size_t* max_context_mem_size_ptr = nullptr; + AllocatorUniquePtr* context_memory = nullptr; std::mutex* tensorrt_mu_ptr = nullptr; + bool sync_stream_after_enqueue = true; }; using ShapeRangesMap = std::unordered_map>>>; @@ -345,6 +347,8 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { std::string onnx_model_folder_path_; const void* onnx_model_bytestream_; size_t onnx_model_bytestream_size_; + const void* onnx_external_data_bytestream_ = nullptr; + size_t onnx_external_data_bytestream_size_ = 0; bool build_heuristics_enable_ = false; bool sparsity_enable_ = false; int builder_optimization_level_ = 3; @@ -447,3 +451,15 @@ struct TRTEpNodeComputeInfo : OrtNodeComputeInfo { TensorrtExecutionProvider& ep; }; + +struct TRTEpEpContextNodeComputeInfo : OrtNodeComputeInfo { + explicit TRTEpEpContextNodeComputeInfo(TensorrtExecutionProvider& ep); + + static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context, + void** compute_state); + static OrtStatus* ORT_API_CALL ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context); + static void ORT_API_CALL ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state); + + TensorrtExecutionProvider& ep; +}; From ccf20da1d76b6a18de73a58463a65af5f8718561 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 23 Jul 2025 10:44:23 -0700 Subject: [PATCH 39/60] update and sync with latest ep c api --- ...nsorrt_execution_provider_data_transfer.cc | 11 +++--- ...ensorrt_execution_provider_data_transfer.h | 15 ++++---- .../tensorrt/tensorrt_provider_factory.cc | 34 +++++++++---------- .../tensorrt/tensorrt_provider_factory.h | 12 +++---- 4 files changed, 37 insertions(+), 35 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc index b8c74511..b111f2e4 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc @@ -9,9 +9,10 @@ void CUDA_RETURN_IF_ERROR(cudaError_t res); /*static*/ -bool ORT_API_CALL TRTEpDataTransfer::CanCopyImpl(void* this_ptr, const OrtMemoryDevice* src_memory_device, +bool ORT_API_CALL TRTEpDataTransfer::CanCopyImpl(const OrtDataTransferImpl* this_ptr, + const OrtMemoryDevice* src_memory_device, const OrtMemoryDevice* dst_memory_device) noexcept { - auto& impl = *static_cast(this_ptr); + auto& impl = *static_cast(this_ptr); auto it = std::find_if(impl.cuda_gpu_mem_devices_.begin(), impl.cuda_gpu_mem_devices_.end(), [&impl, &src_memory_device, &dst_memory_device](const OrtMemoryDevice* memory_device) { @@ -29,7 +30,7 @@ bool ORT_API_CALL TRTEpDataTransfer::CanCopyImpl(void* this_ptr, const OrtMemory // function to copy one or more tensors. // implementation can optionally use async copy if a stream is available for the input. /*static*/ -OrtStatus* ORT_API_CALL TRTEpDataTransfer::CopyTensorsImpl(void* this_ptr, +OrtStatus* ORT_API_CALL TRTEpDataTransfer::CopyTensorsImpl(OrtDataTransferImpl* this_ptr, const OrtValue** src_tensors_ptr, OrtValue** dst_tensors_ptr, OrtSyncStream** streams_ptr, @@ -97,10 +98,10 @@ OrtStatus* ORT_API_CALL TRTEpDataTransfer::CopyTensorsImpl(void* this_ptr, } /*static*/ -void ORT_API_CALL TRTEpDataTransfer::ReleaseImpl(void* this_ptr) noexcept { +void ORT_API_CALL TRTEpDataTransfer::ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept { // In our setup the factory owns a shared ExampleDataTransfer instance so it will do the cleanup, and we ignore // the call to Release from the plugin_ep::DataTransfer dtor (see /onnxruntime/core/framework/plugin_data_transfer.h) // // If you create a new instance on each call to OrtEpFactory::CreateDataTransfer you call `delete` here - delete static_cast(this_ptr); + delete this_ptr; } diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h index 3dead944..f9d4cd87 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h @@ -4,27 +4,28 @@ #pragma once #include "ep_utils.h" +#include "onnxruntime_c_api.h" struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs { - TRTEpDataTransfer(ApiPtrs api_ptrs, std::vector device_mem_infos, - std::vector shared_mem_infos) + TRTEpDataTransfer(ApiPtrs api_ptrs, std::vector& device_mem_infos, + std::vector& shared_mem_infos) : ApiPtrs(api_ptrs), cuda_gpu_mem_devices_{device_mem_infos}, cuda_pinned_mem_devices_{shared_mem_infos} { CanCopy = CanCopyImpl; CopyTensors = CopyTensorsImpl; Release = ReleaseImpl; } - static bool ORT_API_CALL CanCopyImpl(void* this_ptr, const OrtMemoryDevice* src_memory_device, + static bool ORT_API_CALL CanCopyImpl(const OrtDataTransferImpl* this_ptr, const OrtMemoryDevice* src_memory_device, const OrtMemoryDevice* dst_memory_device) noexcept; // function to copy one or more tensors. // implementation can optionally use async copy if a stream is available for the input. - static OrtStatus* ORT_API_CALL CopyTensorsImpl(void* this_ptr, const OrtValue** src_tensors_ptr, + static OrtStatus* ORT_API_CALL CopyTensorsImpl(OrtDataTransferImpl* this_ptr, const OrtValue** src_tensors_ptr, OrtValue** dst_tensors_ptr, OrtSyncStream** streams_ptr, size_t num_tensors) noexcept; - static void ORT_API_CALL ReleaseImpl(void* this_ptr) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept; private: - std::vector cuda_gpu_mem_devices_; - std::vector cuda_pinned_mem_devices_; + std::vector& cuda_gpu_mem_devices_; + std::vector& cuda_pinned_mem_devices_; }; \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index 4cfbb98b..9918ac3b 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -27,6 +27,8 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* e ReleaseAllocator = ReleaseAllocatorImpl; CreateDataTransfer = CreateDataTransferImpl; + + IsStreamAware = IsStreamAwareImpl; } const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetNameImpl(const OrtEpFactory* this_ptr) noexcept { @@ -80,24 +82,19 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp size_t& num_ep_devices = *p_num_ep_devices; auto* factory = static_cast(this_ptr); + // Create two memory infos per device. + // The memory info is required to create allocator and gpu data transfer. int num_cuda_devices = 0; cudaGetDeviceCount(&num_cuda_devices); RETURN_IF_ERROR(factory->CreateMemoryInfoForDevices(num_cuda_devices)); - std::vector cuda_gpu_mem_devices; - std::vector cuda_pinned_mem_devices; int32_t device_id = 0; for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { // C API const OrtHardwareDevice& device = *devices[i]; - if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - - // workaround for duplicate devices when using remote desktop. - if (device_id > 0) { - continue; - } + if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { // These can be returned as nullptr if you have nothing to add. OrtKeyValuePairs* ep_metadata = nullptr; OrtKeyValuePairs* ep_options = nullptr; @@ -129,8 +126,8 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_pinned_mem_info)); // Get memory device from memory info for gpu data transfer - cuda_gpu_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_gpu_mem_info)); - cuda_pinned_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_pinned_mem_info)); + factory->cuda_gpu_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_gpu_mem_info)); + factory->cuda_pinned_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_pinned_mem_info)); ep_devices[num_ep_devices++] = ep_device; ++device_id; @@ -152,10 +149,12 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp // Create gpu data transfer auto data_transfer_impl = std::make_unique(static_cast(*factory), - cuda_gpu_mem_devices, // device memory - cuda_pinned_mem_devices // shared memory + factory->cuda_gpu_mem_devices, // device memory + factory->cuda_pinned_mem_devices // shared memory ); - factory->SetGPUDataTransfer(std::move(data_transfer_impl)); + + factory->data_transfer_impl = std::move(data_transfer_impl); + return nullptr; } @@ -244,13 +243,13 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl OrtEpFactory* this_ptr, OrtDataTransferImpl** data_transfer) noexcept { auto& factory = *static_cast(this_ptr); - *data_transfer = factory.data_transfer_impl_.get(); + *data_transfer = factory.data_transfer_impl.get(); return nullptr; } -void TensorrtExecutionProviderFactory::SetGPUDataTransfer(std::unique_ptr gpu_data_transfer) { - data_transfer_impl_ = std::move(gpu_data_transfer); +bool ORT_API_CALL TensorrtExecutionProviderFactory::IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return false; } // To make symbols visible on macOS/iOS @@ -265,6 +264,7 @@ extern "C" { // Public symbols // EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const OrtApiBase* ort_api_base, + const OrtLogger*, OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); const OrtEpApi* ort_ep_api = ort_api->GetEpApi(); @@ -285,7 +285,7 @@ EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const } EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { - delete factory; + delete static_cast(factory); return nullptr; } diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h index c9a43931..96ee3ba6 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -12,10 +12,6 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { public: TensorrtExecutionProviderFactory(const char* ep_name, ApiPtrs apis); - const OrtMemoryInfo* GetDefaultGpuMemInfoForDeviceId(uint32_t device_id) const; - - const OrtMemoryInfo* GetHostAccessibleMemInfoForDeviceId(uint32_t device_id) const; - OrtStatus* CreateMemoryInfoForDevices(int num_devices); // CUDA gpu memory and CUDA pinned memory are required for allocator and data transfer, these are the OrtMemoryInfo @@ -25,6 +21,10 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { std::vector cuda_pinned_memory_infos; std::unordered_map device_id_to_cuda_gpu_memory_info_map; // device id -> OrtMemoryInfo + std::vector cuda_gpu_mem_devices; + std::vector cuda_pinned_mem_devices; + std::unique_ptr data_transfer_impl; // data transfer implementation for this factory + private: static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; @@ -53,11 +53,11 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, OrtDataTransferImpl** data_transfer) noexcept; + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept; + void SetGPUDataTransfer(std::unique_ptr gpu_data_transfer); const std::string ep_name_; // EP name const std::string vendor_{"Nvidia"}; // EP vendor name const std::string ep_version_{"0.1.0"}; // EP version - - std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory }; \ No newline at end of file From cca956d9eb26b786cba128fa2a435029f26c4521 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 24 Jul 2025 08:13:08 -0700 Subject: [PATCH 40/60] remove delete resource in TRTEpDataTransfer::ReleaseImpl --- .../tensorrt/tensorrt_execution_provider_data_transfer.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc index b111f2e4..44d996a6 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc @@ -103,5 +103,6 @@ void ORT_API_CALL TRTEpDataTransfer::ReleaseImpl(OrtDataTransferImpl* this_ptr) // the call to Release from the plugin_ep::DataTransfer dtor (see /onnxruntime/core/framework/plugin_data_transfer.h) // // If you create a new instance on each call to OrtEpFactory::CreateDataTransfer you call `delete` here - delete this_ptr; + //delete static_cast(this_ptr); + ; } From 404cd4eb35a34ac0b47fff6975fc8b98558f4867 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 29 Jul 2025 08:24:34 -0700 Subject: [PATCH 41/60] update cmake file to force dynamic release CRT globally for all dependencies if it's release build --- .../tensorrt/CMakeLists.txt | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/plugin_execution_providers/tensorrt/CMakeLists.txt b/plugin_execution_providers/tensorrt/CMakeLists.txt index 1db2a761..4c7ab788 100644 --- a/plugin_execution_providers/tensorrt/CMakeLists.txt +++ b/plugin_execution_providers/tensorrt/CMakeLists.txt @@ -9,7 +9,7 @@ enable_language(CUDA) file(TO_CMAKE_PATH CUDAToolkit_ROOT "/usr/local/cuda") find_package(CUDAToolkit REQUIRED) -# CMake config to force dynamic debug CRT globally for all dependencies. +# CMake config to force dynamic debug CRT or dynamic release CRT globally for all dependencies. # This is to address the issue of: # libprotobufd.lib(common.obj) : error LNK2038: mismatch detected for 'RuntimeLibrary': value 'MTd_StaticDebug' doesn't match value 'MDd_DynamicDebug' in unary_elementwise_ops_impl.obj if (WIN32) @@ -17,6 +17,11 @@ if (WIN32) set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreadedDebugDLL" CACHE STRING "" FORCE) # /MDd set(BUILD_SHARED_LIBS OFF) # Build protobuf as static .lib, but using dynamic runtime endif() + + if(CMAKE_BUILD_TYPE STREQUAL "Release") + set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreadedDLL" CACHE STRING "" FORCE) + set(BUILD_SHARED_LIBS OFF) # Build protobuf as static .lib, but using dynamic runtime + endif() endif() add_definitions(-DONNX_NAMESPACE=onnx) @@ -50,10 +55,8 @@ FetchContent_Declare( ) if (WIN32) - if(CMAKE_BUILD_TYPE STREQUAL "Debug") - # Sometimes, protobuf ignores CMAKE_MSVC_RUNTIME_LIBRARY. To ensure it works: - set(protobuf_MSVC_STATIC_RUNTIME OFF CACHE BOOL "" FORCE) - endif() + # Sometimes, protobuf ignores CMAKE_MSVC_RUNTIME_LIBRARY. To ensure it works: + set(protobuf_MSVC_STATIC_RUNTIME OFF CACHE BOOL "" FORCE) endif() FetchContent_MakeAvailable(protobuf) @@ -67,10 +70,6 @@ FetchContent_Declare( FetchContent_MakeAvailable(onnx) -#set(ONNX_USE_LITE_PROTO OFF CACHE BOOL "" FORCE) -#set(ONNX_BUILD_TESTS OFF CACHE BOOL "" FORCE) -#set(ONNX_GEN_PB_TYPE_STUBS OFF CACHE BOOL "" FORCE) - # Add GSL FetchContent_Declare( gsl From c58130b8663b2b6965040fce6d2a289a4cd0c722 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 11 Aug 2025 13:52:48 -0700 Subject: [PATCH 42/60] use updated Value_GetMemoryDevice API --- .../tensorrt/tensorrt_execution_provider_data_transfer.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc index 44d996a6..6d177634 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc @@ -3,6 +3,7 @@ #include "tensorrt_execution_provider_data_transfer.h" +#include #include #include @@ -46,8 +47,8 @@ OrtStatus* ORT_API_CALL TRTEpDataTransfer::CopyTensorsImpl(OrtDataTransferImpl* const OrtMemoryDevice* src_device = nullptr; const OrtMemoryDevice* dst_device = nullptr; - RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(src_tensors[i], &src_device)); - RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(dst_tensors[i], &dst_device)); + src_device = impl.ep_api.Value_GetMemoryDevice(src_tensors[i]); + dst_device = impl.ep_api.Value_GetMemoryDevice(dst_tensors[i]); OrtMemoryInfoDeviceType src_device_type = impl.ep_api.MemoryDevice_GetDeviceType(src_device); OrtMemoryInfoDeviceType dst_device_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_device); From 5828e10447061b016579db5ffaa7e444768c7e3f Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 11 Aug 2025 13:53:22 -0700 Subject: [PATCH 43/60] update ort to graph util --- .../tensorrt/utils/ort_graph_to_proto.h | 188 ++++++++++++++++-- 1 file changed, 169 insertions(+), 19 deletions(-) diff --git a/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h b/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h index 16b346c7..da63f632 100644 --- a/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h +++ b/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// DO NOT include ORT header files as this is meant to be a header-only utility that can be copied +// to other projects. + /* SUMMARY: Utilities to serialize an OrtGraph into an ONNX GraphProto or ModelProto. Can be used by execution provider @@ -75,6 +78,44 @@ // graph_proto stores large initializers in an external file } ``` + + EXAMPLE SNIPPET (external initializers that point to data in memory, not officially supported by ONNX spec): + + This example stores initializers externally. However, instead of storing the initializers in a separate + file, the onnx::TensorProto objects point directly to memory addresses. This requires setting the initializer's + location to a special tag like "_MEM_ADDR_" (instead of a file path). The offset is set to the pointer to the + initializer's data in memory (instead of an offset into a file). + + Because this is not standard ONNX, such a onnx::GraphProto should not be saved as an ONNX file. + However, it allows custom tools that operate directly on a onnx::GraphProto to get the initializer data + if it has already been loaded into memory. + + ```C++ + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) { + auto handle_initializer_data = [](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + (void)value_info; + (void)bytes; + + offset = reinterpret_cast(data); + location = "_MEM_ADDR_"; // Some special location tag that indicates the offset is a pointer. + is_external = true; // True if is external initializer + return Ort::Status{nullptr}; + } + + ONNX_NAMESPACE::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data); + + // graph_proto has initializers that look like they are stored in an external file, + // but they are actually pointing to the data in memory. + } + ``` */ #ifndef INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ @@ -191,7 +232,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_ /*out*/ std::vector& dims, /*out*/ std::vector& symbolic_dims); static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); +static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, onnx::GraphProto& graph_proto, @@ -325,15 +366,20 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, for (const OrtOpAttr* ort_attr : ort_attrs) { OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - Ort::Status status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; - if (!status.IsOK()) { - // This is an attribute type that ORT does not support via ReadOpAttr(), like subgraphs, so skip it. + Ort::Status attr_type_status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; + if (attr_type == OrtOpAttrType::ORT_OP_ATTR_GRAPH) { + // ORT does not support reading subgraphs via ReadOpAttr(), so skip it. // Can use Node_GetSubgraphs to get subgraphs. continue; } + if (!attr_type_status.IsOK()) { + // Unsupported attribute type. + return attr_type_status; + } + onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_node, *ort_attr, *attr_proto)); } } @@ -456,11 +502,14 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, auto* ext_data_entries = tensor_proto->mutable_external_data(); onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* length_entry = ext_data_entries->Add(); location_entry->set_key("location"); location_entry->set_value(ext_location); offset_entry->set_key("offset"); offset_entry->set_value(std::to_string(ext_offset)); + length_entry->set_key("length"); + length_entry->set_value(std::to_string(data_bytes)); } else { // User wants to store data inline the TensorProto's raw_data tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); @@ -578,20 +627,24 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); type_proto_tensor->set_elem_type(ort_elem_type); - onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); + // If there are no dimensions in the shape, do not set a TensorShapeProto. Otherwise, it always looks + // like a scalar value. + if (!ort_dims.empty()) { + onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); - for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { - onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); + for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { + onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); - if (ort_dims[dim_idx] >= 0) { - dim_proto->set_dim_value(ort_dims[dim_idx]); - } else { - const std::string& dim_param = ort_dim_syms[dim_idx]; + if (ort_dims[dim_idx] >= 0) { + dim_proto->set_dim_value(ort_dims[dim_idx]); + } else { + const std::string& dim_param = ort_dim_syms[dim_idx]; - // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, - // which represents an unknown dimension. - if (!dim_param.empty()) { - dim_proto->set_dim_param(dim_param); + // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, + // which represents an unknown dimension. + if (!dim_param.empty()) { + dim_proto->set_dim_param(dim_param); + } } } } @@ -599,7 +652,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, return Ort::Status{nullptr}; } -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { +static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { const OrtApi& ort_api = Ort::GetApi(); const char* attr_name = nullptr; @@ -665,11 +718,11 @@ static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributePr Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; std::string* str = attr_proto.mutable_s(); - str->resize(total_attr_bytes, '\0'); + str->resize(total_attr_bytes); ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes, &total_attr_bytes)); - str->resize(total_attr_bytes - 1); // remove extra ending terminating '\0' character. + str->resize(total_attr_bytes); break; } case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { @@ -705,6 +758,103 @@ static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributePr break; } + case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); + + onnx::TensorProto tensor_proto; + + // TensorProto as an attribute value doesn't require a name. + + OrtValue* ort_value = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value)); + + Ort::Value tensor(ort_value); + + // Get tensor type and shape info + Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); + + // Get tensor type + ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); + + size_t element_size = 0; + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); + element_size = sizeof(float); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); + element_size = sizeof(uint8_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); + element_size = sizeof(int8_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); + element_size = sizeof(uint16_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); + element_size = sizeof(int16_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); + element_size = sizeof(int32_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); + element_size = sizeof(int64_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); + element_size = sizeof(bool); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); + element_size = sizeof(double); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); + element_size = sizeof(uint32_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); + element_size = sizeof(uint64_t); + break; + } + default: { + std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } + + auto shape = type_shape_info.GetShape(); + + for (auto& dim : shape) { + tensor_proto.add_dims(dim); + } + + size_t element_count = type_shape_info.GetElementCount(); + size_t data_bytes = element_count * element_size; + const void* data = tensor.GetTensorData(); + + // Copy the Ortvalue to TensorProto as raw data + tensor_proto.set_raw_data(data, data_bytes); + + *(attr_proto.mutable_t()) = std::move(tensor_proto); + break; + } default: { std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); return Ort::Status(err_msg.c_str(), ORT_FAIL); From 832a7f46e8b7756b746ca942c20ffa5b9f57fd71 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 11 Aug 2025 13:54:52 -0700 Subject: [PATCH 44/60] Add EP API Stream support --- .../tensorrt/cuda_allocator.h | 10 +- .../tensorrt/tensorrt_execution_provider.cc | 34 ++++- .../tensorrt/tensorrt_execution_provider.h | 11 +- ...sorrt_execution_provider_stream_support.cc | 119 ++++++++++++++++++ ...nsorrt_execution_provider_stream_support.h | 62 +++++++++ .../tensorrt/tensorrt_provider_factory.cc | 2 +- .../tensorrt/tensorrt_provider_factory.h | 2 - .../tensorrt/utils/cuda/cuda_call.h | 1 + .../tensorrt/utils/helper.ccc | 59 +++++++++ .../tensorrt/utils/status.ccc | 91 ++++++++++++++ 10 files changed, 377 insertions(+), 14 deletions(-) create mode 100644 plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc create mode 100644 plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h create mode 100644 plugin_execution_providers/tensorrt/utils/helper.ccc create mode 100644 plugin_execution_providers/tensorrt/utils/status.ccc diff --git a/plugin_execution_providers/tensorrt/cuda_allocator.h b/plugin_execution_providers/tensorrt/cuda_allocator.h index 44557cad..f62ecc89 100644 --- a/plugin_execution_providers/tensorrt/cuda_allocator.h +++ b/plugin_execution_providers/tensorrt/cuda_allocator.h @@ -10,11 +10,12 @@ using DeviceId = int16_t; struct CUDAAllocator : OrtAllocator { CUDAAllocator(const OrtMemoryInfo* mem_info, DeviceId device_id) : mem_info_(mem_info), device_id_(device_id) { OrtAllocator::version = ORT_API_VERSION; - OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { - return static_cast(this_)->Alloc(size); - }; + OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; + OrtAllocator::Reserve = nullptr; + OrtAllocator::GetStats = nullptr; + OrtAllocator::AllocOnStream = nullptr; // Allocate memory, handling usage across different Streams. Not used for TRT EP. } // TODO: Handle destructor //~CUDAAllocator(); @@ -41,6 +42,9 @@ struct CUDAPinnedAllocator : OrtAllocator { OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; + OrtAllocator::Reserve = nullptr; + OrtAllocator::GetStats = nullptr; + OrtAllocator::AllocOnStream = nullptr; } // TODO: Handle destructor //~CUDAPinnedAllocator(); diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 5ba38893..e049ff2f 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -14,6 +14,7 @@ #include "tensorrt_execution_provider.h" #include "cuda_allocator.h" #include "onnx_ctx_model_helper.h" +#include "tensorrt_execution_provider_stream_support.h" #include "onnx/onnx_pb.h" #include "cuda/unary_elementwise_ops_impl.h" #include "ep_utils.h" @@ -1960,6 +1961,30 @@ const char* ORT_API_CALL TensorrtExecutionProvider::GetNameImpl(const OrtEp* thi return ep->name_.c_str(); } +OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CreateSyncStreamForDeviceImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _Outptr_ OrtSyncStreamImpl** stream) noexcept { + // A per-session OrtSyncStreamImpl can be created here if the session options affect the implementation. + // Logging of any issues should use logger_ which is the session logger. + + TensorrtExecutionProvider* ep = static_cast(this_ptr); + + // we only create streams for the default device memory. + if (auto mem_type = ep->factory_.ep_api.MemoryDevice_GetMemoryType(memory_device); + mem_type != OrtDeviceMemoryType_DEFAULT) { + std::string error = "Invalid OrtMemoryDevice. Expected OrtDeviceMemoryType_DEFAULT(0). Got "; + error += std::to_string(mem_type); + return ep->ort_api.CreateStatus(ORT_INVALID_ARGUMENT, error.c_str()); + } + + auto device_id = ep->factory_.ep_api.MemoryDevice_GetDeviceId(memory_device); + + auto sync_stream = std::make_unique(ep->factory_, ep, device_id, nullptr); + *stream = sync_stream.release(); + + return nullptr; +} + /** * Refit the weight-stripped engine */ @@ -2070,6 +2095,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa GetCapability = GetCapabilityImpl; Compile = CompileImpl; ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // Initialize the execution provider. auto status = ort_api.Logger_LogMessage(&logger_, @@ -2158,7 +2184,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa force_timing_cache_match_ = info_.force_timing_cache; detailed_build_log_ = info_.detailed_build_log; dump_ep_context_model_ = info_.dump_ep_context_model; - dump_ep_context_model_ = true; + //dump_ep_context_model_ = true; ep_context_file_path_ = info_.ep_context_file_path; ep_context_embed_mode_ = info_.ep_context_embed_mode; enable_engine_cache_for_ep_context_model(); @@ -2378,7 +2404,6 @@ void ORT_API_CALL TensorrtExecutionProvider::ReleaseNodeComputeInfosImpl(OrtEp* } } - // // Implementation of TRTEpNodeComputeInfo // @@ -2487,7 +2512,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* cudaStream_t stream = static_cast(cuda_stream); //cudaStream_t stream; - cudaStreamCreate(&stream); + //cudaStreamCreate(&stream); // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even @@ -3053,6 +3078,9 @@ void TRTEpNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* // Do nothing for here. } +// +// Implementation of TRTEpEpContextNodeComputeInfo +// TRTEpEpContextNodeComputeInfo::TRTEpEpContextNodeComputeInfo(TensorrtExecutionProvider& ep) : ep(ep) { ort_version_supported = ORT_API_VERSION; CreateState = CreateStateImpl; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index e573e576..7c8adca6 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -227,7 +227,7 @@ static const std::string k_ep_ctx_onnx_model_filename = "onnx_model_filename"; /// /// -/// Plugin TensorRT EP +/// Plugin TensorRT EP implementing OrtEp. /// /// struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { @@ -311,6 +311,8 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { std::unordered_map trt_node_name_with_precision_; std::unordered_map> dynamic_range_map_; std::unordered_map cache_suffix_; + bool external_stream_ = false; + cudaStream_t stream_ = nullptr; private: static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; @@ -323,12 +325,11 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, size_t num_node_compute_infos) noexcept; - OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, - /*out*/ gsl::span ep_context_nodes); + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _Outptr_ OrtSyncStreamImpl** stream) noexcept; mutable TensorrtExecutionProviderInfo info_; - bool external_stream_ = false; - cudaStream_t stream_ = nullptr; int max_partition_iterations_ = 1000; size_t min_subgraph_size_ = 1; size_t max_workspace_size_ = 1 << 30; // 1GB diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc new file mode 100644 index 00000000..1fdd2e2e --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "tensorrt_execution_provider_stream_support.h" +#include "tensorrt_provider_factory.h" +#include "tensorrt_execution_provider.h" + +#include "cuda/cuda_common.h" +#include "cuda/cuda_call.h" + +// +// TrtSyncStreamImpl implementation +// + +TrtSyncStreamImpl::TrtSyncStreamImpl(TensorrtExecutionProviderFactory& factory, const OrtEp* ep, uint32_t device_id, const OrtKeyValuePairs* /*stream_options*/) + : ApiPtrs(factory), ep_{ep}, factory_{&factory} { + ort_version_supported = ORT_API_VERSION; + CreateNotification = CreateNotificationImpl; + GetHandle = GetHandleImpl; + Flush = FlushImpl; + OnSessionRunEnd = OnSessionRunEndImpl; + Release = ReleaseImpl; + + const TensorrtExecutionProvider* trt_ep = static_cast(ep_); + if (trt_ep->external_stream_) { + stream_ = trt_ep->stream_; + own_stream_ = false; + } else { + CUDA_CALL_THROW(cudaSetDevice(static_cast(device_id))); + cudaStream_t stream = nullptr; + CUDA_CALL_THROW(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + stream_ = stream; + own_stream_ = true; + } +} + +/*static*/ +OrtStatus* ORT_API_CALL TrtSyncStreamImpl::CreateNotificationImpl(_In_ OrtSyncStreamImpl* this_ptr, + _Outptr_ OrtSyncNotificationImpl** notification) noexcept { + auto& impl = *static_cast(this_ptr); + + std::unique_ptr trt_sync_notification; + RETURN_IF_ERROR(TrtSyncNotificationImpl::Create(impl.stream_, impl, trt_sync_notification)); + + *notification = trt_sync_notification.release(); + return nullptr; +} + +/*static*/ +void* ORT_API_CALL TrtSyncStreamImpl::GetHandleImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + return static_cast(impl.stream_); +} + +/*static*/ +OrtStatus* ORT_API_CALL TrtSyncStreamImpl::FlushImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + + // only flush when we own the stream, not external + if (impl.own_stream_) CUDA_CALL_THROW(cudaStreamSynchronize(static_cast(impl.stream_))); + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL TrtSyncStreamImpl::OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + return nullptr; +} + +// callback for EP library to release any internal state +/*static*/ +void ORT_API_CALL TrtSyncStreamImpl::ReleaseImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +// +// Notification support +// + +/*static*/ +OrtStatus* TrtSyncNotificationImpl::Create(cudaStream_t stream, const ApiPtrs& apis, + std::unique_ptr& notification){ + auto trt_sync_notification = std::make_unique(stream, apis); + CUDA_RETURN_IF_ERROR(cudaEventCreateWithFlags(&trt_sync_notification->event_, cudaEventDisableTiming)); + + notification = std::move(trt_sync_notification); + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL TrtSyncNotificationImpl::ActivateImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + CUDA_RETURN_IF_ERROR(cudaEventRecord(impl.event_, impl.stream_)); + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL TrtSyncNotificationImpl::WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr, + _In_ OrtSyncStream* stream) noexcept { + auto& impl = *static_cast(this_ptr); + void* handle = impl.ort_api.SyncStream_GetHandle(stream); + CUDA_RETURN_IF_ERROR(cudaStreamWaitEvent(static_cast(handle), impl.event_)); + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL TrtSyncNotificationImpl::WaitOnHostImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + CUDA_RETURN_IF_ERROR(cudaEventSynchronize(impl.event_)); + + return nullptr; +} + +/*static*/ +void ORT_API_CALL TrtSyncNotificationImpl::ReleaseImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h new file mode 100644 index 00000000..34a72889 --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "onnxruntime_c_api.h" +#include "tensorrt_provider_factory.h" +#include "ep_utils.h" + +#include + +// +// Class implementing Stream support for synchronization. +// +struct TrtSyncStreamImpl : public OrtSyncStreamImpl, public ApiPtrs { + TrtSyncStreamImpl(TensorrtExecutionProviderFactory& factory, + const OrtEp* ep, + uint32_t device_id, + const OrtKeyValuePairs* /*stream_options*/); + + private: + static OrtStatus* ORT_API_CALL CreateNotificationImpl(_In_ OrtSyncStreamImpl* this_ptr, + _Outptr_ OrtSyncNotificationImpl** sync_notification) noexcept; + static void* ORT_API_CALL GetHandleImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL FlushImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept; + static void ORT_API_CALL ReleaseImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept; + + // EP instance if the stream is being created internally for inferencing. + // nullptr when the stream is created outside of an inference session for data copies. + const OrtEp* ep_; + TensorrtExecutionProviderFactory* factory_{nullptr}; + + cudaStream_t stream_{nullptr}; + bool own_stream_{true}; +}; + +// +// Class implementing synchronization notification support. +// +struct TrtSyncNotificationImpl : public OrtSyncNotificationImpl, public ApiPtrs { + static OrtStatus* Create(cudaStream_t stream, const ApiPtrs& apis, + std::unique_ptr& notification); + + TrtSyncNotificationImpl(cudaStream_t stream, const ApiPtrs& apis) : stream_(stream), ApiPtrs(apis) { + ort_version_supported = ORT_API_VERSION; + Activate = ActivateImpl; + Release = ReleaseImpl; + WaitOnDevice = WaitOnDeviceImpl; + WaitOnHost = WaitOnHostImpl; + } + + private: + static OrtStatus* ORT_API_CALL ActivateImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr, + _In_ OrtSyncStream* stream) noexcept; + static OrtStatus* ORT_API_CALL WaitOnHostImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept; + static void ORT_API_CALL ReleaseImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept; + + cudaStream_t& stream_; + cudaEvent_t event_; +}; diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index 9918ac3b..15f26824 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -249,7 +249,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl } bool ORT_API_CALL TensorrtExecutionProviderFactory::IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { - return false; + return true; } // To make symbols visible on macOS/iOS diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h index 96ee3ba6..a4d1efb4 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -55,8 +55,6 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept; - void SetGPUDataTransfer(std::unique_ptr gpu_data_transfer); - const std::string ep_name_; // EP name const std::string vendor_{"Nvidia"}; // EP vendor name const std::string ep_version_{"0.1.0"}; // EP version diff --git a/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h b/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h index ada25ab7..d95e6a71 100644 --- a/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h +++ b/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h @@ -60,3 +60,4 @@ std::conditional_t CudaCall( //ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line); #define CUDA_CALL(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) +#define CUDA_CALL_THROW(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) diff --git a/plugin_execution_providers/tensorrt/utils/helper.ccc b/plugin_execution_providers/tensorrt/utils/helper.ccc new file mode 100644 index 00000000..7a889c30 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/helper.ccc @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "common.h" + +#ifdef _WIN32 +#include +#include +#endif + +namespace onnxruntime { +#ifdef _WIN32 +std::string ToUTF8String(const std::wstring& s) { + if (s.size() >= static_cast(std::numeric_limits::max())) + ORT_THROW("length overflow"); + + const int src_len = static_cast(s.size() + 1); + const int len = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, nullptr, 0, nullptr, nullptr); + assert(len > 0); + std::string ret(static_cast(len) - 1, '\0'); +#pragma warning(disable : 4189) + const int r = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, (char*)ret.data(), len, nullptr, nullptr); + assert(len == r); +#pragma warning(default : 4189) + return ret; +} + +std::wstring ToWideString(const std::string& s) { + if (s.size() >= static_cast(std::numeric_limits::max())) + ORT_THROW("length overflow"); + + const int src_len = static_cast(s.size() + 1); + const int len = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, nullptr, 0); + assert(len > 0); + std::wstring ret(static_cast(len) - 1, '\0'); +#pragma warning(disable : 4189) + const int r = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, (wchar_t*)ret.data(), len); + assert(len == r); +#pragma warning(default : 4189) + return ret; +} +#endif // #ifdef _WIN32 + +#ifdef ORT_NO_EXCEPTIONS +void PrintFinalMessage(const char* msg) { +#if defined(__ANDROID__) + __android_log_print(ANDROID_LOG_ERROR, "onnxruntime", "%s", msg); +#else + // TODO, consider changing the output of the error message from std::cerr to logging when the + // exceptions are disabled, since using std::cerr might increase binary size, and std::cerr output + // might not be easily accessible on some systems such as mobile + // TODO, see if we need to change the output of the error message from std::cerr to NSLog for iOS + std::cerr << msg << std::endl; +#endif +} +#endif // #ifdef ORT_NO_EXCEPTIONS + +} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/status.ccc b/plugin_execution_providers/tensorrt/utils/status.ccc new file mode 100644 index 00000000..b3a89c8c --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/status.ccc @@ -0,0 +1,91 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Modifications Copyright (c) Microsoft. + +#include "status.h" +#include "common.h" + +namespace onnxruntime { +namespace common { +Status::Status(StatusCategory category, int code, const std::string& msg) { + // state_ will be allocated here causing the status to be treated as a failure + ORT_ENFORCE(code != static_cast(common::OK)); + + state_ = std::make_unique(category, code, msg); +} + +Status::Status(StatusCategory category, int code, const char* msg) { + // state_ will be allocated here causing the status to be treated as a failure + ORT_ENFORCE(code != static_cast(common::OK)); + + state_ = std::make_unique(category, code, msg); +} + +Status::Status(StatusCategory category, int code) + : Status(category, code, "") { +} + +StatusCategory Status::Category() const noexcept { + return IsOK() ? common::NONE : state_->category; +} + +int Status::Code() const noexcept { + return IsOK() ? static_cast(common::OK) : state_->code; +} + +const std::string& Status::ErrorMessage() const noexcept { + return IsOK() ? EmptyString() : state_->msg; +} + +std::string Status::ToString() const { + if (state_ == nullptr) { + return std::string("OK"); + } + + std::string result; + + if (common::SYSTEM == state_->category) { + result += "SystemError"; + result += " : "; + result += std::to_string(errno); + } else if (common::ONNXRUNTIME == state_->category) { + result += "[ONNXRuntimeEPError]"; + result += " : "; + result += std::to_string(Code()); + result += " : "; + result += StatusCodeToString(static_cast(Code())); + result += " : "; + result += state_->msg; + } + + return result; +} + +// GSL_SUPRESS(i.22) is broken. Ignore the warnings for the static local variables that are trivial +// and should not have any destruction order issues via pragmas instead. +// https://developercommunity.visualstudio.com/content/problem/249706/gslsuppress-does-not-work-for-i22-c-core-guideline.html +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 26426) +#endif + +const std::string& Status::EmptyString() noexcept { + static std::string s_empty; + return s_empty; +} + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +} // namespace common +} // namespace onnxruntime From edd4b344b0bb48e3421fe7b673482dca5285ef98 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 11 Aug 2025 13:59:39 -0700 Subject: [PATCH 45/60] Update CMakeLists.txt --- plugin_execution_providers/tensorrt/CMakeLists.txt | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/plugin_execution_providers/tensorrt/CMakeLists.txt b/plugin_execution_providers/tensorrt/CMakeLists.txt index 4c7ab788..289dcce9 100644 --- a/plugin_execution_providers/tensorrt/CMakeLists.txt +++ b/plugin_execution_providers/tensorrt/CMakeLists.txt @@ -5,8 +5,9 @@ cmake_minimum_required(VERSION 3.26) project(TensorRTEp VERSION 1.0) set(CMAKE_CXX_STANDARD 17) -enable_language(CUDA) -file(TO_CMAKE_PATH CUDAToolkit_ROOT "/usr/local/cuda") + +enable_language(CUDA) # via nvcc to get the CUDA tool kit +file(TO_CMAKE_PATH "/usr/local/cuda" CUDAToolkit_ROOT) find_package(CUDAToolkit REQUIRED) # CMake config to force dynamic debug CRT or dynamic release CRT globally for all dependencies. From 5f46b688b65ca89aaf2a08f0f3e1e8ed1560194e Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 11 Aug 2025 21:42:57 -0700 Subject: [PATCH 46/60] fix mem leak for OrtAllocator --- .../tensorrt/tensorrt_execution_provider.cc | 42 ++++++++------- .../tensorrt/tensorrt_execution_provider.h | 12 ++--- .../tensorrt/tensorrt_provider_factory.cc | 54 ++++++++++++------- .../tensorrt/tensorrt_provider_factory.h | 10 ++-- 4 files changed, 68 insertions(+), 50 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index e049ff2f..386544ec 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -2071,11 +2071,15 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine( #endif } -TensorrtExecutionProvider::~TensorrtExecutionProvider() = default; +TensorrtExecutionProvider::~TensorrtExecutionProvider() { + if (alloc_ != nullptr) { + ort_api.ReleaseAllocator(alloc_); + } +} /// /// -/// Plugin TensorRT EP that implements OrtEp +/// Plugin TensorRT EP implementing OrtEp /// /// TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFactory& factory, @@ -2494,18 +2498,17 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* auto& dds_output_allocator_map = dds_output_allocator_maps[fused_node_name]; // Get default OrtMemoryInfo from factory - // Get allocator from OrtKernelContext const OrtMemoryInfo* mem_info = nullptr; - if (ep.factory_.device_id_to_cuda_gpu_memory_info_map.find(device_id) != - ep.factory_.device_id_to_cuda_gpu_memory_info_map.end()) { - mem_info = ep.factory_.device_id_to_cuda_gpu_memory_info_map[device_id]; + if (ep.factory_.cuda_gpu_memory_infos.find(device_id) != + ep.factory_.cuda_gpu_memory_infos.end()) { + mem_info = ep.factory_.cuda_gpu_memory_infos[device_id].get(); } - OrtAllocator* alloc = nullptr; - ep.GetAllocator(&alloc); - if (alloc == nullptr) { - Ort::ThrowOnError(ep.ort_api.KernelContext_GetAllocator(kernel_context, mem_info, &alloc)); - ep.SetAllocator(alloc); + + // Get allocator from OrtKernelContext + if (ep.alloc_ == nullptr) { + Ort::ThrowOnError(ep.ort_api.KernelContext_GetAllocator(kernel_context, mem_info, &ep.alloc_)); } + OrtAllocator* alloc = ep.alloc_; void* cuda_stream; Ort::ThrowOnError(ep.ort_api.KernelContext_GetGPUComputeStream(kernel_context, &cuda_stream)); @@ -3134,18 +3137,17 @@ OrtStatus* TRTEpEpContextNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_p std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input // Get default OrtMemoryInfo from factory - // Get allocator from OrtKernelContext const OrtMemoryInfo* mem_info = nullptr; - if (ep.factory_.device_id_to_cuda_gpu_memory_info_map.find(device_id) != - ep.factory_.device_id_to_cuda_gpu_memory_info_map.end()) { - mem_info = ep.factory_.device_id_to_cuda_gpu_memory_info_map[device_id]; + if (ep.factory_.cuda_gpu_memory_infos.find(device_id) != + ep.factory_.cuda_gpu_memory_infos.end()) { + mem_info = ep.factory_.cuda_gpu_memory_infos[device_id].get(); } - OrtAllocator* alloc = nullptr; - ep.GetAllocator(&alloc); - if (alloc == nullptr) { - Ort::ThrowOnError(ep.ort_api.KernelContext_GetAllocator(kernel_context, mem_info, &alloc)); - ep.SetAllocator(alloc); + + // Get allocator from OrtKernelContext + if (ep.alloc_ == nullptr) { + Ort::ThrowOnError(ep.ort_api.KernelContext_GetAllocator(kernel_context, mem_info, &ep.alloc_)); } + OrtAllocator* alloc = ep.alloc_; void* cuda_stream; Ort::ThrowOnError(ep.ort_api.KernelContext_GetGPUComputeStream(kernel_context, &cuda_stream)); diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index 7c8adca6..406389c1 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -264,10 +264,6 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, bool detailed_build_log); - - void GetAllocator(OrtAllocator** alloc) const { *alloc = alloc_; } - - void SetAllocator(OrtAllocator* alloc) { alloc_ = alloc; } std::unordered_map& GetDDSOutputAllocators() { return dds_output_allocator_maps_; @@ -314,6 +310,10 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { bool external_stream_ = false; cudaStream_t stream_ = nullptr; + // The OrtAllocator object will be get during ep compute time + // and should be kept for the lifetime of TRT EP object. + OrtAllocator* alloc_ = nullptr; + private: static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, @@ -375,10 +375,6 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { std::string cache_prefix_; bool engine_hw_compatible_ = false; - // The OrtAllocator object will be get during ep compute time - // and should be kept for the lifetime of TRT EP object. - OrtAllocator* alloc_ = nullptr; - // For create/dump EP context node model bool dump_ep_context_model_ = false; std::string ep_context_file_path_; diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index 15f26824..43e69a56 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -57,7 +57,7 @@ OrtStatus* TensorrtExecutionProviderFactory::CreateMemoryInfoForDevices(int num_ /* device_id */ device_id, OrtDeviceMemoryType_DEFAULT, /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info)); - cuda_gpu_memory_infos.emplace_back(MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo)); + cuda_gpu_memory_infos[device_id] = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); // HOST_ACCESSIBLE memory should use the non-CPU device type mem_info = nullptr; @@ -66,7 +66,7 @@ OrtStatus* TensorrtExecutionProviderFactory::CreateMemoryInfoForDevices(int num_ /* device_id */ device_id, OrtDeviceMemoryType_HOST_ACCESSIBLE, /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info)); - cuda_pinned_memory_infos.emplace_back(MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo)); + cuda_pinned_memory_infos[device_id] = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); } return nullptr; @@ -196,35 +196,50 @@ void ORT_API_CALL TensorrtExecutionProviderFactory::ReleaseEpImpl(OrtEpFactory* delete trt_ep; } -OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl( - OrtEpFactory* this_ptr, const OrtMemoryInfo* memory_info, - const OrtKeyValuePairs* /*allocator_options*/, - OrtAllocator** allocator) noexcept { +OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl(OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { auto& factory = *static_cast(this_ptr); - *allocator = nullptr; - // NOTE: The factory implementation can return a shared OrtAllocator* instead of creating a new instance on each call. - // To do this just make ReleaseAllocatorImpl a no-op. + // NOTE: The factory implementation is free to return a shared OrtAllocator* instance instead of creating a new + // allocator on each call. To do this have an allocator instance as an OrtEpFactory class member and make + // ReleaseAllocatorImpl a no-op. - // NOTE: If OrtMemoryInfo has allocator type (call MemoryInfoGetType) of OrtArenaAllocator, an ORT BFCArena - // will be added to wrap the returned OrtAllocator. The EP is free to implement its own arena, and if it - // wants to do this the OrtMemoryInfo MUST be created with an allocator type of OrtDeviceAllocator. - - // NOTE: The OrtMemoryInfo pointer should only ever be coming straight from an OrtEpDevice, and pointer based - // matching should work. + // NOTE: EP should implement its own arena logic. ep_arena.cc/h is provided as a reference and we use it here for + // device memory. `allocator_options` can be used for arena configuration and there is a helper in ep_arena.h + // to convert from OrtKeyValuePairs to the same arena config settings that ORT uses. + // You are of course free to have completely different settings. const OrtMemoryDevice* mem_device = factory.ep_api.MemoryInfo_GetMemoryDevice(memory_info); uint32_t device_id = factory.ep_api.MemoryDevice_GetDeviceId(mem_device); if (factory.ep_api.MemoryDevice_GetMemoryType(mem_device) == OrtDeviceMemoryType_DEFAULT) { + // use the one that previously created + if (factory.cuda_gpu_allocators.find(device_id) != factory.cuda_gpu_allocators.end()) { + *allocator = factory.cuda_gpu_allocators[device_id].get(); + return nullptr; + } + // create a CUDA allocator auto cuda_allocator = std::make_unique(memory_info, static_cast(device_id)); - factory.device_id_to_cuda_gpu_memory_info_map[device_id] = memory_info; - *allocator = cuda_allocator.release(); + + *allocator = cuda_allocator.get(); + factory.cuda_gpu_allocators[device_id] = std::move(cuda_allocator); + } else if (factory.ep_api.MemoryDevice_GetMemoryType(mem_device) == OrtDeviceMemoryType_HOST_ACCESSIBLE) { + // use the one that previously created + if (factory.cuda_pinned_allocators.find(device_id) != factory.cuda_pinned_allocators.end()) { + *allocator = factory.cuda_pinned_allocators[device_id].get(); + return nullptr; + } + // create a CUDA PINNED allocator auto cuda_pinned_allocator = std::make_unique(memory_info); - *allocator = cuda_pinned_allocator.release(); + + *allocator = cuda_pinned_allocator.get(); + factory.cuda_pinned_allocators[device_id] = std::move(cuda_pinned_allocator); + } else { return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "INTERNAL ERROR! Unknown memory info provided to CreateAllocator. " @@ -236,7 +251,8 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl( void ORT_API_CALL TensorrtExecutionProviderFactory::ReleaseAllocatorImpl(OrtEpFactory* /*this*/, OrtAllocator* allocator) noexcept { - delete static_cast(allocator); + // no-op. The allocators will be shared across sessions. + // delete static_cast(allocator); } OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl( diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h index a4d1efb4..1a3038fa 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -2,6 +2,7 @@ #include "ep_utils.h" #include "tensorrt_execution_provider_data_transfer.h" +#include "cuda_allocator.h" using MemoryInfoUniquePtr = std::unique_ptr>; @@ -17,9 +18,12 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { // CUDA gpu memory and CUDA pinned memory are required for allocator and data transfer, these are the OrtMemoryInfo // instance required for that. // Current TRT EP implementation uses one default OrtMemoryInfo and one host accessible OrtMemoryInfo per ep device. - std::vector cuda_gpu_memory_infos; - std::vector cuda_pinned_memory_infos; - std::unordered_map device_id_to_cuda_gpu_memory_info_map; // device id -> OrtMemoryInfo + std::unordered_map cuda_gpu_memory_infos; // device id -> memory info + std::unordered_map cuda_pinned_memory_infos; + + // Keeps allocators per ep device in factory so they can be shared across sessions. + std::unordered_map> cuda_gpu_allocators; // device id -> allocator + std::unordered_map> cuda_pinned_allocators; std::vector cuda_gpu_mem_devices; std::vector cuda_pinned_mem_devices; From e81d3954c736c08bd4ebe7b8342c1e9f90b5fba9 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 19 Aug 2025 16:54:08 -0700 Subject: [PATCH 47/60] add missing header file --- .../tensorrt/ep_utils.h | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 plugin_execution_providers/tensorrt/ep_utils.h diff --git a/plugin_execution_providers/tensorrt/ep_utils.h b/plugin_execution_providers/tensorrt/ep_utils.h new file mode 100644 index 00000000..a2495e68 --- /dev/null +++ b/plugin_execution_providers/tensorrt/ep_utils.h @@ -0,0 +1,113 @@ +#pragma once + +#include "onnxruntime_cxx_api.h" + +#include "flatbuffers/idl.h" +#include "ort_trt_int8_cal_table.fbs.h" +#include "make_string.h" +// #include "core/providers/cuda/cuda_pch.h" +// #include "core/common/path_string.h" +// #include "core/framework/murmurhash3.h" + +//#include"nv_includes.h" +#include "gsl/narrow" + +#include +#include +#include +#include +#include +#include +#include + +struct ApiPtrs { + const OrtApi& ort_api; + const OrtEpApi& ep_api; + const OrtModelEditorApi& model_editor_api; +}; + +#define ENFORCE(condition, ...) \ + do { \ + if (!(condition)) { \ + throw std::runtime_error(MakeString(__VA_ARGS__)); \ + } \ + } while (false) + +#define THROW(...) \ + throw std::runtime_error(MakeString(__VA_ARGS__)); + +#define RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return _status; \ + } \ + } while (0) + +/* +template +std::string ComposeString(Args&&... args) { + std::ostringstream oss; + (oss << ... << args); + return oss.str(); +}; +*/ + +#define RETURN_IF(cond, ...) \ + do { \ + if ((cond)) { \ + return Ort::GetApi().CreateStatus(ORT_EP_FAIL, MakeString(__VA_ARGS__).c_str()); \ + } \ + } while (0) + +#define RETURN_IF_NOT(condition, ...) RETURN_IF(!(condition), __VA_ARGS__) + +#define MAKE_STATUS(error_code, msg) \ + Ort::GetApi().CreateStatus(error_code, (msg)); + +#define THROW_IF_ERROR(expr) \ + do { \ + auto _status = (expr); \ + if (_status != nullptr) { \ + std::ostringstream oss; \ + oss << Ort::GetApi().GetErrorMessage(_status); \ + Ort::GetApi().ReleaseStatus(_status); \ + throw std::runtime_error(oss.str()); \ + } \ + } while (0) + +#define RETURN_FALSE_AND_PRINT_IF_ERROR(fn) \ + do { \ + OrtStatus* status = (fn); \ + if (status != nullptr) { \ + std::cerr << Ort::GetApi().GetErrorMessage(status) << std::endl; \ + return false; \ + } \ + } while (0) + +// Helper to release Ort one or more objects obtained from the public C API at the end of their scope. +template +struct DeferOrtRelease { + DeferOrtRelease(T** object_ptr, std::function release_func) + : objects_(object_ptr), count_(1), release_func_(release_func) {} + + DeferOrtRelease(T** objects, size_t count, std::function release_func) + : objects_(objects), count_(count), release_func_(release_func) {} + + ~DeferOrtRelease() { + if (objects_ != nullptr && count_ > 0) { + for (size_t i = 0; i < count_; ++i) { + if (objects_[i] != nullptr) { + release_func_(objects_[i]); + objects_[i] = nullptr; + } + } + } + } + T** objects_ = nullptr; + size_t count_ = 0; + std::function release_func_ = nullptr; +}; + +template +using AllocatorUniquePtr = std::unique_ptr>; \ No newline at end of file From 1211cd647ea987ac8bd43246fc768d18afd7d184 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 19 Aug 2025 23:00:06 -0700 Subject: [PATCH 48/60] fix build issue on Linux --- plugin_execution_providers/tensorrt/CMakeLists.txt | 2 +- .../tensorrt/tensorrt_provider_factory.h | 2 +- plugin_execution_providers/tensorrt/{ => utils}/ep_utils.h | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) rename plugin_execution_providers/tensorrt/{ => utils}/ep_utils.h (98%) diff --git a/plugin_execution_providers/tensorrt/CMakeLists.txt b/plugin_execution_providers/tensorrt/CMakeLists.txt index 289dcce9..ba3f759f 100644 --- a/plugin_execution_providers/tensorrt/CMakeLists.txt +++ b/plugin_execution_providers/tensorrt/CMakeLists.txt @@ -1,6 +1,6 @@ # usage: # cd build/ -# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DORT_HOME=/home/lochi/onnxruntime-win-x64-gpu-1.23.0 -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DTENSORRT_HOME=/home/lochi/tensorrt/TensorRT-10.3.0.26 (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") +# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DORT_HOME=/home/lochi/onnxruntime-win-x64-gpu-1.23.0 -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DTENSORRT_HOME=/home/lochi/tensorrt/TensorRT-10.3.0.26 -DCMAKE_POSITION_INDEPENDENT_CODE=ON (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") # cmake --build ./ --config Debug cmake_minimum_required(VERSION 3.26) project(TensorRTEp VERSION 1.0) diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h index 1a3038fa..1fd7176f 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -34,7 +34,7 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; - static const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* devices, size_t num_devices, diff --git a/plugin_execution_providers/tensorrt/ep_utils.h b/plugin_execution_providers/tensorrt/utils/ep_utils.h similarity index 98% rename from plugin_execution_providers/tensorrt/ep_utils.h rename to plugin_execution_providers/tensorrt/utils/ep_utils.h index a2495e68..ded47204 100644 --- a/plugin_execution_providers/tensorrt/ep_utils.h +++ b/plugin_execution_providers/tensorrt/utils/ep_utils.h @@ -2,8 +2,8 @@ #include "onnxruntime_cxx_api.h" -#include "flatbuffers/idl.h" -#include "ort_trt_int8_cal_table.fbs.h" +//#include "flatbuffers/idl.h" +//#include "ort_trt_int8_cal_table.fbs.h" #include "make_string.h" // #include "core/providers/cuda/cuda_pch.h" // #include "core/common/path_string.h" From 0a8be0d8b0aa79ece2e878c1664973554fbf776a Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 22 Aug 2025 14:40:50 -0700 Subject: [PATCH 49/60] lintrunner -a --- .../cuda/cu_inc/unary_elementwise_impl.cuh | 4 +- .../cuda/unary_elementwise_ops_impl.cu | 5 +- .../cuda/unary_elementwise_ops_impl.h | 4 +- .../tensorrt/cuda_allocator.h | 2 +- .../tensorrt/onnx_ctx_model_helper.cc | 9 +- .../tensorrt/onnx_ctx_model_helper.h | 9 +- .../tensorrt/tensorrt_execution_provider.cc | 76 +++--- .../tensorrt/tensorrt_execution_provider.h | 18 +- ...nsorrt_execution_provider_data_transfer.cc | 4 +- ...ensorrt_execution_provider_data_transfer.h | 2 +- .../tensorrt_execution_provider_info.cc | 222 +++++++++--------- .../tensorrt_execution_provider_info.h | 10 +- ...sorrt_execution_provider_stream_support.cc | 2 +- ...nsorrt_execution_provider_stream_support.h | 8 +- .../tensorrt_execution_provider_utils.h | 28 +-- .../tensorrt/tensorrt_provider_factory.cc | 66 +++--- .../tensorrt/tensorrt_provider_factory.h | 8 +- .../tensorrt/utils/common.h | 24 +- .../tensorrt/utils/cuda/cuda_call.h | 28 +-- .../tensorrt/utils/ep_utils.h | 26 +- .../tensorrt/utils/exceptions.h | 12 +- .../tensorrt/utils/path_string.h | 6 +- .../tensorrt/utils/provider_options_utils.h | 2 +- 23 files changed, 282 insertions(+), 293 deletions(-) diff --git a/plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh b/plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh index 7b16741b..7d05c54b 100644 --- a/plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh +++ b/plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh @@ -35,7 +35,7 @@ __global__ void _UnaryElementWise( InT value[NumElementsPerThread]; CUDA_LONG id = start; - #pragma unroll +#pragma unroll for (int i = 0; i < NumElementsPerThread; i++) { if (id < N) { value[i] = input_data[id]; @@ -44,7 +44,7 @@ __global__ void _UnaryElementWise( } id = start; - #pragma unroll +#pragma unroll for (int i = 0; i < NumElementsPerThread; i++) { if (id < N) { output_data[id] = functor(value[i]); diff --git a/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu b/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu index 0ceb9454..9d488752 100644 --- a/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu +++ b/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu @@ -9,7 +9,6 @@ #endif #include - namespace cuda { // the postfix of means the types supported by the op: @@ -72,7 +71,7 @@ struct OP_Cast { IMPL_CAST_IMPL(T, uint32_t) \ IMPL_CAST_IMPL(T, uint64_t) \ IMPL_CAST_IMPL(T, bool) \ - //IMPL_CAST_IMPL(T, BFloat16) + // IMPL_CAST_IMPL(T, BFloat16) IMPL_CAST_IMPL_FROM(half) IMPL_CAST_IMPL_FROM(float) @@ -86,6 +85,6 @@ IMPL_CAST_IMPL_FROM(uint16_t) IMPL_CAST_IMPL_FROM(uint32_t) IMPL_CAST_IMPL_FROM(uint64_t) IMPL_CAST_IMPL_FROM(bool) -//IMPL_CAST_IMPL_FROM(BFloat16) +// IMPL_CAST_IMPL_FROM(BFloat16) } // namespace cuda diff --git a/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h b/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h index 184426a9..1bd241f7 100644 --- a/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h +++ b/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h @@ -27,7 +27,7 @@ namespace cuda { DECL_IMPL_CAST(T, uint32_t) \ DECL_IMPL_CAST(T, uint64_t) \ DECL_IMPL_CAST(T, bool) \ - //DECL_IMPL_CAST(T, BFloat16) + // DECL_IMPL_CAST(T, BFloat16) DECL_IMPL_CAST_FROM(half) DECL_IMPL_CAST_FROM(float) @@ -41,7 +41,7 @@ DECL_IMPL_CAST_FROM(uint16_t) DECL_IMPL_CAST_FROM(uint32_t) DECL_IMPL_CAST_FROM(uint64_t) DECL_IMPL_CAST_FROM(bool) -//DECL_IMPL_CAST_FROM(BFloat16) +// DECL_IMPL_CAST_FROM(BFloat16) template void Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { diff --git a/plugin_execution_providers/tensorrt/cuda_allocator.h b/plugin_execution_providers/tensorrt/cuda_allocator.h index f62ecc89..7f765d50 100644 --- a/plugin_execution_providers/tensorrt/cuda_allocator.h +++ b/plugin_execution_providers/tensorrt/cuda_allocator.h @@ -15,7 +15,7 @@ struct CUDAAllocator : OrtAllocator { OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; OrtAllocator::Reserve = nullptr; OrtAllocator::GetStats = nullptr; - OrtAllocator::AllocOnStream = nullptr; // Allocate memory, handling usage across different Streams. Not used for TRT EP. + OrtAllocator::AllocOnStream = nullptr; // Allocate memory, handling usage across different Streams. Not used for TRT EP. } // TODO: Handle destructor //~CUDAAllocator(); diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc index 87d8d1cd..76e5553d 100644 --- a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc @@ -46,7 +46,6 @@ OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_ca const std::string& compute_capability, const std::string& onnx_model_path, OrtNode** ep_context_node) { - // Helper to collect input or output names from an array of OrtValueInfo instances. auto collect_input_output_names = [&](gsl::span value_infos, std::vector& result) -> OrtStatus* { @@ -99,16 +98,14 @@ OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_ca RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", engine_cache_path.c_str(), engine_cache_path.size(), ORT_OP_ATTR_STRING, &attributes[1])); } - ort_api.CreateOpAttr("hardware_architecture", compute_capability.c_str(), compute_capability.size(), ORT_OP_ATTR_STRING, &attributes[2]); ort_api.CreateOpAttr("onnx_model_filename", std::filesystem::path(onnx_model_path).filename().string().c_str(), 1, ORT_OP_ATTR_STRING, &attributes[3]); - RETURN_IF_ERROR(model_editor_api.CreateNode("EPContext", "com.microsoft", fused_node_name, input_names.data(), input_names.size(), output_names.data(), output_names.size(), attributes.data(), attributes.size(), ep_context_node)); - + return nullptr; } @@ -140,7 +137,7 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) { const int64_t embed_mode = reinterpret_cast(node_attr)->i(); // Only make path checks if model not provided as byte buffer - //bool make_secure_path_checks = !GetModelPath(graph_viewer).empty(); + // bool make_secure_path_checks = !GetModelPath(graph_viewer).empty(); bool make_secure_path_checks = false; if (embed_mode) { @@ -151,7 +148,7 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) { *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(const_cast(context_binary.c_str()), static_cast(context_binary.length()))); - //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it"; + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it"; if (!(*trt_engine_)) { return ort_api.CreateStatus(ORT_EP_FAIL, "TensorRT EP could not deserialize engine from binary data"); } diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h index 0c0d0f2a..75fa1b19 100644 --- a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h @@ -5,7 +5,7 @@ #include "tensorrt_execution_provider.h" #include "ep_utils.h" -#include "nv_includes.h" +// #include "nv_includes.h" #include #include @@ -21,14 +21,13 @@ class EPContextNodeHelper : public ApiPtrs { static bool GraphHasCtxNode(const OrtGraph* graph, const OrtApi& ort_api); - OrtStatus* CreateEPContextNode(const std::string& engine_cache_path, + OrtStatus* CreateEPContextNode(const std::string& engine_cache_path, char* engine_data, size_t size, const int64_t embed_mode, const std::string& compute_capability, const std::string& onnx_model_path, - OrtNode** ep_context_node - ); + OrtNode** ep_context_node); private: const OrtGraph* graph_ = nullptr; @@ -63,7 +62,7 @@ class EPContextNodeReader : public ApiPtrs { detailed_build_log_(detailed_build_log) { } - //bool ValidateEPCtxNode(const OrtGraph& graph); + // bool ValidateEPCtxNode(const OrtGraph& graph); OrtStatus* GetEpContextFromGraph(const OrtGraph& graph); diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 386544ec..84d6cd0c 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -723,7 +723,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect } OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info) noexcept { + OrtEpGraphSupportInfo* graph_support_info) noexcept { TensorrtExecutionProvider* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; @@ -784,7 +784,6 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this std::vector node_subgraphs(num_subgraphs); RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetSubgraphs(node, node_subgraphs.data(), node_subgraphs.size(), nullptr)); - // Iterate the node's subgraphs for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { const OrtGraph* subgraph = node_subgraphs[subgraph_idx]; @@ -792,7 +791,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this // Get number of subgraph's nodes size_t num_subgraph_nodes = 0; RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Graph_GetNumNodes(subgraph, &num_subgraph_nodes)); - + // TRT EP should consider the empty subgraph is fully supported by TRT. if (num_subgraph_nodes == 0) { continue; @@ -828,7 +827,6 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this } } - // Use this local definitions for now // TODO: Use provider option int max_partition_iterations = 1000; @@ -1763,7 +1761,6 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine std::unordered_map& input_map, std::unordered_map& output_map, OrtNodeComputeInfo** node_compute_info) { - TensorrtExecutionProvider* ep = static_cast(this_ptr); const char* name = nullptr; @@ -1867,17 +1864,17 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine std::unique_ptr compute_state = std::make_unique(); *compute_state = { - static_cast(device_id_), - fused_node_name, - &engines_[fused_node_name], - &contexts_[fused_node_name], - input_info_[fused_node_name], - output_info_[fused_node_name], - context_memory_sharing_enable_, - &max_ctx_mem_size_, - &context_memory_, - &tensorrt_mu_, - sync_stream_after_enqueue_}; + static_cast(device_id_), + fused_node_name, + &engines_[fused_node_name], + &contexts_[fused_node_name], + input_info_[fused_node_name], + output_info_[fused_node_name], + context_memory_sharing_enable_, + &max_ctx_mem_size_, + &context_memory_, + &tensorrt_mu_, + sync_stream_after_enqueue_}; ep->compute_states_for_ep_context_[fused_node_name] = std::move(compute_state); @@ -1886,7 +1883,6 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine *node_compute_info = ep_node_compute_info.release(); return nullptr; - } OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ptr, @@ -1895,10 +1891,9 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, _Out_writes_(count) OrtNode** ep_context_nodes) noexcept { - TensorrtExecutionProvider* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; - + gsl::span node_compute_infos_result(node_compute_infos, count); gsl::span ep_context_nodes_result(ep_context_nodes, count); @@ -1911,13 +1906,13 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ std::vector node_inputs(num_node_inputs); RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node, node_inputs.data(), node_inputs.size())); - + // Builds map from input name to its index in input list std::unordered_map input_map; input_map.reserve(num_node_inputs); for (size_t i = 0; i < num_node_inputs; i++) { const OrtValueInfo* value_info = node_inputs[i]; - const char* name = nullptr; + const char* name = nullptr; RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &name)); input_map.emplace(name, i); @@ -1940,7 +1935,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ output_map.emplace(name, i); } - + OrtStatus* status; if (EPContextNodeHelper::GraphHasCtxNode(graphs[fused_node_idx], ort_api)) { RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[fused_node_idx], fused_node, @@ -1952,7 +1947,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ &ep_context_nodes_result[fused_node_idx])); } } - + return nullptr; } @@ -2033,7 +2028,8 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine( if (refit_from_file) { // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from file on disk: " << onnx_model_path.string(); if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { - std::string err_msg = "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with " + std::string err_msg = + "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with " "weights contained in: " + onnx_model_path.string(); return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); @@ -2078,9 +2074,9 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() { } /// -/// +/// /// Plugin TensorRT EP implementing OrtEp -/// +/// /// TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFactory& factory, const std::string& name, @@ -2092,7 +2088,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa name_{name}, session_options_{session_options}, logger_{logger} { - // Implementation of OrtEp interfaces ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; @@ -2154,7 +2149,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa info_ = TensorrtExecutionProviderInfo::FromProviderOptions(provider_options); info_.has_trt_options = true; device_id_ = info_.device_id; - //api_->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, device_id_, &default_device); + // api_->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, device_id_, &default_device); std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; @@ -2188,7 +2183,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa force_timing_cache_match_ = info_.force_timing_cache; detailed_build_log_ = info_.detailed_build_log; dump_ep_context_model_ = info_.dump_ep_context_model; - //dump_ep_context_model_ = true; + // dump_ep_context_model_ = true; ep_context_file_path_ = info_.ep_context_file_path; ep_context_embed_mode_ = info_.ep_context_embed_mode; enable_engine_cache_for_ep_context_model(); @@ -2419,10 +2414,10 @@ TRTEpNodeComputeInfo::TRTEpNodeComputeInfo(TensorrtExecutionProvider& ep) : ep(e } OrtStatus* TRTEpNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context, - void** compute_state) { + void** compute_state) { auto* node_compute_info = static_cast(this_ptr); TensorrtExecutionProvider& ep = node_compute_info->ep; - + std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context); auto state_it = ep.compute_states_.find(fused_node_name); if (state_it == ep.compute_states_.end()) { @@ -2496,7 +2491,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* std::unordered_map& dds_output_allocator_maps = ep.GetDDSOutputAllocators(); auto& dds_output_allocator_map = dds_output_allocator_maps[fused_node_name]; - + // Get default OrtMemoryInfo from factory const OrtMemoryInfo* mem_info = nullptr; if (ep.factory_.cuda_gpu_memory_infos.find(device_id) != @@ -2514,8 +2509,8 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* Ort::ThrowOnError(ep.ort_api.KernelContext_GetGPUComputeStream(kernel_context, &cuda_stream)); cudaStream_t stream = static_cast(cuda_stream); - //cudaStream_t stream; - //cudaStreamCreate(&stream); + // cudaStream_t stream; + // cudaStreamCreate(&stream); // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even @@ -2852,8 +2847,8 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* if (weight_stripped_engine_refit) { auto status = ep.RefitEngine(model_path, onnx_model_folder_path, engine_cache_path, false /* path check for security */, - onnx_model_bytestream, onnx_model_bytestream_size, trt_engine, - true /* serialize refitted engine to disk */, detailed_build_log); + onnx_model_bytestream, onnx_model_bytestream_size, trt_engine, + true /* serialize refitted engine to disk */, detailed_build_log); if (status != nullptr) { return ep.ort_api.CreateStatus(ORT_EP_FAIL, "RefitEngine failed."); } @@ -3062,7 +3057,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* if (cuda_graph_enable_ && !IsGraphCaptured(0)) { if (IsGraphCaptureAllowed()) { CaptureEnd(0); - // CUDA work issued to a capturing stream doesn�t actually run on the GPU, + // CUDA work issued to a capturing stream doesn't actually run on the GPU, // so run the captured graph here to actually execute the work. ORT_RETURN_IF_ERROR(ReplayGraph(0)); } else { @@ -3092,7 +3087,7 @@ TRTEpEpContextNodeComputeInfo::TRTEpEpContextNodeComputeInfo(TensorrtExecutionPr } OrtStatus* TRTEpEpContextNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context, - void** compute_state) { + void** compute_state) { auto* node_compute_info = static_cast(this_ptr); TensorrtExecutionProvider& ep = node_compute_info->ep; @@ -3109,7 +3104,7 @@ OrtStatus* TRTEpEpContextNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* th } OrtStatus* TRTEpEpContextNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, - OrtKernelContext* kernel_context) { + OrtKernelContext* kernel_context) { auto* node_compute_info = static_cast(this_ptr); TensorrtExecutionProvider& ep = node_compute_info->ep; @@ -3336,7 +3331,7 @@ OrtStatus* TRTEpEpContextNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_p if (cuda_graph_enable_ && !IsGraphCaptured(0)) { if (IsGraphCaptureAllowed()) { CaptureEnd(0); - // CUDA work issued to a capturing stream doesn�t actually run on the GPU, + // CUDA work issued to a capturing stream doesn't actually run on the GPU, // so run the captured graph here to actually execute the work. ORT_RETURN_IF_ERROR(ReplayGraph(0)); } else { @@ -3346,7 +3341,6 @@ OrtStatus* TRTEpEpContextNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_p */ return nullptr; - } void TRTEpEpContextNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) { diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index 406389c1..fcd2b507 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -226,7 +226,7 @@ static const std::string k_ep_ctx_hardware_architecture = "hardware_architecture static const std::string k_ep_ctx_onnx_model_filename = "onnx_model_filename"; /// -/// +/// /// Plugin TensorRT EP implementing OrtEp. /// /// @@ -260,11 +260,11 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { OrtNode** ep_context_node); OrtStatus* RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path, - std::string& weight_stripped_engine_cath_path, bool path_check, - const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, - nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, - bool detailed_build_log); - + std::string& weight_stripped_engine_cath_path, bool path_check, + const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, + nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, + bool detailed_build_log); + std::unordered_map& GetDDSOutputAllocators() { return dds_output_allocator_maps_; } @@ -296,13 +296,13 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { std::unique_lock GetApiLock() const; /**Check the graph is the subgraph of control flow op*/ - //bool IsSubGraphOfControlFlowOp(const OrtGraphViewer* graph) const; + // bool IsSubGraphOfControlFlowOp(const OrtGraphViewer* graph) const; /**Check whether all the nodes of the graph are assigned to specific ep*/ - //bool AllNodesAssignedToSpecificEP(const OrtGraphViewer* graph, const std::string& provider_type) const; + // bool AllNodesAssignedToSpecificEP(const OrtGraphViewer* graph, const std::string& provider_type) const; /**Check whether all the nodes of subgraph are supported*/ - //bool IsSubGraphFullySupported(SubGraphCollection_t supported_nodes_vector, const int number_of_ort_nodes) const; + // bool IsSubGraphFullySupported(SubGraphCollection_t supported_nodes_vector, const int number_of_ort_nodes) const; std::unordered_map trt_node_name_with_precision_; std::unordered_map> dynamic_range_map_; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc index 6d177634..fe1bc675 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc @@ -90,7 +90,7 @@ OrtStatus* ORT_API_CALL TRTEpDataTransfer::CopyTensorsImpl(OrtDataTransferImpl* CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost)); } else { // CPU -> CPU involves copy to/from pinned memory and a synchronize may be required first - //ORT_ENFORCE(dst_data != src_data); + // ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } } @@ -104,6 +104,6 @@ void ORT_API_CALL TRTEpDataTransfer::ReleaseImpl(OrtDataTransferImpl* this_ptr) // the call to Release from the plugin_ep::DataTransfer dtor (see /onnxruntime/core/framework/plugin_data_transfer.h) // // If you create a new instance on each call to OrtEpFactory::CreateDataTransfer you call `delete` here - //delete static_cast(this_ptr); + // delete static_cast(this_ptr); ; } diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h index f9d4cd87..42c83007 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h @@ -8,7 +8,7 @@ struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs { TRTEpDataTransfer(ApiPtrs api_ptrs, std::vector& device_mem_infos, - std::vector& shared_mem_infos) + std::vector& shared_mem_infos) : ApiPtrs(api_ptrs), cuda_gpu_mem_devices_{device_mem_infos}, cuda_pinned_mem_devices_{shared_mem_infos} { CanCopy = CanCopyImpl; CopyTensors = CopyTensorsImpl; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc index c27d8095..fece820a 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include //#incldue "core/providers/cuda/cuda_pch.h" +#include //#incldue "core/providers/cuda/cuda_pch.h" #include "tensorrt_execution_provider_info.h" #include "provider_options_utils.h" @@ -128,114 +128,114 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions return info; } -//ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtExecutionProviderInfo& info) { -// const ProviderOptions options{ -// {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, -// {tensorrt::provider_option_names::kMaxPartitionIterations, MakeStringWithClassicLocale(info.max_partition_iterations)}, -// {tensorrt::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, -// {tensorrt::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, -// {tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.min_subgraph_size)}, -// {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.max_workspace_size)}, -// {tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, -// {tensorrt::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, -// {tensorrt::provider_option_names::kInt8CalibTable, MakeStringWithClassicLocale(info.int8_calibration_table_name)}, -// {tensorrt::provider_option_names::kInt8UseNativeCalibTable, MakeStringWithClassicLocale(info.int8_use_native_calibration_table)}, -// {tensorrt::provider_option_names::kDLAEnable, MakeStringWithClassicLocale(info.dla_enable)}, -// {tensorrt::provider_option_names::kDLACore, MakeStringWithClassicLocale(info.dla_core)}, -// {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.dump_subgraphs)}, -// {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.engine_cache_enable)}, -// {tensorrt::provider_option_names::kEngineCachePath, MakeStringWithClassicLocale(info.engine_cache_path)}, -// {tensorrt::provider_option_names::kWeightStrippedEngineEnable, MakeStringWithClassicLocale(info.weight_stripped_engine_enable)}, -// {tensorrt::provider_option_names::kOnnxModelFolderPath, MakeStringWithClassicLocale(info.onnx_model_folder_path)}, -// {tensorrt::provider_option_names::kEngineCachePrefix, MakeStringWithClassicLocale(info.engine_cache_prefix)}, -// {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)}, -// {tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)}, -// {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)}, -// // add new provider option here. -// {tensorrt::provider_option_names::kContextMemorySharingEnable, MakeStringWithClassicLocale(info.context_memory_sharing_enable)}, -// {tensorrt::provider_option_names::kLayerNormFP32Fallback, MakeStringWithClassicLocale(info.layer_norm_fp32_fallback)}, -// {tensorrt::provider_option_names::kTimingCacheEnable, MakeStringWithClassicLocale(info.timing_cache_enable)}, -// {tensorrt::provider_option_names::kTimingCachePath, MakeStringWithClassicLocale(info.timing_cache_path)}, -// {tensorrt::provider_option_names::kForceTimingCacheMatch, MakeStringWithClassicLocale(info.force_timing_cache)}, -// {tensorrt::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.detailed_build_log)}, -// {tensorrt::provider_option_names::kBuildHeuristics, MakeStringWithClassicLocale(info.build_heuristics_enable)}, -// {tensorrt::provider_option_names::kSparsityEnable, MakeStringWithClassicLocale(info.sparsity_enable)}, -// {tensorrt::provider_option_names::kBuilderOptimizationLevel, MakeStringWithClassicLocale(info.builder_optimization_level)}, -// {tensorrt::provider_option_names::kAuxiliaryStreams, MakeStringWithClassicLocale(info.auxiliary_streams)}, -// {tensorrt::provider_option_names::kTacticSources, MakeStringWithClassicLocale(info.tactic_sources)}, -// {tensorrt::provider_option_names::kExtraPluginLibPaths, MakeStringWithClassicLocale(info.extra_plugin_lib_paths)}, -// {tensorrt::provider_option_names::kProfilesMinShapes, MakeStringWithClassicLocale(info.profile_min_shapes)}, -// {tensorrt::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)}, -// {tensorrt::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)}, -// {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)}, -// {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.dump_ep_context_model)}, -// {tensorrt::provider_option_names::kEpContextFilePath, MakeStringWithClassicLocale(info.ep_context_file_path)}, -// {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.ep_context_embed_mode)}, -// {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.engine_hw_compatible)}, -// }; -// return options; -//} +// ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtExecutionProviderInfo& info) { +// const ProviderOptions options{ +// {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, +// {tensorrt::provider_option_names::kMaxPartitionIterations, MakeStringWithClassicLocale(info.max_partition_iterations)}, +// {tensorrt::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, +// {tensorrt::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, +// {tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.min_subgraph_size)}, +// {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.max_workspace_size)}, +// {tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, +// {tensorrt::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, +// {tensorrt::provider_option_names::kInt8CalibTable, MakeStringWithClassicLocale(info.int8_calibration_table_name)}, +// {tensorrt::provider_option_names::kInt8UseNativeCalibTable, MakeStringWithClassicLocale(info.int8_use_native_calibration_table)}, +// {tensorrt::provider_option_names::kDLAEnable, MakeStringWithClassicLocale(info.dla_enable)}, +// {tensorrt::provider_option_names::kDLACore, MakeStringWithClassicLocale(info.dla_core)}, +// {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.dump_subgraphs)}, +// {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.engine_cache_enable)}, +// {tensorrt::provider_option_names::kEngineCachePath, MakeStringWithClassicLocale(info.engine_cache_path)}, +// {tensorrt::provider_option_names::kWeightStrippedEngineEnable, MakeStringWithClassicLocale(info.weight_stripped_engine_enable)}, +// {tensorrt::provider_option_names::kOnnxModelFolderPath, MakeStringWithClassicLocale(info.onnx_model_folder_path)}, +// {tensorrt::provider_option_names::kEngineCachePrefix, MakeStringWithClassicLocale(info.engine_cache_prefix)}, +// {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)}, +// {tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)}, +// {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)}, +// // add new provider option here. +// {tensorrt::provider_option_names::kContextMemorySharingEnable, MakeStringWithClassicLocale(info.context_memory_sharing_enable)}, +// {tensorrt::provider_option_names::kLayerNormFP32Fallback, MakeStringWithClassicLocale(info.layer_norm_fp32_fallback)}, +// {tensorrt::provider_option_names::kTimingCacheEnable, MakeStringWithClassicLocale(info.timing_cache_enable)}, +// {tensorrt::provider_option_names::kTimingCachePath, MakeStringWithClassicLocale(info.timing_cache_path)}, +// {tensorrt::provider_option_names::kForceTimingCacheMatch, MakeStringWithClassicLocale(info.force_timing_cache)}, +// {tensorrt::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.detailed_build_log)}, +// {tensorrt::provider_option_names::kBuildHeuristics, MakeStringWithClassicLocale(info.build_heuristics_enable)}, +// {tensorrt::provider_option_names::kSparsityEnable, MakeStringWithClassicLocale(info.sparsity_enable)}, +// {tensorrt::provider_option_names::kBuilderOptimizationLevel, MakeStringWithClassicLocale(info.builder_optimization_level)}, +// {tensorrt::provider_option_names::kAuxiliaryStreams, MakeStringWithClassicLocale(info.auxiliary_streams)}, +// {tensorrt::provider_option_names::kTacticSources, MakeStringWithClassicLocale(info.tactic_sources)}, +// {tensorrt::provider_option_names::kExtraPluginLibPaths, MakeStringWithClassicLocale(info.extra_plugin_lib_paths)}, +// {tensorrt::provider_option_names::kProfilesMinShapes, MakeStringWithClassicLocale(info.profile_min_shapes)}, +// {tensorrt::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)}, +// {tensorrt::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)}, +// {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)}, +// {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.dump_ep_context_model)}, +// {tensorrt::provider_option_names::kEpContextFilePath, MakeStringWithClassicLocale(info.ep_context_file_path)}, +// {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.ep_context_embed_mode)}, +// {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.engine_hw_compatible)}, +// }; +// return options; +// } // -//ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensorRTProviderOptionsV2& info) { -// auto empty_if_null = [](const char* s) { return s != nullptr ? std::string{s} : std::string{}; }; -// const std::string kInt8CalibTable_ = empty_if_null(info.trt_int8_calibration_table_name); -// const std::string kEngineCachePath_ = empty_if_null(info.trt_engine_cache_path); -// const std::string kEngineCachePrefix_ = empty_if_null(info.trt_engine_cache_prefix); -// const std::string kTimingCachePath_ = empty_if_null(info.trt_timing_cache_path); -// const std::string kTacticSources_ = empty_if_null(info.trt_tactic_sources); -// const std::string kDecryptionLibPath_ = empty_if_null(info.trt_engine_decryption_lib_path); -// const std::string kExtraPluginLibPaths_ = empty_if_null(info.trt_extra_plugin_lib_paths); -// const std::string kProfilesMinShapes_ = empty_if_null(info.trt_profile_min_shapes); -// const std::string kProfilesMaxShapes_ = empty_if_null(info.trt_profile_max_shapes); -// const std::string kProfilesOptShapes_ = empty_if_null(info.trt_profile_opt_shapes); -// const std::string kEpContextFilePath_ = empty_if_null(info.trt_ep_context_file_path); -// const std::string kOnnxModelFolderPath_ = empty_if_null(info.trt_onnx_model_folder_path); +// ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensorRTProviderOptionsV2& info) { +// auto empty_if_null = [](const char* s) { return s != nullptr ? std::string{s} : std::string{}; }; +// const std::string kInt8CalibTable_ = empty_if_null(info.trt_int8_calibration_table_name); +// const std::string kEngineCachePath_ = empty_if_null(info.trt_engine_cache_path); +// const std::string kEngineCachePrefix_ = empty_if_null(info.trt_engine_cache_prefix); +// const std::string kTimingCachePath_ = empty_if_null(info.trt_timing_cache_path); +// const std::string kTacticSources_ = empty_if_null(info.trt_tactic_sources); +// const std::string kDecryptionLibPath_ = empty_if_null(info.trt_engine_decryption_lib_path); +// const std::string kExtraPluginLibPaths_ = empty_if_null(info.trt_extra_plugin_lib_paths); +// const std::string kProfilesMinShapes_ = empty_if_null(info.trt_profile_min_shapes); +// const std::string kProfilesMaxShapes_ = empty_if_null(info.trt_profile_max_shapes); +// const std::string kProfilesOptShapes_ = empty_if_null(info.trt_profile_opt_shapes); +// const std::string kEpContextFilePath_ = empty_if_null(info.trt_ep_context_file_path); +// const std::string kOnnxModelFolderPath_ = empty_if_null(info.trt_onnx_model_folder_path); // -// const ProviderOptions options{ -// {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, -// {tensorrt::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, -// {tensorrt::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, -// {tensorrt::provider_option_names::kMaxPartitionIterations, MakeStringWithClassicLocale(info.trt_max_partition_iterations)}, -// {tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.trt_min_subgraph_size)}, -// {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.trt_max_workspace_size)}, -// {tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.trt_fp16_enable)}, -// {tensorrt::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.trt_int8_enable)}, -// {tensorrt::provider_option_names::kInt8CalibTable, kInt8CalibTable_}, -// {tensorrt::provider_option_names::kInt8UseNativeCalibTable, MakeStringWithClassicLocale(info.trt_int8_use_native_calibration_table)}, -// {tensorrt::provider_option_names::kDLAEnable, MakeStringWithClassicLocale(info.trt_dla_enable)}, -// {tensorrt::provider_option_names::kDLACore, MakeStringWithClassicLocale(info.trt_dla_core)}, -// {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.trt_dump_subgraphs)}, -// {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.trt_engine_cache_enable)}, -// {tensorrt::provider_option_names::kEngineCachePath, kEngineCachePath_}, -// {tensorrt::provider_option_names::kEngineCachePrefix, kEngineCachePrefix_}, -// {tensorrt::provider_option_names::kWeightStrippedEngineEnable, MakeStringWithClassicLocale(info.trt_weight_stripped_engine_enable)}, -// {tensorrt::provider_option_names::kOnnxModelFolderPath, kOnnxModelFolderPath_}, -// {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.trt_engine_decryption_enable)}, -// {tensorrt::provider_option_names::kDecryptionLibPath, kDecryptionLibPath_}, -// {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.trt_force_sequential_engine_build)}, -// {tensorrt::provider_option_names::kContextMemorySharingEnable, MakeStringWithClassicLocale(info.trt_context_memory_sharing_enable)}, -// {tensorrt::provider_option_names::kLayerNormFP32Fallback, MakeStringWithClassicLocale(info.trt_layer_norm_fp32_fallback)}, -// {tensorrt::provider_option_names::kTimingCacheEnable, MakeStringWithClassicLocale(info.trt_timing_cache_enable)}, -// {tensorrt::provider_option_names::kTimingCachePath, kTimingCachePath_}, -// {tensorrt::provider_option_names::kForceTimingCacheMatch, MakeStringWithClassicLocale(info.trt_force_timing_cache)}, -// {tensorrt::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.trt_detailed_build_log)}, -// {tensorrt::provider_option_names::kBuildHeuristics, MakeStringWithClassicLocale(info.trt_build_heuristics_enable)}, -// {tensorrt::provider_option_names::kSparsityEnable, MakeStringWithClassicLocale(info.trt_sparsity_enable)}, -// {tensorrt::provider_option_names::kBuilderOptimizationLevel, MakeStringWithClassicLocale(info.trt_builder_optimization_level)}, -// {tensorrt::provider_option_names::kAuxiliaryStreams, MakeStringWithClassicLocale(info.trt_auxiliary_streams)}, -// {tensorrt::provider_option_names::kTacticSources, kTacticSources_}, -// {tensorrt::provider_option_names::kExtraPluginLibPaths, kExtraPluginLibPaths_}, -// {tensorrt::provider_option_names::kProfilesMinShapes, kProfilesMinShapes_}, -// {tensorrt::provider_option_names::kProfilesMaxShapes, kProfilesMaxShapes_}, -// {tensorrt::provider_option_names::kProfilesOptShapes, kProfilesOptShapes_}, -// {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.trt_cuda_graph_enable)}, -// {tensorrt::provider_option_names::kEpContextFilePath, kEpContextFilePath_}, -// {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.trt_dump_ep_context_model)}, -// {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.trt_ep_context_embed_mode)}, -// {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.trt_engine_hw_compatible)}, -// }; -// return options; -//} +// const ProviderOptions options{ +// {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, +// {tensorrt::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, +// {tensorrt::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, +// {tensorrt::provider_option_names::kMaxPartitionIterations, MakeStringWithClassicLocale(info.trt_max_partition_iterations)}, +// {tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.trt_min_subgraph_size)}, +// {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.trt_max_workspace_size)}, +// {tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.trt_fp16_enable)}, +// {tensorrt::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.trt_int8_enable)}, +// {tensorrt::provider_option_names::kInt8CalibTable, kInt8CalibTable_}, +// {tensorrt::provider_option_names::kInt8UseNativeCalibTable, MakeStringWithClassicLocale(info.trt_int8_use_native_calibration_table)}, +// {tensorrt::provider_option_names::kDLAEnable, MakeStringWithClassicLocale(info.trt_dla_enable)}, +// {tensorrt::provider_option_names::kDLACore, MakeStringWithClassicLocale(info.trt_dla_core)}, +// {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.trt_dump_subgraphs)}, +// {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.trt_engine_cache_enable)}, +// {tensorrt::provider_option_names::kEngineCachePath, kEngineCachePath_}, +// {tensorrt::provider_option_names::kEngineCachePrefix, kEngineCachePrefix_}, +// {tensorrt::provider_option_names::kWeightStrippedEngineEnable, MakeStringWithClassicLocale(info.trt_weight_stripped_engine_enable)}, +// {tensorrt::provider_option_names::kOnnxModelFolderPath, kOnnxModelFolderPath_}, +// {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.trt_engine_decryption_enable)}, +// {tensorrt::provider_option_names::kDecryptionLibPath, kDecryptionLibPath_}, +// {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.trt_force_sequential_engine_build)}, +// {tensorrt::provider_option_names::kContextMemorySharingEnable, MakeStringWithClassicLocale(info.trt_context_memory_sharing_enable)}, +// {tensorrt::provider_option_names::kLayerNormFP32Fallback, MakeStringWithClassicLocale(info.trt_layer_norm_fp32_fallback)}, +// {tensorrt::provider_option_names::kTimingCacheEnable, MakeStringWithClassicLocale(info.trt_timing_cache_enable)}, +// {tensorrt::provider_option_names::kTimingCachePath, kTimingCachePath_}, +// {tensorrt::provider_option_names::kForceTimingCacheMatch, MakeStringWithClassicLocale(info.trt_force_timing_cache)}, +// {tensorrt::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.trt_detailed_build_log)}, +// {tensorrt::provider_option_names::kBuildHeuristics, MakeStringWithClassicLocale(info.trt_build_heuristics_enable)}, +// {tensorrt::provider_option_names::kSparsityEnable, MakeStringWithClassicLocale(info.trt_sparsity_enable)}, +// {tensorrt::provider_option_names::kBuilderOptimizationLevel, MakeStringWithClassicLocale(info.trt_builder_optimization_level)}, +// {tensorrt::provider_option_names::kAuxiliaryStreams, MakeStringWithClassicLocale(info.trt_auxiliary_streams)}, +// {tensorrt::provider_option_names::kTacticSources, kTacticSources_}, +// {tensorrt::provider_option_names::kExtraPluginLibPaths, kExtraPluginLibPaths_}, +// {tensorrt::provider_option_names::kProfilesMinShapes, kProfilesMinShapes_}, +// {tensorrt::provider_option_names::kProfilesMaxShapes, kProfilesMaxShapes_}, +// {tensorrt::provider_option_names::kProfilesOptShapes, kProfilesOptShapes_}, +// {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.trt_cuda_graph_enable)}, +// {tensorrt::provider_option_names::kEpContextFilePath, kEpContextFilePath_}, +// {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.trt_dump_ep_context_model)}, +// {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.trt_ep_context_embed_mode)}, +// {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.trt_engine_hw_compatible)}, +// }; +// return options; +// } // ///** // * Update OrtTensorRTProviderOptionsV2 instance with ProviderOptions (map of string-based key-value pairs) @@ -250,7 +250,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions // * // * Note: If there is strncpy involved, please remember to deallocate or simply call C API ReleaseTensorRTProviderOptions. // */ -//void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy) { +// void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy) { // if (provider_options == nullptr) { // return; // } @@ -262,11 +262,11 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions // return (const char*)nullptr; // } else { // dest = new char[str_size + 1]; -//#ifdef _MSC_VER +// #ifdef _MSC_VER // strncpy_s(dest, str_size + 1, s_in.c_str(), str_size); -//#else +// #else // strncpy(dest, s_in.c_str(), str_size); -//#endif +// #endif // dest[str_size] = '\0'; // return (const char*)dest; // } diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h index f2614721..e0596e42 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h @@ -55,9 +55,9 @@ struct TensorrtExecutionProviderInfo { bool engine_hw_compatible{false}; static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); -// static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info); -// static ProviderOptions ToProviderOptions(const OrtTensorRTProviderOptionsV2& info); -// static void UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy); -// -// std::vector custom_op_domain_list; + // static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info); + // static ProviderOptions ToProviderOptions(const OrtTensorRTProviderOptionsV2& info); + // static void UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy); + // + // std::vector custom_op_domain_list; }; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc index 1fdd2e2e..59f3d1e1 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc @@ -78,7 +78,7 @@ void ORT_API_CALL TrtSyncStreamImpl::ReleaseImpl(_In_ OrtSyncStreamImpl* this_pt /*static*/ OrtStatus* TrtSyncNotificationImpl::Create(cudaStream_t stream, const ApiPtrs& apis, - std::unique_ptr& notification){ + std::unique_ptr& notification) { auto trt_sync_notification = std::make_unique(stream, apis); CUDA_RETURN_IF_ERROR(cudaEventCreateWithFlags(&trt_sync_notification->event_, cudaEventDisableTiming)); diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h index 34a72889..83a0fabc 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h @@ -13,10 +13,10 @@ // Class implementing Stream support for synchronization. // struct TrtSyncStreamImpl : public OrtSyncStreamImpl, public ApiPtrs { - TrtSyncStreamImpl(TensorrtExecutionProviderFactory& factory, - const OrtEp* ep, - uint32_t device_id, - const OrtKeyValuePairs* /*stream_options*/); + TrtSyncStreamImpl(TensorrtExecutionProviderFactory& factory, + const OrtEp* ep, + uint32_t device_id, + const OrtKeyValuePairs* /*stream_options*/); private: static OrtStatus* ORT_API_CALL CreateNotificationImpl(_In_ OrtSyncStreamImpl* this_ptr, diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h index 4685f386..84b35d6f 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h @@ -10,7 +10,7 @@ // #include "core/common/path_string.h" // #include "core/framework/murmurhash3.h" -#include"nv_includes.h" +#include "nv_includes.h" #include "gsl/narrow" #include @@ -36,7 +36,7 @@ bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignmen template AllocatorUniquePtr MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator, size_t count_or_bytes, - bool use_reserve = false) { + bool use_reserve = false) { size_t alloc_size = count_or_bytes; // if T is not void, 'count_or_bytes' == number of items so allow for that if constexpr (!std::is_void::value) { @@ -92,7 +92,7 @@ bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_mapgetType() == nvinfer1::LayerType::kCONSTANT) { @@ -145,7 +145,7 @@ bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map= 10 default: - //LOGS_DEFAULT(ERROR) << "Found unsupported datatype for layer " << const_layer_name; + // LOGS_DEFAULT(ERROR) << "Found unsupported datatype for layer " << const_layer_name; return false; } max_weight = std::max(max_weight, std::abs(weight)); @@ -159,7 +159,7 @@ bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map(source); @@ -241,8 +241,8 @@ nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { inline std::vector loadTimingCacheFile(const std::string inFileName) { std::ifstream iFile(inFileName, std::ios::in | std::ios::binary); if (!iFile) { - //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not read timing cache from: " << inFileName - // << ". A new timing cache will be generated and written."; + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not read timing cache from: " << inFileName + // << ". A new timing cache will be generated and written."; return std::vector(); } iFile.seekg(0, std::ifstream::end); @@ -257,7 +257,7 @@ inline std::vector loadTimingCacheFile(const std::string inFileName) { inline void saveTimingCacheFile(const std::string outFileName, const nvinfer1::IHostMemory* blob) { std::ofstream oFile(outFileName, std::ios::out | std::ios::binary); if (!oFile) { - //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not write timing cache to: " << outFileName; + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not write timing cache to: " << outFileName; return; } oFile.write((char*)blob->data(), blob->size()); diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index 43e69a56..7cda9126 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -73,12 +73,12 @@ OrtStatus* TensorrtExecutionProviderFactory::CreateMemoryInfoForDevices(int num_ } OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImpl( - OrtEpFactory* this_ptr, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) noexcept { + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { size_t& num_ep_devices = *p_num_ep_devices; auto* factory = static_cast(this_ptr); @@ -103,7 +103,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp // The ep options can be provided here as default values. // Users can also call SessionOptionsAppendExecutionProvider_V2 C API with provided ep options to override. - factory->ort_api.AddKeyValuePair(ep_metadata, "gpu_type", "data center"); // random example using made up values + factory->ort_api.AddKeyValuePair(ep_metadata, "gpu_type", "data center"); // random example using made up values factory->ort_api.AddKeyValuePair(ep_options, "trt_builder_optimization_level", "3"); // OrtEpDevice copies ep_metadata and ep_options. @@ -132,26 +132,26 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp ep_devices[num_ep_devices++] = ep_device; ++device_id; } - - // C++ API equivalent. Throws on error. - //{ - // Ort::ConstHardwareDevice device(devices[i]); - // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - // Ort::KeyValuePairs ep_metadata; - // Ort::KeyValuePairs ep_options; - // ep_metadata.Add("version", "0.1"); - // ep_options.Add("trt_builder_optimization_level", "3"); - // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()}; - // ep_devices[num_ep_devices++] = ep_device.release(); - // } - //} + + // C++ API equivalent. Throws on error. + //{ + // Ort::ConstHardwareDevice device(devices[i]); + // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // Ort::KeyValuePairs ep_metadata; + // Ort::KeyValuePairs ep_options; + // ep_metadata.Add("version", "0.1"); + // ep_options.Add("trt_builder_optimization_level", "3"); + // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()}; + // ep_devices[num_ep_devices++] = ep_device.release(); + // } + //} } - // Create gpu data transfer + // Create gpu data transfer auto data_transfer_impl = std::make_unique(static_cast(*factory), - factory->cuda_gpu_mem_devices, // device memory + factory->cuda_gpu_mem_devices, // device memory factory->cuda_pinned_mem_devices // shared memory - ); + ); factory->data_transfer_impl = std::move(data_transfer_impl); @@ -159,12 +159,12 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp } OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateEpImpl( - OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, - _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, _Out_ OrtEp** ep) noexcept { + OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, + _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, _Out_ OrtEp** ep) noexcept { auto* factory = static_cast(this_ptr); *ep = nullptr; @@ -210,7 +210,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl(Or // device memory. `allocator_options` can be used for arena configuration and there is a helper in ep_arena.h // to convert from OrtKeyValuePairs to the same arena config settings that ORT uses. // You are of course free to have completely different settings. - + const OrtMemoryDevice* mem_device = factory.ep_api.MemoryInfo_GetMemoryDevice(memory_info); uint32_t device_id = factory.ep_api.MemoryDevice_GetDeviceId(mem_device); @@ -226,7 +226,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl(Or *allocator = cuda_allocator.get(); factory.cuda_gpu_allocators[device_id] = std::move(cuda_allocator); - + } else if (factory.ep_api.MemoryDevice_GetMemoryType(mem_device) == OrtDeviceMemoryType_HOST_ACCESSIBLE) { // use the one that previously created if (factory.cuda_pinned_allocators.find(device_id) != factory.cuda_pinned_allocators.end()) { @@ -256,8 +256,8 @@ void ORT_API_CALL TensorrtExecutionProviderFactory::ReleaseAllocatorImpl(OrtEpFa } OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl( - OrtEpFactory* this_ptr, - OrtDataTransferImpl** data_transfer) noexcept { + OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept { auto& factory = *static_cast(this_ptr); *data_transfer = factory.data_transfer_impl.get(); diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h index 1fd7176f..df67ff1c 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -18,11 +18,11 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { // CUDA gpu memory and CUDA pinned memory are required for allocator and data transfer, these are the OrtMemoryInfo // instance required for that. // Current TRT EP implementation uses one default OrtMemoryInfo and one host accessible OrtMemoryInfo per ep device. - std::unordered_map cuda_gpu_memory_infos; // device id -> memory info + std::unordered_map cuda_gpu_memory_infos; // device id -> memory info std::unordered_map cuda_pinned_memory_infos; // Keeps allocators per ep device in factory so they can be shared across sessions. - std::unordered_map> cuda_gpu_allocators; // device id -> allocator + std::unordered_map> cuda_gpu_allocators; // device id -> allocator std::unordered_map> cuda_pinned_allocators; std::vector cuda_gpu_mem_devices; @@ -59,7 +59,7 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept; - const std::string ep_name_; // EP name - const std::string vendor_{"Nvidia"}; // EP vendor name + const std::string ep_name_; // EP name + const std::string vendor_{"Nvidia"}; // EP vendor name const std::string ep_version_{"0.1.0"}; // EP version }; \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/utils/common.h b/plugin_execution_providers/tensorrt/utils/common.h index eaf000a5..72e9ffbf 100644 --- a/plugin_execution_providers/tensorrt/utils/common.h +++ b/plugin_execution_providers/tensorrt/utils/common.h @@ -117,20 +117,20 @@ namespace onnxruntime { ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName); \ ORT_DISALLOW_MOVE(TypeName) -#define ORT_RETURN_IF_ERROR(expr) \ - do { \ - auto _status = (expr); \ - if ((!_status.IsOK())) { \ - return _status; \ - } \ +#define ORT_RETURN_IF_ERROR(expr) \ + do { \ + auto _status = (expr); \ + if ((!_status.IsOK())) { \ + return _status; \ + } \ } while (0) -#define ORT_THROW_IF_ERROR(expr) \ - do { \ - auto _status = (expr); \ - if ((!_status.IsOK())) { \ - ORT_THROW(_status); \ - } \ +#define ORT_THROW_IF_ERROR(expr) \ + do { \ + auto _status = (expr); \ + if ((!_status.IsOK())) { \ + ORT_THROW(_status); \ + } \ } while (0) // use this macro when cannot early return diff --git a/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h b/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h index d95e6a71..9043eb1c 100644 --- a/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h +++ b/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h @@ -17,24 +17,24 @@ std::conditional_t CudaCall( ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line) { if (retCode != successCode) { try { -//#ifdef _WIN32 - //std::string hostname_str = GetEnvironmentVar("COMPUTERNAME"); - //if (hostname_str.empty()) { - //hostname_str = "?"; + // #ifdef _WIN32 + // std::string hostname_str = GetEnvironmentVar("COMPUTERNAME"); + // if (hostname_str.empty()) { + // hostname_str = "?"; //} - //const char* hostname = hostname_str.c_str(); -//#else - //char hostname[HOST_NAME_MAX]; - //if (gethostname(hostname, HOST_NAME_MAX) != 0) - //strcpy(hostname, "?"); -//#endif + // const char* hostname = hostname_str.c_str(); + // #else + // char hostname[HOST_NAME_MAX]; + // if (gethostname(hostname, HOST_NAME_MAX) != 0) + // strcpy(hostname, "?"); + // #endif int currentCudaDevice = -1; cudaGetDevice(¤tCudaDevice); cudaGetLastError(); // clear last CUDA error static char str[1024]; snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=? ; file=%s ; line=%d ; expr=%s; %s", libName, (int)retCode, CudaErrString(retCode), currentCudaDevice, - //hostname, + // hostname, file, line, exprString, msg); if constexpr (THRW) { // throw an exception with the error info @@ -55,9 +55,9 @@ std::conditional_t CudaCall( } } -//template -//std::conditional_t CudaCall( - //ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line); +// template +// std::conditional_t CudaCall( +// ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line); #define CUDA_CALL(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) #define CUDA_CALL_THROW(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) diff --git a/plugin_execution_providers/tensorrt/utils/ep_utils.h b/plugin_execution_providers/tensorrt/utils/ep_utils.h index ded47204..a1ba2bfe 100644 --- a/plugin_execution_providers/tensorrt/utils/ep_utils.h +++ b/plugin_execution_providers/tensorrt/utils/ep_utils.h @@ -2,14 +2,14 @@ #include "onnxruntime_cxx_api.h" -//#include "flatbuffers/idl.h" -//#include "ort_trt_int8_cal_table.fbs.h" +// #include "flatbuffers/idl.h" +// #include "ort_trt_int8_cal_table.fbs.h" #include "make_string.h" // #include "core/providers/cuda/cuda_pch.h" // #include "core/common/path_string.h" // #include "core/framework/murmurhash3.h" -//#include"nv_includes.h" +// #include"nv_includes.h" #include "gsl/narrow" #include @@ -26,11 +26,11 @@ struct ApiPtrs { const OrtModelEditorApi& model_editor_api; }; -#define ENFORCE(condition, ...) \ - do { \ - if (!(condition)) { \ - throw std::runtime_error(MakeString(__VA_ARGS__)); \ - } \ +#define ENFORCE(condition, ...) \ + do { \ + if (!(condition)) { \ + throw std::runtime_error(MakeString(__VA_ARGS__)); \ + } \ } while (false) #define THROW(...) \ @@ -53,11 +53,11 @@ std::string ComposeString(Args&&... args) { }; */ -#define RETURN_IF(cond, ...) \ - do { \ - if ((cond)) { \ +#define RETURN_IF(cond, ...) \ + do { \ + if ((cond)) { \ return Ort::GetApi().CreateStatus(ORT_EP_FAIL, MakeString(__VA_ARGS__).c_str()); \ - } \ + } \ } while (0) #define RETURN_IF_NOT(condition, ...) RETURN_IF(!(condition), __VA_ARGS__) @@ -83,7 +83,7 @@ std::string ComposeString(Args&&... args) { std::cerr << Ort::GetApi().GetErrorMessage(status) << std::endl; \ return false; \ } \ - } while (0) + } while (0) // Helper to release Ort one or more objects obtained from the public C API at the end of their scope. template diff --git a/plugin_execution_providers/tensorrt/utils/exceptions.h b/plugin_execution_providers/tensorrt/utils/exceptions.h index 19c1586a..4f166da3 100644 --- a/plugin_execution_providers/tensorrt/utils/exceptions.h +++ b/plugin_execution_providers/tensorrt/utils/exceptions.h @@ -11,24 +11,24 @@ #include #include "common.h" -//#include "code_location.h" +// #include "code_location.h" namespace onnxruntime { class NotImplementedException : public std::logic_error { public: - explicit NotImplementedException(const char* _Message = "Function not yet implemented") noexcept : std::logic_error(_Message){}; - explicit NotImplementedException(const std::string& _Message = "Function not yet implemented") noexcept : std::logic_error(_Message){}; + explicit NotImplementedException(const char* _Message = "Function not yet implemented") noexcept : std::logic_error(_Message) {}; + explicit NotImplementedException(const std::string& _Message = "Function not yet implemented") noexcept : std::logic_error(_Message) {}; }; class TypeMismatchException : public std::logic_error { public: - TypeMismatchException() noexcept : logic_error("Type mismatch"){}; + TypeMismatchException() noexcept : logic_error("Type mismatch") {}; }; class OnnxRuntimeException : public std::exception { public: - // code location is not provided for now + // code location is not provided for now /* OnnxRuntimeException(const CodeLocation& location, const std::string& msg) noexcept : OnnxRuntimeException(location, nullptr, msg) { @@ -83,7 +83,7 @@ class OnnxRuntimeException : public std::exception { } private: - //const CodeLocation location_; + // const CodeLocation location_; const std::vector stacktrace_; std::string what_; }; diff --git a/plugin_execution_providers/tensorrt/utils/path_string.h b/plugin_execution_providers/tensorrt/utils/path_string.h index fd638aa5..3856c7c2 100644 --- a/plugin_execution_providers/tensorrt/utils/path_string.h +++ b/plugin_execution_providers/tensorrt/utils/path_string.h @@ -22,10 +22,10 @@ #define ORT_TSTR_CONVERT_FROM_STRING(X) X #endif -//#include "core/common/common.h" -//#include "core/session/onnxruntime_c_api.h" +// #include "core/common/common.h" +// #include "core/session/onnxruntime_c_api.h" -//#include "common.h" +// #include "common.h" namespace onnxruntime { // char type for filesystem paths diff --git a/plugin_execution_providers/tensorrt/utils/provider_options_utils.h b/plugin_execution_providers/tensorrt/utils/provider_options_utils.h index 9a02d272..4bf6b37c 100644 --- a/plugin_execution_providers/tensorrt/utils/provider_options_utils.h +++ b/plugin_execution_providers/tensorrt/utils/provider_options_utils.h @@ -150,7 +150,7 @@ class ProviderOptionsParser { RETURN_IF_NOT( (parse_status == nullptr), "Failed to parse provider option \"", name, "\": "); - //"Failed to parse provider option \"", name, "\": ", parse_status.ErrorMessage()); + //"Failed to parse provider option \"", name, "\": ", parse_status.ErrorMessage()); } return nullptr; From e4c24058b32af8a4d4cc94638500186b0bd46d29 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 29 Aug 2025 10:56:19 -0700 Subject: [PATCH 50/60] Update to use new API OpAttr_GetTensorAttributeAsOrtValue --- .../tensorrt/utils/ort_graph_to_proto.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h b/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h index da63f632..28ce4439 100644 --- a/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h +++ b/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h @@ -122,7 +122,7 @@ #define INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ #include -#include "onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_cxx_api.h" #include "onnx/onnx_pb.h" namespace OrtEpUtils { @@ -232,7 +232,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_ /*out*/ std::vector& dims, /*out*/ std::vector& symbolic_dims); static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, onnx::GraphProto& graph_proto, @@ -379,7 +379,7 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, } onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_node, *ort_attr, *attr_proto)); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); } } @@ -652,7 +652,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, return Ort::Status{nullptr}; } -static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { const OrtApi& ort_api = Ort::GetApi(); const char* attr_name = nullptr; @@ -766,7 +766,7 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or // TensorProto as an attribute value doesn't require a name. OrtValue* ort_value = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetTensorAttributeAsOrtValue(&ort_attr, &ort_value)); Ort::Value tensor(ort_value); From 2472a158c2df450ba2e350e8d66a8f0cf067f436 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 8 Sep 2025 13:45:34 -0700 Subject: [PATCH 51/60] remove unnecessary files --- .../tensorrt/utils/helper.ccc | 59 ------------ .../tensorrt/utils/status.ccc | 91 ------------------- 2 files changed, 150 deletions(-) delete mode 100644 plugin_execution_providers/tensorrt/utils/helper.ccc delete mode 100644 plugin_execution_providers/tensorrt/utils/status.ccc diff --git a/plugin_execution_providers/tensorrt/utils/helper.ccc b/plugin_execution_providers/tensorrt/utils/helper.ccc deleted file mode 100644 index 7a889c30..00000000 --- a/plugin_execution_providers/tensorrt/utils/helper.ccc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include "common.h" - -#ifdef _WIN32 -#include -#include -#endif - -namespace onnxruntime { -#ifdef _WIN32 -std::string ToUTF8String(const std::wstring& s) { - if (s.size() >= static_cast(std::numeric_limits::max())) - ORT_THROW("length overflow"); - - const int src_len = static_cast(s.size() + 1); - const int len = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, nullptr, 0, nullptr, nullptr); - assert(len > 0); - std::string ret(static_cast(len) - 1, '\0'); -#pragma warning(disable : 4189) - const int r = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, (char*)ret.data(), len, nullptr, nullptr); - assert(len == r); -#pragma warning(default : 4189) - return ret; -} - -std::wstring ToWideString(const std::string& s) { - if (s.size() >= static_cast(std::numeric_limits::max())) - ORT_THROW("length overflow"); - - const int src_len = static_cast(s.size() + 1); - const int len = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, nullptr, 0); - assert(len > 0); - std::wstring ret(static_cast(len) - 1, '\0'); -#pragma warning(disable : 4189) - const int r = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, (wchar_t*)ret.data(), len); - assert(len == r); -#pragma warning(default : 4189) - return ret; -} -#endif // #ifdef _WIN32 - -#ifdef ORT_NO_EXCEPTIONS -void PrintFinalMessage(const char* msg) { -#if defined(__ANDROID__) - __android_log_print(ANDROID_LOG_ERROR, "onnxruntime", "%s", msg); -#else - // TODO, consider changing the output of the error message from std::cerr to logging when the - // exceptions are disabled, since using std::cerr might increase binary size, and std::cerr output - // might not be easily accessible on some systems such as mobile - // TODO, see if we need to change the output of the error message from std::cerr to NSLog for iOS - std::cerr << msg << std::endl; -#endif -} -#endif // #ifdef ORT_NO_EXCEPTIONS - -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/status.ccc b/plugin_execution_providers/tensorrt/utils/status.ccc deleted file mode 100644 index b3a89c8c..00000000 --- a/plugin_execution_providers/tensorrt/utils/status.ccc +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Modifications Copyright (c) Microsoft. - -#include "status.h" -#include "common.h" - -namespace onnxruntime { -namespace common { -Status::Status(StatusCategory category, int code, const std::string& msg) { - // state_ will be allocated here causing the status to be treated as a failure - ORT_ENFORCE(code != static_cast(common::OK)); - - state_ = std::make_unique(category, code, msg); -} - -Status::Status(StatusCategory category, int code, const char* msg) { - // state_ will be allocated here causing the status to be treated as a failure - ORT_ENFORCE(code != static_cast(common::OK)); - - state_ = std::make_unique(category, code, msg); -} - -Status::Status(StatusCategory category, int code) - : Status(category, code, "") { -} - -StatusCategory Status::Category() const noexcept { - return IsOK() ? common::NONE : state_->category; -} - -int Status::Code() const noexcept { - return IsOK() ? static_cast(common::OK) : state_->code; -} - -const std::string& Status::ErrorMessage() const noexcept { - return IsOK() ? EmptyString() : state_->msg; -} - -std::string Status::ToString() const { - if (state_ == nullptr) { - return std::string("OK"); - } - - std::string result; - - if (common::SYSTEM == state_->category) { - result += "SystemError"; - result += " : "; - result += std::to_string(errno); - } else if (common::ONNXRUNTIME == state_->category) { - result += "[ONNXRuntimeEPError]"; - result += " : "; - result += std::to_string(Code()); - result += " : "; - result += StatusCodeToString(static_cast(Code())); - result += " : "; - result += state_->msg; - } - - return result; -} - -// GSL_SUPRESS(i.22) is broken. Ignore the warnings for the static local variables that are trivial -// and should not have any destruction order issues via pragmas instead. -// https://developercommunity.visualstudio.com/content/problem/249706/gslsuppress-does-not-work-for-i22-c-core-guideline.html -#ifdef _MSC_VER -#pragma warning(push) -#pragma warning(disable : 26426) -#endif - -const std::string& Status::EmptyString() noexcept { - static std::string s_empty; - return s_empty; -} - -#ifdef _MSC_VER -#pragma warning(pop) -#endif - -} // namespace common -} // namespace onnxruntime From ab8cd70b306d8784ba12921488a980b989924e56 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 10 Sep 2025 10:15:09 -0700 Subject: [PATCH 52/60] Add default logger for TRT logger --- .../tensorrt/onnx_ctx_model_helper.cc | 3 ++- .../tensorrt/tensorrt_execution_provider.cc | 12 +++++----- .../tensorrt/tensorrt_execution_provider.h | 22 ++++++++++++++----- .../tensorrt/tensorrt_provider_factory.cc | 8 +++---- .../tensorrt/tensorrt_provider_factory.h | 3 ++- 5 files changed, 32 insertions(+), 16 deletions(-) diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc index 76e5553d..4e803383 100644 --- a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc @@ -9,7 +9,8 @@ #include "onnx_ctx_model_helper.h" #include "onnx/onnx_pb.h" -extern TensorrtLogger& GetTensorrtLogger(bool verbose_log); +extern TensorrtLogger& GetTensorrtLogger(bool verbose_log, const OrtLogger& ort_default_logger, + const OrtApi* ort_api); /* * Check whether the graph has the EP context node. diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 84d6cd0c..cfba6341 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -84,9 +84,11 @@ void OutputAllocator::notifyShape(char const* /*tensorName*/, nvinfer1::Dims con } } -TensorrtLogger& GetTensorrtLogger(bool verbose_log) { +TensorrtLogger& GetTensorrtLogger(bool verbose_log, + const OrtLogger& ort_default_logger, + const OrtApi* ort_api) { const auto log_level = verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING; - static TensorrtLogger trt_logger(log_level); + static TensorrtLogger trt_logger(ort_default_logger, ort_api, log_level); if (log_level != trt_logger.get_level()) { trt_logger.set_level(verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING); } @@ -1041,7 +1043,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this model_proto.SerializeToOstream(&dump); } - TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_, logger_, &ort_api); auto trt_builder = GetBuilder(trt_logger); auto network_flags = 0; #if NV_TENSORRT_MAJOR > 8 @@ -2021,7 +2023,7 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine( } // weight-stripped engine refit logic - TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log); + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log, logger_, &ort_api); auto refitter = std::unique_ptr(nvinfer1::createInferRefitter(*trt_engine, trt_logger)); auto parser_refitter = std::unique_ptr(nvonnxparser::createParserRefitter(*refitter, trt_logger)); @@ -2378,7 +2380,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa { auto lock = GetApiLock(); - runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_))); + runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_, logger_, &ort_api))); } // EP Context setting diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index fcd2b507..a9757260 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -66,10 +66,14 @@ using DestroyFunc = void (*)(void*, void*); class TensorrtLogger : public nvinfer1::ILogger { nvinfer1::ILogger::Severity verbosity_; + const OrtLogger& ort_default_logger_; + const OrtApi* ort_api_ = nullptr; public: - TensorrtLogger(Severity verbosity = Severity::kWARNING) - : verbosity_(verbosity) {} + TensorrtLogger(const OrtLogger& ort_default_logger, + const OrtApi* ort_api, + Severity verbosity = Severity::kWARNING) + : ort_default_logger_{ort_default_logger}, ort_api_{ort_api}, verbosity_(verbosity) {} void log(Severity severity, const char* msg) noexcept override { if (severity <= verbosity_) { time_t rawtime = std::time(0); @@ -87,11 +91,19 @@ class TensorrtLogger : public nvinfer1::ILogger { : severity == Severity::kWARNING ? "WARNING" : severity == Severity::kINFO ? " INFO" : "UNKNOWN"); + OrtLoggingLevel ort_severity; if (severity <= Severity::kERROR) { - // LOGS_DEFAULT(ERROR) << "[" << buf << " " << sevstr << "] " << msg; - } else { - // LOGS_DEFAULT(WARNING) << "[" << buf << " " << sevstr << "] " << msg; + ort_severity = ORT_LOGGING_LEVEL_ERROR; } + else { + ort_severity = ORT_LOGGING_LEVEL_WARNING; + } + + std::string message = "[" + std::string(buf) + " " + std::string(sevstr) + "] " + std::string(msg); + + Ort::ThrowOnError(ort_api_->Logger_LogMessage(&ort_default_logger_, + ort_severity, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } } void set_level(Severity verbosity) { diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index 7cda9126..458cdc0a 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -11,8 +11,8 @@ #include #include -TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* ep_name, ApiPtrs apis) - : ApiPtrs(apis), ep_name_{ep_name} { +TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* ep_name, const OrtLogger& default_logger, ApiPtrs apis) + : ApiPtrs(apis), default_logger_{default_logger}, ep_name_{ep_name} { ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; @@ -280,14 +280,14 @@ extern "C" { // Public symbols // EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const OrtApiBase* ort_api_base, - const OrtLogger*, + const OrtLogger* default_logger, OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); const OrtEpApi* ort_ep_api = ort_api->GetEpApi(); const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); // Factory could use registration_name or define its own EP name. - std::unique_ptr factory = std::make_unique(registration_name, ApiPtrs{*ort_api, *ort_ep_api, *model_editor_api}); + std::unique_ptr factory = std::make_unique(registration_name, *default_logger, ApiPtrs{*ort_api, *ort_ep_api, *model_editor_api}); if (max_factories < 1) { return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h index df67ff1c..9ace8dd9 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -11,7 +11,7 @@ using MemoryInfoUniquePtr = std::unique_ptr Date: Wed, 10 Sep 2025 14:08:52 -0700 Subject: [PATCH 53/60] Add default logger for TRT EP --- .../tensorrt/tensorrt_execution_provider.cc | 365 ++++++++++++++---- 1 file changed, 286 insertions(+), 79 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index cfba6341..070a3dde 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -978,11 +978,20 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this const size_t number_of_subgraphs = supported_nodes_vector.size(); if (number_of_trt_nodes == 0) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] No graph will run on TensorRT execution provider"; + std::string message = "[TensorRT EP] No graph will run on TensorRT execution provider"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } else if (number_of_trt_nodes == nodes.size()) { - // LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; + std::string message = "[TensorRT EP] Whole graph will run on TensorRT execution provider"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } else { - // LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; + std::string message = "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " + std::to_string(number_of_subgraphs); + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } return nullptr; @@ -1074,7 +1083,10 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this if (layer->getType() == nvinfer1::LayerType::kELEMENTWISE && next_layer->getType() == nvinfer1::LayerType::kREDUCE && (static_cast(layer))->getOperation() == nvinfer1::ElementWiseOperation::kPOW) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow"; + std::string message = "[TensorRT EP] Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); layer->setPrecision(nvinfer1::DataType::kFLOAT); next_layer->setPrecision(nvinfer1::DataType::kFLOAT); layer->setOutputType(0, nvinfer1::DataType::kFLOAT); @@ -1241,8 +1253,10 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this #pragma warning(pop) #endif fp16_enable_ = false; - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE or ORT_TENSORRT_BF16_ENABLE is set, but " - // "platform doesn't support fast native fp16/bf16"; + std::string message = "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE or ORT_TENSORRT_BF16_ENABLE is set, but platform doesn't support fast native fp16/bf16"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } } @@ -1283,12 +1297,18 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this if (fp16_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); trt_node_name_with_precision += "_fp16"; - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; + std::string message = "[TensorRT EP] FP16 mode is enabled"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } if (int8_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); trt_node_name_with_precision += "_int8"; - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; + std::string message = "[TensorRT EP] INT8 mode is enabled"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } #if defined(_MSC_VER) #pragma warning(pop) @@ -1298,16 +1318,25 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8 int number_of_dla_core = trt_builder->getNbDLACores(); if (number_of_dla_core == 0) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core"; + std::string message = "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); dla_enable_ = false; } else { if (dla_core_ >= number_of_dla_core) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core #" << dla_core_ - // << ", but it exceeds platform's maximum DLA core number " << number_of_dla_core - // << ". Use DLA core 0 instead."; + std::string message = "[TensorRT EP] Try to use DLA core #" + std::to_string(dla_core_) + + std::string(", but it exceeds platform's maximum DLA core number ") + std::to_string(number_of_dla_core) + + std::string(". Use DLA core 0 instead."); + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); dla_core_ = 0; } - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << dla_core_; + std::string message = "[TensorRT EP] use DLA core " + dla_core_; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); trt_config->setDLACore(dla_core_); @@ -1319,7 +1348,10 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this // enable sparse weights if (sparsity_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + std::string message = "[TensorRT EP] Sparse weights are allowed"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } #if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 if (build_heuristics_enable_) { @@ -1332,11 +1364,17 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this // for TRT 8.6 onwards, heuristic-based tactic option is automatically enabled by setting builder optimization level 2 if (build_heuristics_enable_) { if (builder_optimization_level_ == 2) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are automatically enabled by builder optimization " - // "level 2. trt_build_heuristics_enable is deprecated on TRT 8.6 onwards."; + std::string message = "[TensorRT EP] Builder heuristics are automatically enabled by builder optimization " + + std::string("level 2. trt_build_heuristics_enable is deprecated on TRT 8.6 onwards."); + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } else { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] trt_build_heuristics_enable is deprecated on TRT 8.6 onwards. Please set " - // "builder optimization level as 2 to enable builder heuristics."; + std::string message = "[TensorRT EP] trt_build_heuristics_enable is deprecated on TRT 8.6 onwards. Please set " + + std::string("builder optimization level as 2 to enable builder heuristics."); + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } } #endif @@ -1345,13 +1383,19 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this // switch optimizaion level if (builder_optimization_level_ != 3) { trt_config->setBuilderOptimizationLevel(builder_optimization_level_); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + std::string message = "[TensorRT EP] Builder optimization level is set to " + builder_optimization_level_; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } // limit auxiliary streams if (auxiliary_streams_ >= 0) { trt_config->setMaxAuxStreams(auxiliary_streams_); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << auxiliary_streams_; + std::string message = "[TensorRT EP] Auxiliary streams are se to " + auxiliary_streams_; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } #else if (builder_optimization_level_ != 3) { @@ -1365,9 +1409,15 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this if (weight_stripped_engine_enable_) { #if NV_TENSORRT_MAJOR >= 10 trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; + std::string message = "[TensorRT EP] STRIP_PLAN is enabled"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; + message = "[TensorRT EP] REFIT_IDENTICAL is enabled"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); #else LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; #endif @@ -1378,7 +1428,10 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this nvinfer1::TacticSources tactics = trt_config->getTacticSources(); tactics |= GetTacticSourceFromString(tactic_sources_); trt_config->setTacticSources(tactics); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using " << tactic_sources_; + std::string message = "[TensorRT EP] Tactic sources are limited using " + tactic_sources_; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } // Build TRT engine (if needed) and load TRT engine if: @@ -1406,7 +1459,10 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this if (engine_cache_enable_ && engine_hw_compatible_) { trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); cache_hw_compat = "_sm80+"; - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; + std::string message = "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } #endif @@ -1446,9 +1502,15 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this engine_update = CompareProfiles(profile_cache_path, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); if (engine_update) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine will be built"; + std::string message = "[TensorRT EP] Engine will be built"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } else { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine won't be rebuilt"; + std::string message = "[TensorRT EP] Engine won't be rebuilt"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } } @@ -1461,7 +1523,10 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this engine_file.read((char*)engine_buf.get(), engine_size); trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + std::string message = "[TensorRT EP] DeSerialized " + engine_cache_path; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); if (trt_engine == nullptr) { std::string err_msg = "TensorRT EP could not deserialize engine from cache: " + engine_cache_path; return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); @@ -1483,7 +1548,10 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this // Deserialize engine trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + std::string message = "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); if (trt_engine == nullptr) { std::string err_msg = "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path; return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); @@ -1517,7 +1585,10 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); if (detailed_build_log_) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; + std::string message = "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } } @@ -1551,7 +1622,10 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this // Serialize engine profile if it has explicit profiles if (has_explicit_profile) { SerializeProfileV2(profile_cache_path, input_explicit_shape_ranges); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; + std::string message = "[TensorRT EP] Serialized " + profile_cache_path; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } if (engine_decryption_enable_) { @@ -1563,7 +1637,10 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this std::string err_msg = "TensorRT EP call to engine encryption library failed"; return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + std::string message = "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } else { // LOGS_DEFAULT(WARNING) // << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; @@ -1571,7 +1648,10 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } else { std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path; + std::string message = "[TensorRT EP] Serialized engine " + engine_cache_path; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } } // serialize and save timing cache @@ -1584,14 +1664,20 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); if (detailed_build_log_) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; + std::string message = "[TensorRT EP] Serialized timing cache " + timing_cache_path; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } } } } if (weight_stripped_engine_refit_) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refit engine from main ONNX file after engine build"; + std::string message = "[TensorRT EP] Refit engine from main ONNX file after engine build"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); char* onnx = string_buf.data(); size_t onnx_size = string_buf.size(); auto status = RefitEngine(model_path_, onnx_model_folder_path_, engine_cache_path, @@ -1989,6 +2075,7 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine( std::string onnx_model_filename, std::string& onnx_model_folder_path, std::string& weight_stripped_engine_cath_path, bool path_check, const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, bool detailed_build_log) { + #if NV_TENSORRT_MAJOR >= 10 bool refit_from_file = onnx_model_bytestream == nullptr && onnx_model_bytestream_size == 0; std::filesystem::path onnx_model_path{onnx_model_folder_path}; @@ -2028,7 +2115,10 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine( auto parser_refitter = std::unique_ptr(nvonnxparser::createParserRefitter(*refitter, trt_logger)); if (refit_from_file) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from file on disk: " << onnx_model_path.string(); + std::string message = "[TensorRT EP] Refitting from file on disk: " + onnx_model_path.string(); + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { std::string err_msg = "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with " @@ -2037,7 +2127,10 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine( return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } } else { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from byte array"; + std::string message = "[TensorRT EP] Refitting from byte array"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) { std::string err_msg = "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with " @@ -2046,7 +2139,10 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine( } } if (refitter->refitCudaEngine()) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully refitted the weight-stripped engine."; + std::string message = "[TensorRT EP] Successfully refitted the weight-stripped engine."; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } else { std::string err_msg = "TensorRT EP's IRefitter could not refit deserialized weight-stripped engine with weights contained in: " + @@ -2060,7 +2156,10 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine( nvinfer1::IHostMemory* serialized_engine = trt_engine->serialize(); std::ofstream engine_file(refitted_engine_cache, std::ios::binary | std::ios::out); engine_file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialize the refitted engine to " << refitted_engine_cache; + std::string message = "[TensorRT EP] Serialize the refitted engine to " + refitted_engine_cache; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } return nullptr; #else @@ -2226,19 +2325,31 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa // Validate setting if (max_partition_iterations_ <= 0) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] TensorRT option trt_max_partition_iterations must be a positive integer value. Set it to 1000"; + std::string message = "[TensorRT EP] TensorRT option trt_max_partition_iterations must be a positive integer value. Set it to 1000"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); max_partition_iterations_ = 1000; } if (min_subgraph_size_ <= 0) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] TensorRT option trt_min_subgraph_size must be a positive integer value. Set it to 1"; + std::string message = "[TensorRT EP] TensorRT option trt_min_subgraph_size must be a positive integer value. Set it to 1"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); min_subgraph_size_ = 1; } if (max_workspace_size_ <= 0) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] TensorRT option trt_max_workspace_size must be a positive integer value. Set it to 1073741824 (1GB)"; + std::string message = "[TensorRT EP] TensorRT option trt_max_workspace_size must be a positive integer value. Set it to 1073741824 (1GB)"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); max_workspace_size_ = 1 << 30; } if (dla_core_ < 0) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] TensorRT option trt_dla_core must be a non-negative integer value. Set it to 0"; + std::string message = "[TensorRT EP] TensorRT option trt_dla_core must be a non-negative integer value. Set it to 0"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); dla_core_ = 0; } @@ -2277,14 +2388,23 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa if (engine_cache_enable_ && engine_hw_compatible_) { #if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 if (std::stoi(compute_capability_) < 80) { - // LOGS_DEFAULT(WARNING) << "Engine hardware compatibility cannot be enabled as GPU arch < 80. "; + std::string message = "Engine hardware compatibility cannot be enabled as GPU arch < 80. "; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); engine_hw_compatible_ = false; } else if (std::stoi(compute_capability_) == 87) { - // LOGS_DEFAULT(WARNING) << "Engine hardware compatibility cannot be enabled on Jetson Orin. "; + std::string message = "Engine hardware compatibility cannot be enabled on Jetson Orin. "; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); engine_hw_compatible_ = false; } #else - // LOGS_DEFAULT(WARNING) << "Engine hardware compatibility cannot be enabled as TRT < 8.6. "; + std::string message = "Engine hardware compatibility cannot be enabled as TRT < 8.6. "; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); engine_hw_compatible_ = false; #endif } @@ -2337,7 +2457,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa // status = ParseProfileShapes(profile_min_shapes, profile_min_shapes_); // if (!status) { // profile_min_shapes_.clear(); - // // LOGS_DEFAULT(WARNING) << "[TensorRT EP] The format of provider option 'trt_profile_min_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + // std::string message = "[TensorRT EP] The format of provider option 'trt_profile_min_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + // Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + // OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + // message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); // } // } @@ -2345,7 +2468,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa // status = ParseProfileShapes(profile_max_shapes, profile_max_shapes_); // if (!status) { // profile_max_shapes_.clear(); - // // LOGS_DEFAULT(WARNING) << "[TensorRT EP] The format of provider option 'trt_profile_max_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + // std::string message = "[TensorRT EP] The format of provider option 'trt_profile_max_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + // Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + // OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + // message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); // } // } @@ -2353,15 +2479,24 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa // status = ParseProfileShapes(profile_opt_shapes, profile_opt_shapes_); // if (!status) { // profile_opt_shapes_.clear(); - // // LOGS_DEFAULT(WARNING) << "[TensorRT EP] The format of provider option 'trt_profile_opt_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + // std::string message = "[TensorRT EP] The format of provider option 'trt_profile_opt_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + // Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + // OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + // message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); // } // } // if (status) { // status = ValidateProfileShapes(profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); // if (!status) { - // // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Profile shapes validation failed. Make sure the provider options 'trt_profile_min_shapes', 'trt_profile_max_shapes' and 'trt_profile_opt_shapes' have same input name and number of profile."; - // // LOGS_DEFAULT(WARNING) << "[TensorRT EP] TRT EP will implicitly create optimization profiles based on input tensor for you."; + // std::string message = "[TensorRT EP] Profile shapes validation failed. Make sure the provider options 'trt_profile_min_shapes', 'trt_profile_max_shapes' and 'trt_profile_opt_shapes' have same input name and number of profile."; + // Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + // OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + // message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + // message = "[TensorRT EP] TRT EP will implicitly create optimization profiles based on input tensor for you."; + // Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + // OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + // message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); // profile_min_shapes_.clear(); // profile_max_shapes_.clear(); // profile_opt_shapes_.clear(); @@ -2530,8 +2665,10 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* #if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 if (engine_cache_enable && engine_hw_compatible) { cache_hw_compat = "_sm80+"; - // LOGS_DEFAULT(VERBOSE) - // << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; + std::string message = "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } #endif @@ -2562,7 +2699,10 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* if (engine_file && !trt_state->engine_decryption_enable && profile_file) { // Deserialize profile shape_ranges = DeserializeProfileV2(profile_file); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + std::string message = "[TensorRT EP] DeSerialized " + profile_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); // Prepare buffer engine_file.seekg(0, std::ios::end); @@ -2581,14 +2721,20 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* std::string err_msg = "TensorRT EP Failed to Build Engine."; return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + message = "[TensorRT EP] DeSerialized " + engine_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); trt_engine = trt_state->engine->get(); context_update = true; } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) { shape_ranges = DeserializeProfileV2(profile_file); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + std::string message = "[TensorRT EP] DeSerialized " + profile_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); // Decrypt engine size_t engine_size = 0; if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { @@ -2610,7 +2756,10 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* std::string err_msg = "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path; return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + message = "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); trt_engine = trt_state->engine->get(); context_update = true; } @@ -2671,18 +2820,27 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* // Set precision if (trt_state->int8_enable) { trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; + std::string message = "[TensorRT EP] INT8 mode is enabled"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } if (trt_state->fp16_enable) { trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; + std::string message = "[TensorRT EP] FP16 mode is enabled"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } #if defined(_MSC_VER) #pragma warning(pop) #endif // Set DLA (DLA can only run with FP16 or INT8) if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; + std::string message = "[TensorRT EP] use DLA core " + trt_state->dla_core; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); trt_config->setDLACore(trt_state->dla_core); @@ -2691,42 +2849,69 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* // enable sparse weights if (trt_state->sparsity_enable) { trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + std::string message = "[TensorRT EP] Sparse weights are allowed"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } #if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 // enable builder heuristics if (trt_state->build_heuristics_enable) { trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; + std::string message = "[TensorRT EP] Builder heuristics are enabled"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } #elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 // switch optimizaion level if (trt_state->builder_optimization_level != 3) { trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + std::string message = "[TensorRT EP] Builder optimization level is set to " + trt_state->builder_optimization_level; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } // limit auxiliary streams if (trt_state->auxiliary_streams >= 0) { trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; + std::string message = "[TensorRT EP] Auxiliary streams are se to " + trt_state->auxiliary_streams; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } #else if (trt_state->builder_optimization_level != 3) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + std::string message = "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } if (trt_state->auxiliary_streams >= 0) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + std::string message = "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } #endif if (weight_stripped_engine_enable) { #if NV_TENSORRT_MAJOR >= 10 trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; + std::string message = "[TensorRT EP] STRIP_PLAN is enabled"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; + message = "[TensorRT EP] REFIT_IDENTICAL is enabled"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); #else - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; + std::string message = "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); #endif } // limit used tactic sources @@ -2734,7 +2919,10 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* nvinfer1::TacticSources tactics = trt_config->getTacticSources(); tactics |= trt_state->tactic_sources; trt_config->setTacticSources(tactics); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; + std::string message = "[TensorRT EP] Tactic sources are limited using bitmask " + tactics; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } // Load timing cache from file. Create a fresh cache if the file doesn't exist @@ -2749,7 +2937,10 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* } trt_config->setTimingCache(*timing_cache, force_timing_cache_match); if (detailed_build_log) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; + std::string message = "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } } @@ -2757,7 +2948,10 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* // Enable hardware compatility mode if assigned if (trt_state->engine_hw_compatible) { trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); - // LOGS_DEFAULT(INFO) << "[TensorRT EP] Re-generate engine with hardware compatibility enabled."; + std::string message = "[TensorRT EP] Re-generate engine with hardware compatibility enabled."; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } #endif @@ -2783,10 +2977,11 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* } if (detailed_build_log) { auto engine_build_stop = std::chrono::steady_clock::now(); - // LOGS_DEFAULT(INFO) - // << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " - // << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" - // << std::endl; + std::string message = "TensorRT engine build for " + trt_state->trt_node_name_with_precision + " took: " + + std::to_string(std::chrono::duration_cast(engine_build_stop - engine_build_start).count()) + "ms"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } } if (!(*(trt_state->engine))) { @@ -2797,7 +2992,10 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* if (trt_state->engine_cache_enable) { // Serialize engine profile SerializeProfileV2(profile_cache_path, shape_ranges); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; + std::string message = "[TensorRT EP] Serialized " + profile_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); // Serialize engine if (trt_state->engine_decryption_enable) { @@ -2810,7 +3008,10 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* std::string err_msg = "TensorRT EP could not call engine encryption function encrypt"; return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + std::string message = "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } else { // LOGS_DEFAULT(WARNING) // << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; @@ -2818,7 +3019,10 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* } else { std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; + std::string message = "[TensorRT EP] Serialized " + engine_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } } @@ -2832,7 +3036,10 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* } saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); if (detailed_build_log) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; + std::string message = "[TensorRT EP] Serialized timing cache " + timing_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } } From c6ae7b6498205173d43c368920bec9fb8156fe60 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 10 Sep 2025 14:09:48 -0700 Subject: [PATCH 54/60] update include path in utility function header --- plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h b/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h index 28ce4439..6f07c67a 100644 --- a/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h +++ b/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h @@ -122,7 +122,7 @@ #define INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ #include -#include "core/session/onnxruntime_cxx_api.h" +#include "onnxruntime_cxx_api.h" #include "onnx/onnx_pb.h" namespace OrtEpUtils { From 6b180a4814f7c5cf4d51d55c0132b1a0eb5eecba Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 10 Sep 2025 14:54:45 -0700 Subject: [PATCH 55/60] Add default logger for TRT EP (cont.) --- .../tensorrt/tensorrt_execution_provider.cc | 97 +++++++++++++------ 1 file changed, 68 insertions(+), 29 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 070a3dde..0314d46a 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -128,9 +128,13 @@ bool ApplyProfileShapesFromProviderOptions(std::vector>>& profile_min_shapes, std::unordered_map>>& profile_max_shapes, std::unordered_map>>& profile_opt_shapes, - ShapeRangesMap& input_explicit_shape_ranges) { + ShapeRangesMap& input_explicit_shape_ranges, + const OrtLogger* logger) { if (trt_profiles.size() == 0) { - // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Number of optimization profiles should be greater than 0, but it's 0."; + std::string message = "[TensorRT EP] Number of optimization profiles should be greater than 0, but it's 0."; + Ort::ThrowOnError(g_ort_api->Logger_LogMessage(logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); return false; } @@ -144,8 +148,11 @@ bool ApplyProfileShapesFromProviderOptions(std::vectorLogger_LogMessage(logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); for (size_t i = 0; i < trt_profiles.size(); i++) { nvinfer1::Dims dims = input->getDimensions(); @@ -158,7 +165,10 @@ bool ApplyProfileShapesFromProviderOptions(std::vector(profile_min_shapes[input_name][i].size()); std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shape size of this shape tensor is " << shape_size; + std::string message = "[TensorRT EP] shape size of this shape tensor is " + std::to_string(shape_size); + Ort::ThrowOnError(g_ort_api->Logger_LogMessage(logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); for (int j = 0; j < shape_size; j++) { auto min_value = profile_min_shapes[input_name][i][j]; @@ -167,9 +177,12 @@ bool ApplyProfileShapesFromProviderOptions(std::vector(min_value); shapes_max[j] = static_cast(max_value); shapes_opt[j] = static_cast(opt_value); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shapes_min.d[" << j << "] is " << shapes_min[j]; - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shapes_max.d[" << j << "] is " << shapes_max[j]; - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shapes_opt.d[" << j << "] is " << shapes_opt[j]; + std::string message = "[TensorRT EP] shapes_min.d[" + std::to_string(j) + std::string("] is ") + std::to_string(shapes_min[j]) + std::string("\n") + + std::string("[TensorRT EP] shapes_max.d[") + std::to_string(j) + std::string("] is ") + std::to_string(shapes_max[j]) + std::string("\n") + + std::string("[TensorRT EP] shapes_opt.d[") + std::to_string(j) + std::string("] is ") + std::to_string(shapes_opt[j]); + Ort::ThrowOnError(g_ort_api->Logger_LogMessage(logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); if (input_explicit_shape_ranges[input_name].find(j) == input_explicit_shape_ranges[input_name].end()) { std::vector> profile_vector(trt_profiles.size()); @@ -191,7 +204,10 @@ bool ApplyProfileShapesFromProviderOptions(std::vectorLogger_LogMessage(logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); for (int j = 0; j < nb_dims; j++) { if (dims.d[j] == -1) { @@ -201,9 +217,13 @@ bool ApplyProfileShapesFromProviderOptions(std::vector(min_value); dims_max.d[j] = static_cast(max_value); dims_opt.d[j] = static_cast(opt_value); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dims_min.d[" << j << "] is " << dims_min.d[j]; - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dims_max.d[" << j << "] is " << dims_max.d[j]; - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dims_opt.d[" << j << "] is " << dims_opt.d[j]; + + std::string message = "[TensorRT EP] dims_min.d[" + std::to_string(j) + std::string("] is ") + std::to_string(dims_min.d[j]) + std::string("\n") + + std::string("[TensorRT EP] dims_max.d[") + std::to_string(j) + std::string("] is ") + std::to_string(dims_max.d[j]) + std::string("\n") + + std::string("[TensorRT EP] dims_opt.d[") + std::to_string(j) + std::string("] is ") + std::to_string(dims_opt.d[j]); + Ort::ThrowOnError(g_ort_api->Logger_LogMessage(logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); if (input_explicit_shape_ranges[input_name].find(j) == input_explicit_shape_ranges[input_name].end()) { std::vector> profile_vector(trt_profiles.size()); @@ -1178,7 +1198,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this if (has_explicit_profile) { apply_explicit_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, - profile_opt_shapes_, input_explicit_shape_ranges); + profile_opt_shapes_, input_explicit_shape_ranges, &ep->logger_); } // If no explicit optimization profile is being applied, TRT EP will later set min/max/opt shape values based on @@ -1270,8 +1290,10 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this #pragma warning(pop) #endif int8_enable_ = false; - // LOGS_DEFAULT(WARNING) - // << "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8"; + std::string message = "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } } @@ -1356,9 +1378,12 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this #if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 if (build_heuristics_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are enabled." - << " For TRT > 8.5, trt_build_heuristics_enable is deprecated, please set builder " - "optimization level as 2 to enable builder heuristics."; + std::string message = "[TensorRT EP] Builder heuristics are enabled." + + std::string(" For TRT > 8.5, trt_build_heuristics_enable is deprecated, please set builder ") + + std::string("optimization level as 2 to enable builder heuristics."); + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } #elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 // for TRT 8.6 onwards, heuristic-based tactic option is automatically enabled by setting builder optimization level 2 @@ -1399,10 +1424,16 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } #else if (builder_optimization_level_ != 3) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + std::string message = "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } if (auxiliary_streams_ >= 0) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + std::string message = "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } #endif @@ -1419,7 +1450,10 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); #else - LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; + std::string message = "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); #endif } @@ -1613,10 +1647,11 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } if (detailed_build_log_) { auto engine_build_stop = std::chrono::steady_clock::now(); - // LOGS_DEFAULT(INFO) - // << "TensorRT engine build for " << trt_node_name_with_precision << " took: " - // << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() - // << "ms" << std::endl; + std::string message = "TensorRT engine build for " + trt_node_name_with_precision + std::string(" took: ") + + std::to_string(std::chrono::duration_cast(engine_build_stop - engine_build_start).count()) + std::string("ms"); + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } if (engine_cache_enable_) { // Serialize engine profile if it has explicit profiles @@ -1642,8 +1677,10 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } else { - // LOGS_DEFAULT(WARNING) - // << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; + std::string message = "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } } else { std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); @@ -3013,8 +3050,10 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } else { - // LOGS_DEFAULT(WARNING) - // << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; + std::string message = "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } } else { std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); From b3ac797958e4beb288ee106007af30874dd3aa4f Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 16 Sep 2025 14:52:05 -0700 Subject: [PATCH 56/60] put code under namespace trt_ep --- .../tensorrt/cuda_allocator.cc | 4 ++++ .../tensorrt/cuda_allocator.h | 4 ++++ .../tensorrt/onnx_ctx_model_helper.cc | 4 +++- .../tensorrt/onnx_ctx_model_helper.h | 2 ++ .../tensorrt/tensorrt_execution_provider.cc | 3 +++ .../tensorrt/tensorrt_execution_provider.h | 13 +++++++------ .../tensorrt_execution_provider_data_transfer.cc | 3 +++ .../tensorrt_execution_provider_data_transfer.h | 5 ++++- .../tensorrt_execution_provider_stream_support.cc | 3 +++ .../tensorrt_execution_provider_stream_support.h | 2 ++ .../tensorrt/tensorrt_provider_factory.cc | 8 ++++++-- .../tensorrt/tensorrt_provider_factory.h | 5 ++++- 12 files changed, 45 insertions(+), 11 deletions(-) diff --git a/plugin_execution_providers/tensorrt/cuda_allocator.cc b/plugin_execution_providers/tensorrt/cuda_allocator.cc index 058d96f4..9166e35e 100644 --- a/plugin_execution_providers/tensorrt/cuda_allocator.cc +++ b/plugin_execution_providers/tensorrt/cuda_allocator.cc @@ -5,6 +5,8 @@ #include #include "cuda_allocator.h" +namespace trt_ep { + void CUDA_RETURN_IF_ERROR(cudaError_t res); void CUDAAllocator::CheckDevice(bool throw_when_fail) const { @@ -74,3 +76,5 @@ void CUDAPinnedAllocator::Free(void* p) { const OrtMemoryInfo* CUDAPinnedAllocator::Info() const { return mem_info_; } + +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/cuda_allocator.h b/plugin_execution_providers/tensorrt/cuda_allocator.h index 7f765d50..2000e828 100644 --- a/plugin_execution_providers/tensorrt/cuda_allocator.h +++ b/plugin_execution_providers/tensorrt/cuda_allocator.h @@ -7,6 +7,8 @@ using DeviceId = int16_t; +namespace trt_ep { + struct CUDAAllocator : OrtAllocator { CUDAAllocator(const OrtMemoryInfo* mem_info, DeviceId device_id) : mem_info_(mem_info), device_id_(device_id) { OrtAllocator::version = ORT_API_VERSION; @@ -62,3 +64,5 @@ struct CUDAPinnedAllocator : OrtAllocator { DeviceId device_id_ = 0; const OrtMemoryInfo* mem_info_ = nullptr; }; + +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc index 4e803383..34bd9647 100644 --- a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc @@ -9,6 +9,7 @@ #include "onnx_ctx_model_helper.h" #include "onnx/onnx_pb.h" +namespace trt_ep { extern TensorrtLogger& GetTensorrtLogger(bool verbose_log, const OrtLogger& ort_default_logger, const OrtApi* ort_api); @@ -272,4 +273,5 @@ std::string GetWeightRefittedEnginePath(std::string stripped_engine_cache) { std::filesystem::path stripped_engine_cache_path(stripped_engine_cache); std::string refitted_engine_cache_path = stripped_engine_cache_path.stem().stem().string() + ".engine"; return refitted_engine_cache_path; -} \ No newline at end of file +} +} // namespace trt_ep \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h index 75fa1b19..4979c07c 100644 --- a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h @@ -12,6 +12,7 @@ #include #include +namespace trt_ep { class EPContextNodeHelper : public ApiPtrs { public: EPContextNodeHelper(TensorrtExecutionProvider& ep, @@ -79,3 +80,4 @@ class EPContextNodeReader : public ApiPtrs { size_t onnx_external_data_bytestream_size_; bool detailed_build_log_; }; // TRTCacheModelHandler +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 0314d46a..cb286f29 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -35,6 +35,8 @@ const OrtApi* g_ort_api = nullptr; const OrtEpApi* g_ep_api = nullptr; const OrtModelEditorApi* g_model_editor_api = nullptr; +namespace trt_ep { + void CUDA_RETURN_IF_ERROR(cudaError_t res) { if (res != cudaSuccess) abort(); } @@ -3597,3 +3599,4 @@ void TRTEpEpContextNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_pt (void)trt_ep_compute_state; // Do nothing for here. } +} diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index a9757260..e95157bd 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -17,6 +17,11 @@ #define EXPORT_API #endif +using HashValue = uint64_t; +using AllocateFunc = void* (*)(void*, size_t, size_t); +using DestroyFunc = void (*)(void*, void*); + +namespace trt_ep { namespace tensorrt_env_vars { static const std::string kMaxPartitionIterations = "ORT_TENSORRT_MAX_PARTITION_ITERATIONS"; static const std::string kMinSubgraphSize = "ORT_TENSORRT_MIN_SUBGRAPH_SIZE"; @@ -60,10 +65,6 @@ static const std::string kEngineCachePrefix = "ORT_TENSORRT_CACHE_PREFIX"; static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH"; } // namespace tensorrt_env_vars -using HashValue = uint64_t; -using AllocateFunc = void* (*)(void*, size_t, size_t); -using DestroyFunc = void (*)(void*, void*); - class TensorrtLogger : public nvinfer1::ILogger { nvinfer1::ILogger::Severity verbosity_; const OrtLogger& ort_default_logger_; @@ -94,8 +95,7 @@ class TensorrtLogger : public nvinfer1::ILogger { OrtLoggingLevel ort_severity; if (severity <= Severity::kERROR) { ort_severity = ORT_LOGGING_LEVEL_ERROR; - } - else { + } else { ort_severity = ORT_LOGGING_LEVEL_WARNING; } @@ -472,3 +472,4 @@ struct TRTEpEpContextNodeComputeInfo : OrtNodeComputeInfo { TensorrtExecutionProvider& ep; }; +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc index fe1bc675..b0716b04 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc @@ -7,6 +7,8 @@ #include #include +namespace trt_ep { + void CUDA_RETURN_IF_ERROR(cudaError_t res); /*static*/ @@ -107,3 +109,4 @@ void ORT_API_CALL TRTEpDataTransfer::ReleaseImpl(OrtDataTransferImpl* this_ptr) // delete static_cast(this_ptr); ; } +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h index 42c83007..34221f3a 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h @@ -6,6 +6,8 @@ #include "ep_utils.h" #include "onnxruntime_c_api.h" +namespace trt_ep { + struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs { TRTEpDataTransfer(ApiPtrs api_ptrs, std::vector& device_mem_infos, std::vector& shared_mem_infos) @@ -28,4 +30,5 @@ struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs { private: std::vector& cuda_gpu_mem_devices_; std::vector& cuda_pinned_mem_devices_; -}; \ No newline at end of file +}; +} // namespace trt_ep \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc index 59f3d1e1..a6a95451 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc @@ -8,6 +8,8 @@ #include "cuda/cuda_common.h" #include "cuda/cuda_call.h" +namespace trt_ep { + // // TrtSyncStreamImpl implementation // @@ -117,3 +119,4 @@ OrtStatus* ORT_API_CALL TrtSyncNotificationImpl::WaitOnHostImpl(_In_ OrtSyncNoti void ORT_API_CALL TrtSyncNotificationImpl::ReleaseImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { delete static_cast(this_ptr); } +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h index 83a0fabc..7242c247 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h @@ -9,6 +9,7 @@ #include +namespace trt_ep { // // Class implementing Stream support for synchronization. // @@ -60,3 +61,4 @@ struct TrtSyncNotificationImpl : public OrtSyncNotificationImpl, public ApiPtrs cudaStream_t& stream_; cudaEvent_t event_; }; +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc index 458cdc0a..6ff43e78 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -11,6 +11,8 @@ #include #include +namespace trt_ep { + TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* ep_name, const OrtLogger& default_logger, ApiPtrs apis) : ApiPtrs(apis), default_logger_{default_logger}, ep_name_{ep_name} { ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. @@ -268,6 +270,8 @@ bool ORT_API_CALL TensorrtExecutionProviderFactory::IsStreamAwareImpl(const OrtE return true; } +} // namespace trt_ep + // To make symbols visible on macOS/iOS #ifdef __APPLE__ #define EXPORT_SYMBOL __attribute__((visibility("default"))) @@ -287,7 +291,7 @@ EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); // Factory could use registration_name or define its own EP name. - std::unique_ptr factory = std::make_unique(registration_name, *default_logger, ApiPtrs{*ort_api, *ort_ep_api, *model_editor_api}); + std::unique_ptr factory = std::make_unique(registration_name, *default_logger, ApiPtrs{*ort_api, *ort_ep_api, *model_editor_api}); if (max_factories < 1) { return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, @@ -301,7 +305,7 @@ EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const } EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { - delete static_cast(factory); + delete static_cast(factory); return nullptr; } diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h index 9ace8dd9..fcb0eba1 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -6,6 +6,8 @@ using MemoryInfoUniquePtr = std::unique_ptr>; +namespace trt_ep { + /// /// Plugin TensorRT EP factory that can create an OrtEp and return information about the supported hardware devices. /// @@ -63,4 +65,5 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { const std::string vendor_{"Nvidia"}; // EP vendor name const std::string ep_version_{"0.1.0"}; // EP version const OrtLogger& default_logger_; -}; \ No newline at end of file +}; +} // namespace trt_ep \ No newline at end of file From 632d2249166bc479a5de2981320240b4beccb7c8 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 18 Sep 2025 11:20:32 -0700 Subject: [PATCH 57/60] remove unnecessary files --- .../tensorrt/utils/code_location.h | 58 --- .../tensorrt/utils/common.h | 169 --------- .../tensorrt/utils/endian.h | 27 -- .../tensorrt/utils/exceptions.h | 91 ----- .../tensorrt/utils/murmurhash3.cc | 349 ------------------ .../tensorrt/utils/murmurhash3.h | 16 - .../tensorrt/utils/path_string.h | 70 ---- .../tensorrt/utils/status.h | 192 ---------- 8 files changed, 972 deletions(-) delete mode 100644 plugin_execution_providers/tensorrt/utils/code_location.h delete mode 100644 plugin_execution_providers/tensorrt/utils/common.h delete mode 100644 plugin_execution_providers/tensorrt/utils/endian.h delete mode 100644 plugin_execution_providers/tensorrt/utils/exceptions.h delete mode 100644 plugin_execution_providers/tensorrt/utils/murmurhash3.cc delete mode 100644 plugin_execution_providers/tensorrt/utils/murmurhash3.h delete mode 100644 plugin_execution_providers/tensorrt/utils/path_string.h delete mode 100644 plugin_execution_providers/tensorrt/utils/status.h diff --git a/plugin_execution_providers/tensorrt/utils/code_location.h b/plugin_execution_providers/tensorrt/utils/code_location.h deleted file mode 100644 index dbff6909..00000000 --- a/plugin_execution_providers/tensorrt/utils/code_location.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -namespace onnxruntime { -/** - CodeLocation captures information on where in the source code a message came from. -*/ -struct CodeLocation { - /** - @param file_path Usually the value of __FILE__ - @param line Usually the value of __LINE__ - @param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__ - */ - CodeLocation(const char* file_path, const int line, const char* func) - : file_and_path{file_path}, line_num{line}, function{func} { - } - - /** - @param file_path Usually the value of __FILE__ - @param line Usually the value of __LINE__ - @param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__ - @param stacktrace Stacktrace from source of message. - */ - CodeLocation(const char* file_path, const int line, const char* func, const std::vector& stacktrace) - : file_and_path{file_path}, line_num{line}, function{func}, stacktrace(stacktrace) { - } - - std::string FileNoPath() const { - // assuming we always have work to do, so not trying to avoid creating a new string if - // no path was removed. - return file_and_path.substr(file_and_path.find_last_of("/\\") + 1); - } - - enum Format { - kFilename, - kFilenameAndPath - }; - - std::string ToString(Format format = Format::kFilename) const { - std::ostringstream out; - out << (format == Format::kFilename ? FileNoPath() : file_and_path) << ":" << line_num << " " << function; - return out.str(); - } - // utf-8. Because on Windows we compile our code with "/utf-8". And we assume the other platforms only use utf-8. - const std::string file_and_path; - const int line_num; - // utf-8 - const std::string function; - const std::vector stacktrace; -}; - -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/common.h b/plugin_execution_providers/tensorrt/utils/common.h deleted file mode 100644 index 72e9ffbf..00000000 --- a/plugin_execution_providers/tensorrt/utils/common.h +++ /dev/null @@ -1,169 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -// Portions Copyright (c) Microsoft Corporation - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "code_location.h" -#include "exceptions.h" -#include "make_string.h" -#include "status.h" - -namespace onnxruntime { - -// __PRETTY_FUNCTION__ isn't a macro on gcc, so use a check for _MSC_VER -// so we only define it as one for MSVC -#if (_MSC_VER && !defined(__PRETTY_FUNCTION__)) -#define __PRETTY_FUNCTION__ __FUNCTION__ -#endif - -// Capture where a message is coming from. Use __FUNCTION__ rather than the much longer __PRETTY_FUNCTION__ -#define ORT_WHERE ::onnxruntime::CodeLocation(__FILE__, __LINE__, static_cast(__FUNCTION__)) - -#define ORT_WHERE_WITH_STACK \ - ::onnxruntime::CodeLocation(__FILE__, __LINE__, static_cast(__PRETTY_FUNCTION__), ::onnxruntime::GetStackTrace()) - -// Throw an exception with optional message. -// NOTE: The arguments get streamed into a string via ostringstream::operator<< -// DO NOT use a printf format string, as that will not work as you expect. -/* -#define ORT_THROW(...) \ - throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, ::onnxruntime::MakeString(__VA_ARGS__)) -*/ -#define ORT_THROW(...) \ - throw ::onnxruntime::OnnxRuntimeException(::onnxruntime::MakeString(__VA_ARGS__)) - -// Just in order to mark things as not implemented. Do not use in final code. -#define ORT_NOT_IMPLEMENTED(...) \ - throw ::onnxruntime::NotImplementedException(::onnxruntime::MakeString(__VA_ARGS__)) - -// Check condition. -// NOTE: The arguments get streamed into a string via ostringstream::operator<< -// DO NOT use a printf format string, as that will not work as you expect. -#define ORT_ENFORCE(condition, ...) \ - do { \ - if (!(condition)) { \ - throw ::onnxruntime::OnnxRuntimeException(#condition, \ - ::onnxruntime::MakeString(__VA_ARGS__)); \ - } \ - } while (false) - -#define ORT_THROW_EX(ex, ...) \ - throw ex(__VA_ARGS__) - -#define ORT_MAKE_STATUS(category, code, ...) \ - ::onnxruntime::common::Status(::onnxruntime::common::category, \ - ::onnxruntime::common::code, \ - ::onnxruntime::MakeString(__VA_ARGS__)) - -// Check condition. if met, return status. -#define ORT_RETURN_IF(condition, ...) \ - do { \ - if (condition) { \ - return ::onnxruntime::common::Status(::onnxruntime::common::ONNXRUNTIME, \ - ::onnxruntime::common::FAIL, \ - ::onnxruntime::MakeString(ORT_WHERE.ToString(), " ", __VA_ARGS__)); \ - } \ - } while (false) - -// Check condition. if not met, return status. -#define ORT_RETURN_IF_NOT(condition, ...) \ - ORT_RETURN_IF(!(condition), __VA_ARGS__) - -// Macros to disable the copy and/or move ctor and assignment methods -// These are usually placed in the private: declarations for a class. - -#define ORT_DISALLOW_COPY(TypeName) TypeName(const TypeName&) = delete - -#define ORT_DISALLOW_ASSIGNMENT(TypeName) TypeName& operator=(const TypeName&) = delete - -#define ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName) \ - ORT_DISALLOW_COPY(TypeName); \ - ORT_DISALLOW_ASSIGNMENT(TypeName) - -#define ORT_DISALLOW_MOVE(TypeName) \ - TypeName(TypeName&&) = delete; \ - TypeName& operator=(TypeName&&) = delete - -#define ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TypeName) \ - ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName); \ - ORT_DISALLOW_MOVE(TypeName) - -#define ORT_RETURN_IF_ERROR(expr) \ - do { \ - auto _status = (expr); \ - if ((!_status.IsOK())) { \ - return _status; \ - } \ - } while (0) - -#define ORT_THROW_IF_ERROR(expr) \ - do { \ - auto _status = (expr); \ - if ((!_status.IsOK())) { \ - ORT_THROW(_status); \ - } \ - } while (0) - -// use this macro when cannot early return -#define ORT_CHECK_AND_SET_RETVAL(expr) \ - do { \ - if (retval.IsOK()) { \ - retval = (expr); \ - } \ - } while (0) - -struct null_type {}; -inline std::string ToUTF8String(const std::string& s) { return s; } -#ifdef _WIN32 -/** - * Convert a wide character string to a UTF-8 string - */ -std::string ToUTF8String(const std::wstring& s); - -std::wstring ToWideString(const std::string& s); -inline std::wstring ToWideString(const std::wstring& s) { return s; } -#else -inline std::string ToWideString(const std::string& s) { return s; } -#endif - -constexpr size_t kMaxStrLen = 2048; - -// Returns whether `key` is in `container`. -// Like C++20's map/set contains() member function. -template typename AssociativeContainer, - typename LookupKey> -inline bool Contains(const AssociativeContainer& container, LookupKey&& key) { - return container.find(std::forward(key)) != container.end(); -} - -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/endian.h b/plugin_execution_providers/tensorrt/utils/endian.h deleted file mode 100644 index 629fb78f..00000000 --- a/plugin_execution_providers/tensorrt/utils/endian.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -namespace onnxruntime { - -// the semantics of this enum should match std::endian from C++20 -enum class endian { -#if defined(_WIN32) - little = 0, - big = 1, - native = little, -#elif defined(__GNUC__) || defined(__clang__) - little = __ORDER_LITTLE_ENDIAN__, - big = __ORDER_BIG_ENDIAN__, - native = __BYTE_ORDER__, -#else -#error onnxruntime::endian is not implemented in this environment. -#endif -}; - -static_assert( - endian::native == endian::little || endian::native == endian::big, - "Only little-endian or big-endian native byte orders are supported."); - -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/exceptions.h b/plugin_execution_providers/tensorrt/utils/exceptions.h deleted file mode 100644 index 4f166da3..00000000 --- a/plugin_execution_providers/tensorrt/utils/exceptions.h +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "common.h" -// #include "code_location.h" - -namespace onnxruntime { - -class NotImplementedException : public std::logic_error { - public: - explicit NotImplementedException(const char* _Message = "Function not yet implemented") noexcept : std::logic_error(_Message) {}; - explicit NotImplementedException(const std::string& _Message = "Function not yet implemented") noexcept : std::logic_error(_Message) {}; -}; - -class TypeMismatchException : public std::logic_error { - public: - TypeMismatchException() noexcept : logic_error("Type mismatch") {}; -}; - -class OnnxRuntimeException : public std::exception { - public: - // code location is not provided for now - /* - OnnxRuntimeException(const CodeLocation& location, const std::string& msg) noexcept - : OnnxRuntimeException(location, nullptr, msg) { - } - */ - - /** - Create a new exception that captures the location it was thrown from. - @param location Location in the source code the exception is being thrown from - @param failed_condition Optional string containing the condition that failed. - e.g. "tensor.Size() == input.Size()". May be nullptr. - @param msg Message containing additional information about the exception cause. - */ - /* - OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) - : location_{location} { - std::ostringstream ss; - - ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous - if (failed_condition != nullptr) { - ss << " " << failed_condition << " was false."; - } - - ss << " " << msg << "\n"; - if (!location.stacktrace.empty()) { - ss << "Stacktrace:\n"; - // skip the first entry in the stacktrace as we have that information from location.ToString() - std::copy(std::next(location.stacktrace.begin()), location.stacktrace.end(), std::ostream_iterator(ss, "\n")); - } - - what_ = ss.str(); - } - */ - - OnnxRuntimeException(const std::string& msg) noexcept - : OnnxRuntimeException(nullptr, msg) { - } - - OnnxRuntimeException(const char* failed_condition, const std::string& msg) { - std::ostringstream ss; - - if (failed_condition != nullptr) { - ss << failed_condition << " was false."; - } - - ss << " " << msg << "\n"; - what_ = ss.str(); - } - - const char* what() const noexcept override { - return what_.c_str(); - } - - private: - // const CodeLocation location_; - const std::vector stacktrace_; - std::string what_; -}; - -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/murmurhash3.cc b/plugin_execution_providers/tensorrt/utils/murmurhash3.cc deleted file mode 100644 index 49fcb2ef..00000000 --- a/plugin_execution_providers/tensorrt/utils/murmurhash3.cc +++ /dev/null @@ -1,349 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "murmurhash3.h" - -// Original source: https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp -//----------------------------------------------------------------------------- -// MurmurHash3 was written by Austin Appleby, and is placed in the public -// domain. The author hereby disclaims copyright to this source code. - -// Note - The x86 and x64 versions do _not_ produce the same results, as the -// algorithms are optimized for their respective platforms. You can still -// compile and run any of them on any platform, but your performance with the -// non-native version will be less than optimal. - -/* Modifications Copyright (c) Microsoft. */ - -#include "endian.h" - -//----------------------------------------------------------------------------- -// Platform-specific functions and macros - -// Microsoft Visual Studio - -#if defined(_MSC_VER) - -#define FORCE_INLINE __forceinline - -#include - -#define ROTL32(x, y) _rotl(x, y) -#define ROTL64(x, y) _rotl64(x, y) - -#define BIG_CONSTANT(x) (x) - -// Other compilers - -#else // defined(_MSC_VER) - -#define FORCE_INLINE inline __attribute__((always_inline)) - -inline uint32_t rotl32(uint32_t x, int8_t r) { - return (x << r) | (x >> (32 - r)); -} - -inline uint64_t rotl64(uint64_t x, int8_t r) { - return (x << r) | (x >> (64 - r)); -} - -#define ROTL32(x, y) rotl32(x, y) -#define ROTL64(x, y) rotl64(x, y) - -#define BIG_CONSTANT(x) (x##LLU) - -#endif // !defined(_MSC_VER) -#include -//----------------------------------------------------------------------------- -// Block read - on little-endian machines this is a single load, -// while on big-endian or unknown machines the byte accesses should -// still get optimized into the most efficient instruction. -// -// Changes to support big-endian from https://github.com/explosion/murmurhash/pull/27/ -// were manually applied to original murmurhash3 source code. -FORCE_INLINE uint32_t getblock32(const uint32_t* p, int i) { - if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) { - return p[i]; - } else { - const uint8_t* c = (const uint8_t*)&p[i]; - return (uint32_t)c[0] | - (uint32_t)c[1] << 8 | - (uint32_t)c[2] << 16 | - (uint32_t)c[3] << 24; - } -} - -FORCE_INLINE uint64_t getblock64(const uint64_t* p, int i) { - if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) { - return p[i]; - } else { - const uint8_t* c = (const uint8_t*)&p[i]; - return (uint64_t)c[0] | - (uint64_t)c[1] << 8 | - (uint64_t)c[2] << 16 | - (uint64_t)c[3] << 24 | - (uint64_t)c[4] << 32 | - (uint64_t)c[5] << 40 | - (uint64_t)c[6] << 48 | - (uint64_t)c[7] << 56; - } -} - -//----------------------------------------------------------------------------- -// Finalization mix - force all bits of a hash block to avalanche - -FORCE_INLINE constexpr uint32_t fmix32(uint32_t h) { - h ^= h >> 16; - h *= 0x85ebca6b; - h ^= h >> 13; - h *= 0xc2b2ae35; - h ^= h >> 16; - - return h; -} - -//---------- - -FORCE_INLINE constexpr uint64_t fmix64(uint64_t k) { - k ^= k >> 33; - k *= BIG_CONSTANT(0xff51afd7ed558ccd); - k ^= k >> 33; - k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53); - k ^= k >> 33; - - return k; -} - -//----------------------------------------------------------------------------- - -namespace onnxruntime { -void MurmurHash3::x86_32(const void* key, int len, - uint32_t seed, void* out) { - const uint8_t* data = (const uint8_t*)key; - const int nblocks = len / 4; - - uint32_t h1 = seed; - - constexpr uint32_t c1 = 0xcc9e2d51; - constexpr uint32_t c2 = 0x1b873593; - - //---------- - // body - - const uint32_t* blocks = (const uint32_t*)(data + static_cast(nblocks) * 4); - - for (int i = -nblocks; i; i++) { - uint32_t k1 = getblock32(blocks, i); - - k1 *= c1; - k1 = ROTL32(k1, 15); - k1 *= c2; - - h1 ^= k1; - h1 = ROTL32(h1, 13); - h1 = h1 * 5 + 0xe6546b64; - } - - //---------- - // tail - - const uint8_t* tail = (const uint8_t*)(data + static_cast(nblocks) * 4); - - uint32_t k1 = 0; - - switch (len & 3) { - case 3: - k1 ^= tail[2] << 16; - [[fallthrough]]; - case 2: - k1 ^= tail[1] << 8; - [[fallthrough]]; - case 1: - k1 ^= tail[0]; - k1 *= c1; - k1 = ROTL32(k1, 15); - k1 *= c2; - h1 ^= k1; - }; - - //---------- - // finalization - - h1 ^= len; - - h1 = fmix32(h1); - - *(uint32_t*)out = h1; -} - -//----------------------------------------------------------------------------- - -void MurmurHash3::x86_128(const void* key, int len, uint32_t seed, void* out) { - const uint8_t* data = (const uint8_t*)key; - const int nblocks = len / 16; - - uint32_t h1 = seed; - uint32_t h2 = seed; - uint32_t h3 = seed; - uint32_t h4 = seed; - - constexpr uint32_t c1 = 0x239b961b; - constexpr uint32_t c2 = 0xab0e9789; - constexpr uint32_t c3 = 0x38b34ae5; - constexpr uint32_t c4 = 0xa1e38b93; - - //---------- - // body - - const uint32_t* blocks = (const uint32_t*)(data + static_cast(nblocks) * 16); - - for (int i = -nblocks; i; i++) { - uint32_t k1 = getblock32(blocks, i * 4 + 0); - uint32_t k2 = getblock32(blocks, i * 4 + 1); - uint32_t k3 = getblock32(blocks, i * 4 + 2); - uint32_t k4 = getblock32(blocks, i * 4 + 3); - - k1 *= c1; - k1 = ROTL32(k1, 15); - k1 *= c2; - h1 ^= k1; - - h1 = ROTL32(h1, 19); - h1 += h2; - h1 = h1 * 5 + 0x561ccd1b; - - k2 *= c2; - k2 = ROTL32(k2, 16); - k2 *= c3; - h2 ^= k2; - - h2 = ROTL32(h2, 17); - h2 += h3; - h2 = h2 * 5 + 0x0bcaa747; - - k3 *= c3; - k3 = ROTL32(k3, 17); - k3 *= c4; - h3 ^= k3; - - h3 = ROTL32(h3, 15); - h3 += h4; - h3 = h3 * 5 + 0x96cd1c35; - - k4 *= c4; - k4 = ROTL32(k4, 18); - k4 *= c1; - h4 ^= k4; - - h4 = ROTL32(h4, 13); - h4 += h1; - h4 = h4 * 5 + 0x32ac3b17; - } - - //---------- - // tail - - const uint8_t* tail = (const uint8_t*)(data + static_cast(nblocks) * 16); - - uint32_t k1 = 0; - uint32_t k2 = 0; - uint32_t k3 = 0; - uint32_t k4 = 0; - - switch (len & 15) { - case 15: - k4 ^= tail[14] << 16; - [[fallthrough]]; - case 14: - k4 ^= tail[13] << 8; - [[fallthrough]]; - case 13: - k4 ^= tail[12] << 0; - k4 *= c4; - k4 = ROTL32(k4, 18); - k4 *= c1; - h4 ^= k4; - [[fallthrough]]; - case 12: - k3 ^= tail[11] << 24; - [[fallthrough]]; - case 11: - k3 ^= tail[10] << 16; - [[fallthrough]]; - case 10: - k3 ^= tail[9] << 8; - [[fallthrough]]; - case 9: - k3 ^= tail[8] << 0; - k3 *= c3; - k3 = ROTL32(k3, 17); - k3 *= c4; - h3 ^= k3; - [[fallthrough]]; - case 8: - k2 ^= tail[7] << 24; - [[fallthrough]]; - case 7: - k2 ^= tail[6] << 16; - [[fallthrough]]; - case 6: - k2 ^= tail[5] << 8; - [[fallthrough]]; - case 5: - k2 ^= tail[4] << 0; - k2 *= c2; - k2 = ROTL32(k2, 16); - k2 *= c3; - h2 ^= k2; - [[fallthrough]]; - case 4: - k1 ^= tail[3] << 24; - [[fallthrough]]; - case 3: - k1 ^= tail[2] << 16; - [[fallthrough]]; - case 2: - k1 ^= tail[1] << 8; - [[fallthrough]]; - case 1: - k1 ^= tail[0] << 0; - k1 *= c1; - k1 = ROTL32(k1, 15); - k1 *= c2; - h1 ^= k1; - }; - - //---------- - // finalization - - h1 ^= len; - h2 ^= len; - h3 ^= len; - h4 ^= len; - - h1 += h2; - h1 += h3; - h1 += h4; - h2 += h1; - h3 += h1; - h4 += h1; - - h1 = fmix32(h1); - h2 = fmix32(h2); - h3 = fmix32(h3); - h4 = fmix32(h4); - - h1 += h2; - h1 += h3; - h1 += h4; - h2 += h1; - h3 += h1; - h4 += h1; - - ((uint32_t*)out)[0] = h1; - ((uint32_t*)out)[1] = h2; - ((uint32_t*)out)[2] = h3; - ((uint32_t*)out)[3] = h4; -} - -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/murmurhash3.h b/plugin_execution_providers/tensorrt/utils/murmurhash3.h deleted file mode 100644 index ab86a3e5..00000000 --- a/plugin_execution_providers/tensorrt/utils/murmurhash3.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace onnxruntime { -struct MurmurHash3 { - // generate 32-bit hash from input and write to 'out' - static void x86_32(const void* key, int len, uint32_t seed, void* out); - - // generate 128-bit hash from input and write to 'out'. - static void x86_128(const void* key, int len, uint32_t seed, void* out); -}; -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/path_string.h b/plugin_execution_providers/tensorrt/utils/path_string.h deleted file mode 100644 index 3856c7c2..00000000 --- a/plugin_execution_providers/tensorrt/utils/path_string.h +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -// for std::tolower or std::towlower -#ifdef _WIN32 -#include -#else -#include -#endif - -// for converting / printing ORT_TSTR path strings to std::string -#ifdef _WIN32 -#define ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(X) std::wstring_convert>().to_bytes(X) -#define ORT_TSTR_CONVERT_FROM_STRING(X) std::wstring_convert>().from_bytes(X); -#else -#define ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(X) X -#define ORT_TSTR_CONVERT_FROM_STRING(X) X -#endif - -// #include "core/common/common.h" -// #include "core/session/onnxruntime_c_api.h" - -// #include "common.h" - -namespace onnxruntime { -// char type for filesystem paths -using PathChar = ORTCHAR_T; -// string type for filesystem paths -using PathString = std::basic_string; - -inline PathString ToPathString(const PathString& s) { - return s; -} - -#ifdef _WIN32 - -static_assert(std::is_same::value, "PathString is not std::wstring!"); - -inline PathString ToPathString(const std::string& s) { - return ToWideString(s); -} - -inline PathChar ToLowerPathChar(PathChar c) { - return std::towlower(c); -} - -inline std::string PathToUTF8String(const PathString& s) { - return ToUTF8String(s); -} - -#else - -static_assert(std::is_same::value, "PathString is not std::string!"); - -inline PathChar ToLowerPathChar(PathChar c) { - return std::tolower(c); -} - -inline std::string PathToUTF8String(const PathString& s) { - return s; -} - -#endif - -} // namespace onnxruntime diff --git a/plugin_execution_providers/tensorrt/utils/status.h b/plugin_execution_providers/tensorrt/utils/status.h deleted file mode 100644 index 80bf7caf..00000000 --- a/plugin_execution_providers/tensorrt/utils/status.h +++ /dev/null @@ -1,192 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Modifications Copyright (c) Microsoft. - -#pragma once - -#include -#include -#include -#ifdef _WIN32 -#include -#endif - -namespace onnxruntime { -namespace common { - -enum StatusCategory { - NONE = 0, - SYSTEM = 1, - ONNXRUNTIME = 2, -}; - -/** - Error code for ONNXRuntime. -*/ -enum StatusCode { - OK = 0, - FAIL = 1, - INVALID_ARGUMENT = 2, - NO_SUCHFILE = 3, - NO_MODEL = 4, - ENGINE_ERROR = 5, - RUNTIME_EXCEPTION = 6, - INVALID_PROTOBUF = 7, - MODEL_LOADED = 8, - NOT_IMPLEMENTED = 9, - INVALID_GRAPH = 10, - EP_FAIL = 11 -}; - -constexpr const char* StatusCodeToString(StatusCode status) noexcept { - switch (status) { - case StatusCode::OK: - return "SUCCESS"; - case StatusCode::FAIL: - return "FAIL"; - case StatusCode::INVALID_ARGUMENT: - return "INVALID_ARGUMENT"; - case StatusCode::NO_SUCHFILE: - return "NO_SUCHFILE"; - case StatusCode::NO_MODEL: - return "NO_MODEL"; - case StatusCode::ENGINE_ERROR: - return "ENGINE_ERROR"; - case StatusCode::RUNTIME_EXCEPTION: - return "RUNTIME_EXCEPTION"; - case StatusCode::INVALID_PROTOBUF: - return "INVALID_PROTOBUF"; - case StatusCode::MODEL_LOADED: - return "MODEL_LOADED"; - case StatusCode::NOT_IMPLEMENTED: - return "NOT_IMPLEMENTED"; - case StatusCode::INVALID_GRAPH: - return "INVALID_GRAPH"; - case StatusCode::EP_FAIL: - return "EP_FAIL"; - default: - return "GENERAL ERROR"; - } -} - -#ifdef _WIN32 -constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept { - switch (status) { - case StatusCode::OK: - return S_OK; - case StatusCode::FAIL: - return E_FAIL; - case StatusCode::INVALID_ARGUMENT: - return E_INVALIDARG; - case StatusCode::NO_SUCHFILE: - return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); - case StatusCode::NO_MODEL: - return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); - case StatusCode::ENGINE_ERROR: - return E_FAIL; - case StatusCode::RUNTIME_EXCEPTION: - return E_FAIL; - case StatusCode::INVALID_PROTOBUF: - return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); - case StatusCode::MODEL_LOADED: - return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); - case StatusCode::NOT_IMPLEMENTED: - return E_NOTIMPL; - case StatusCode::INVALID_GRAPH: - return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); - case StatusCode::EP_FAIL: - return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); - default: - return E_FAIL; - } -} -#endif - -class [[nodiscard]] Status { - public: - Status() noexcept = default; - - Status(StatusCategory category, int code, const std::string& msg); - - Status(StatusCategory category, int code, const char* msg); - - Status(StatusCategory category, int code); - - Status(const Status& other) - : state_((other.state_ == nullptr) ? nullptr : new State(*other.state_)) {} - Status& operator=(const Status& other) { - if (state_ != other.state_) { - if (other.state_ == nullptr) { - state_.reset(); - } else { - state_.reset(new State(*other.state_)); - } - } - return *this; - } - - Status(Status&&) = default; - Status& operator=(Status&&) = default; - ~Status() = default; - - bool IsOK() const { - return (state_ == nullptr); - } - - int Code() const noexcept; - - StatusCategory Category() const noexcept; - - const std::string& ErrorMessage() const noexcept; - - std::string ToString() const; - - bool operator==(const Status& other) const { - return (this->state_ == other.state_) || (ToString() == other.ToString()); - } - - bool operator!=(const Status& other) const { - return !(*this == other); - } - - static Status OK() { - return Status(); - } - - private: - static const std::string& EmptyString() noexcept; - - struct State { - State(StatusCategory cat0, int code0, const std::string& msg0) - : category(cat0), code(code0), msg(msg0) {} - - State(StatusCategory cat0, int code0, const char* msg0) - : category(cat0), code(code0), msg(msg0) {} - - const StatusCategory category; - const int code; - const std::string msg; - }; - - // As long as Code() is OK, state_ == nullptr. - std::unique_ptr state_; -}; - -inline std::ostream& operator<<(std::ostream& out, const Status& status) { - return out << status.ToString(); -} -} // namespace common - -// make Status directly available in the onnxruntime namespace as it is widely used -using common::Status; - -} // namespace onnxruntime From 4d328677589b19f5872a669c77163c37d7c19cd3 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 25 Sep 2025 16:46:49 -0700 Subject: [PATCH 58/60] update GetCapabilityImpl() --- .../tensorrt/tensorrt_execution_provider.cc | 201 ++++++++++-------- .../tensorrt/tensorrt_execution_provider.h | 51 +---- 2 files changed, 120 insertions(+), 132 deletions(-) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index cb286f29..7840ece5 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -3,6 +3,7 @@ #include #include #include +#include #include #include "onnxruntime_cxx_api.h" @@ -737,11 +738,61 @@ OrtStatusPtr BindKernelOutput(Ort::KernelContext& ctx, return nullptr; } +bool TensorrtExecutionProvider::AllNodesAssignedToSpecificEP(const OrtGraph* graph, const std::string& provider_type) const { + size_t num_nodes = 0; + THROW_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); + + // Get all the nodes from the graph + std::vector nodes(num_nodes); + THROW_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); + + for (const auto node : nodes) { + const char* ep_name; + THROW_IF_ERROR(ort_api.Node_GetEpName(node, &ep_name)); + + if (std::string(ep_name) != provider_type) { + return false; + } + } + + return num_nodes != 0; +} + +// Check the graph is the subgraph of control flow op +bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const OrtGraph* graph) const { + const OrtNode* parent_node = nullptr; + THROW_IF_ERROR(ort_api.Graph_GetParentNode(graph, &parent_node)); + if (parent_node) { + const char* op_type = nullptr; + THROW_IF_ERROR(ort_api.Node_GetOperatorType(parent_node, &op_type)); + + if (control_flow_op_set_.find(std::string(op_type)) != control_flow_op_set_.end()) { + return true; + } + } + return false; +} + +// Check whether all the nodes of subgraph are supported +bool TensorrtExecutionProvider::IsSubGraphFullySupported(const OrtGraph* graph, SubGraphCollection_t supported_nodes_vector) const { + size_t num_nodes = 0; + THROW_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); + + int number_of_trt_nodes = 0; + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + number_of_trt_nodes += static_cast(group.first.size()); + } + } + + return number_of_trt_nodes == num_nodes; +} + SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t nodes_vector_input, int iterations, const int max_iterations, const OrtGraph* graph, bool* early_termination) const { - // Return if iterations are exceeding predefined number - SubGraphCollection_t nodes_list_output; + // Temporarily make all nodes supported + SubGraphCollection_t nodes_list_output = nodes_vector_input; return nodes_list_output; } @@ -750,6 +801,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this OrtEpGraphSupportInfo* graph_support_info) noexcept { TensorrtExecutionProvider* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; + auto ort_graph = Ort::ConstGraph(graph); size_t num_nodes = 0; RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); @@ -776,8 +828,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this return set; }; - // auto exclude_ops_set = get_exclude_ops_set(op_types_to_exclude_); - auto exclude_ops_set = get_exclude_ops_set(""); + auto exclude_ops_set = get_exclude_ops_set(ep->op_types_to_exclude_); /* Iterate all the nodes and exclude the node if: * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. @@ -821,12 +872,10 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this continue; } - /* - if (!ep->AllNodesAssignedToSpecificEP(*(subgraph->CreateGraphViewer()), kTensorrtExecutionProvider)) { + if (!ep->AllNodesAssignedToSpecificEP(subgraph, ep->name_)) { // if not all its subgraphs are supported, we need to exclude this control flow op return false; } - */ } return true; }; @@ -862,9 +911,6 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this supported_nodes_vector.clear(); } - // Temporarily make all nodes supported - supported_nodes_vector = parser_nodes_vector; - // Remove subgraphs if its size is less than the predefined minimal size for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end(); ++it) { const size_t subgraph_size = it->first.size(); @@ -873,108 +919,83 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this } } - // Detect and remove cycles from supported node list - /* ep->DetectTensorRTGraphCycles(supported_nodes_vector, graph, model_hash); */ - - // Consolidate supported node list - /* - if (supported_nodes_vector.size() > 1) { - nodes_vector.clear(); - for (const auto& group : supported_nodes_vector) { - if (!group.first.empty()) { - nodes_vector.insert(nodes_vector.end(), group.first.begin(), group.first.end()); - } - } - SubGraphCollection_t consolidated_supported_nodes_vector = {{nodes_vector, true}}; - if (p->DetectTensorRTGraphCycles(consolidated_supported_nodes_vector, graph, model_hash, false)) { - // LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are not consolidated because graph will have cycles after consolidation"; - } else { - // LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are consolidated into one subgraph"; - supported_nodes_vector = consolidated_supported_nodes_vector; - } - } - */ + // TODO: Detect and remove cycles from supported node list + // TODO: Consolidate supported node list + // Handle the case where the graph is subgraph of control flow op. // The purpose is to make control flow op as well as its subgraphs run on TRT. // Here we need to check whether subgraph is fully supported by TRT and don't fuse the nodes of the subgraph until control flow op level. - /* - if (p->IsSubGraphOfControlFlowOp(graph) && p->IsSubGraphFullySupported(supported_nodes_vector, number_of_ort_nodes)) { + if (ep->IsSubGraphOfControlFlowOp(graph) && ep->IsSubGraphFullySupported(graph, supported_nodes_vector)) { + //const std::vector& node_index = graph.GetNodesInTopologicalOrder(1); bool all_subgraphs_are_supported = true; // "If" control flow op has two subgraph bodies, "then" body and "else" body respectively. // Check its parent node's another subgraph to see whether that subgraph is also fully supported by TRT. - const OrtNode* parent_node = nullptr; - graph_api_->OrtGraph_GetParenNode(graph, &parent_node); - const char* parent_node_op_type = nullptr; - graph_api_->OrtNode_GetOpType(parent_node, &parent_node_op_type); - if (strcmp(parent_node_op_type, "If") == 0) { + Ort::ConstNode parent_node = ort_graph.GetParentNode(); + if (parent_node.GetOperatorType() == "If") { all_subgraphs_are_supported = false; SubGraphCollection_t subgraph_supported_nodes_vector; - const OrtGraphViewer** subgraphs = nullptr; - size_t subgraph_count = 0; - graph_api_->OrtNode_GetSubgraphs(parent_node, &subgraphs, &subgraph_count); - for (size_t i = 0; i < subgraph_count; i++) { - bool same_graph = false; - graph_api_->OrtGraph_IsSameGraph(graph, subgraphs[i], &same_graph); - if (same_graph) { - continue; - } - int number_of_ort_subgraph_nodes = 0; - graph_api_->OrtGraph_NumberOfNodes(subgraphs[i], &number_of_ort_subgraph_nodes); - std::vector subgraph_nodes_vector(number_of_ort_subgraph_nodes); - std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0); - SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}}; - bool subgraph_early_termination = false; - - // Another subgraph of "If" control flow op has no nodes. - // In this case, TRT EP should consider this empty subgraph is fully supported by TRT. - if (number_of_ort_subgraph_nodes == 0) { - all_subgraphs_are_supported = true; - break; - } - // Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. - else if (p->AllNodesAssignedToSpecificEP(subgraphs[i], tensorrtEp)) { - all_subgraphs_are_supported = true; - break; - } - // Another subgraph of "If" control flow has been parsed by GetCapability and not all subgraph's nodes assigned to TRT EP. - // (Note: GetExecutionProviderType() returns "" meaning node has not yet been assigned to any EPs) - else if (!p->AllNodesAssignedToSpecificEP(subgraphs[i], "")) { - all_subgraphs_are_supported = false; + + std::vector attr_name_subgraphs = parent_node.GetSubgraphs(); + for (auto attr_name_subgraph : attr_name_subgraphs) { + auto subgraph = attr_name_subgraph.sub_graph; + const OrtGraph* subgraph_raw_pointer = subgraph; + if (subgraph_raw_pointer != graph) { + + size_t num_subgraph_nodes = 0; + THROW_IF_ERROR(ort_api.Graph_GetNumNodes(subgraph, &num_subgraph_nodes)); + + // Another subgraph of "If" control flow op has no nodes. + // In this case, TRT EP should consider this empty subgraph is fully supported by TRT. + if (num_subgraph_nodes == 0) { + all_subgraphs_are_supported = true; + break; + } + // Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. + else if (ep->AllNodesAssignedToSpecificEP(subgraph, ep->name_)) { + all_subgraphs_are_supported = true; + break; + } + // Another subgraph of "If" control flow has been parsed by GetCapability and not all subgraph's nodes assigned to TRT EP. + // (Note: GetExecutionProviderType() returns "" meaning node has not yet been assigned to any EPs) + else if (!ep->AllNodesAssignedToSpecificEP(subgraph, "")) { + all_subgraphs_are_supported = false; + break; + } + + std::vector subgraph_nodes_vector(num_subgraph_nodes); + std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0); + SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}}; + bool subgraph_early_termination = false; + + // Another subgraph of "If" control flow has not yet been parsed by GetCapability. + subgraph_supported_nodes_vector = ep->GetSupportedList(parser_subgraph_nodes_vector, 0, ep->max_partition_iterations_, subgraph, &subgraph_early_termination); + all_subgraphs_are_supported = ep->IsSubGraphFullySupported(subgraph, subgraph_supported_nodes_vector); break; } - - // Another subgraph of "If" control flow has not yet been parsed by GetCapability. - subgraph_supported_nodes_vector = p->GetSupportedList(parser_subgraph_nodes_vector, 0, p->max_partition_iterations_, subgraphs[i], &subgraph_early_termination); - all_subgraphs_are_supported = p->IsSubGraphFullySupported(subgraph_supported_nodes_vector, number_of_ort_subgraph_nodes); - break; } - graph_api_->OrtGraph_ReleaseGraphViewerArray(subgraphs, subgraph_count); } if (all_subgraphs_are_supported) { + // We want the subgraph nodes to be assigned to TRT EP but don't want them to be fused until later at the control flow op level. + // Simply request the subgraph nodes with a single ComputeCapability for each with no MetaDef (i.e. what the default implementation for IExecutionProvider::GetCapability does). for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { for (const auto& index : group.first) { - std::unique_ptr sub_graph = std::make_unique(); - sub_graph->node_index_len = 1; - sub_graph->node_index = new size_t[sub_graph->node_index_len]; - sub_graph->node_index[0] = nodes_index[index]; - cache.push_back(sub_graph.release()); + const OrtNode* supported_node = nodes[index]; + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_node)); } } } - *cnt = cache.size(); - *indexed_sub_graph = new OrtIndexedSubGraph*[*cnt]; - for (size_t i = 0; i < *cnt; i++) { - (*indexed_sub_graph)[i] = cache[i]; - } - // LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; - return; + std::string message = "[TensorRT EP] Whole graph will run on TensorRT execution provider"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + + return nullptr; } } - */ int number_of_trt_nodes = 0; for (const auto& group : supported_nodes_vector) { @@ -2251,7 +2272,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa // The implementation of the SessionOptionsAppendExecutionProvider C API function automatically adds EP options to // the session option configurations with the key prefix "ep..". - // We extract those EP options to create a new "provider options" key/value map. + // We extract those EP options to create a new "provider options" key-value map. std::string lowercase_ep_name = name_.c_str(); std::transform(lowercase_ep_name.begin(), lowercase_ep_name.end(), lowercase_ep_name.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); @@ -2289,7 +2310,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa info_ = TensorrtExecutionProviderInfo::FromProviderOptions(provider_options); info_.has_trt_options = true; device_id_ = info_.device_id; - // api_->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, device_id_, &default_device); std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; @@ -2358,6 +2378,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa profile_opt_shapes = info_.profile_opt_shapes; cuda_graph_enable_ = info_.cuda_graph_enable; engine_hw_compatible_ = info_.engine_hw_compatible; + op_types_to_exclude_ = info_.op_types_to_exclude; } else { // deprecate env provider option } diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index e95157bd..90a056e2 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -22,48 +22,6 @@ using AllocateFunc = void* (*)(void*, size_t, size_t); using DestroyFunc = void (*)(void*, void*); namespace trt_ep { -namespace tensorrt_env_vars { -static const std::string kMaxPartitionIterations = "ORT_TENSORRT_MAX_PARTITION_ITERATIONS"; -static const std::string kMinSubgraphSize = "ORT_TENSORRT_MIN_SUBGRAPH_SIZE"; -static const std::string kMaxWorkspaceSize = "ORT_TENSORRT_MAX_WORKSPACE_SIZE"; -static const std::string kFP16Enable = "ORT_TENSORRT_FP16_ENABLE"; -static const std::string kINT8Enable = "ORT_TENSORRT_INT8_ENABLE"; -static const std::string kINT8CalibrationTableName = "ORT_TENSORRT_INT8_CALIBRATION_TABLE_NAME"; -static const std::string kINT8UseNativeTensorrtCalibrationTable = "ORT_TENSORRT_INT8_USE_NATIVE_CALIBRATION_TABLE"; -static const std::string kDLAEnable = "ORT_TENSORRT_DLA_ENABLE"; -static const std::string kDLACore = "ORT_TENSORRT_DLA_CORE"; -static const std::string kDumpSubgraphs = "ORT_TENSORRT_DUMP_SUBGRAPHS"; -static const std::string kEngineCacheEnable = "ORT_TENSORRT_ENGINE_CACHE_ENABLE"; -static const std::string kCachePath = "ORT_TENSORRT_CACHE_PATH"; -static const std::string kWeightStrippedEngineEnable = "ORT_TENSORRT_WEIGHT_STRIPPED_ENGINE_ENABLE"; -static const std::string kOnnxModelFolderPath = "ORT_TENSORRT_ONNX_MODEL_FOLDER_PATH"; -// As a timing cache can be used across multiple ONNX files it makes sense to have a separate cache path -static const std::string kTimingCachePath = "ORT_TENSORRT_GLOBAL_CACHE_PATH"; -static const std::string kDecryptionEnable = "ORT_TENSORRT_ENGINE_DECRYPTION_ENABLE"; -static const std::string kDecryptionLibPath = "ORT_TENSORRT_ENGINE_DECRYPTION_LIB_PATH"; -static const std::string kForceSequentialEngineBuild = "ORT_TENSORRT_FORCE_SEQUENTIAL_ENGINE_BUILD"; -static const std::string kContextMemorySharingEnable = "ORT_TENSORRT_CONTEXT_MEMORY_SHARING_ENABLE"; -static const std::string kLayerNormFP32Fallback = "ORT_TENSORRT_LAYER_NORM_FP32_FALLBACK"; -static const std::string kTimingCacheEnable = "ORT_TENSORRT_TIMING_CACHE_ENABLE"; -static const std::string kForceTimingCache = "ORT_TENSORRT_FORCE_TIMING_CACHE_ENABLE"; -static const std::string kDetailedBuildLog = "ORT_TENSORRT_DETAILED_BUILD_LOG_ENABLE"; -static const std::string kBuildHeuristics = "ORT_TENSORRT_BUILD_HEURISTICS_ENABLE"; -static const std::string kSparsityEnable = "ORT_TENSORRT_SPARSITY_ENABLE"; -static const std::string kBuilderOptimizationLevel = "ORT_TENSORRT_BUILDER_OPTIMIZATION_LEVEL"; -static const std::string kAuxiliaryStreams = "ORT_TENSORRT_AUXILIARY_STREAMS"; -static const std::string kTacticSources = "ORT_TENSORRT_TACTIC_SOURCES"; -static const std::string kExtraPluginLibPaths = "ORT_TENSORRT_EXTRA_PLUGIN_LIB_PATHS"; -static const std::string kProfilesMinShapes = "ORT_TENSORRT_PROFILE_MIN_SHAPES"; -static const std::string kProfilesMaxShapes = "ORT_TENSORRT_PROFILE_MAX_SHAPES"; -static const std::string kProfilesOptShapes = "ORT_TENSORRT_PROFILE_OPT_SHAPES"; -static const std::string kCudaGraphEnable = "ORT_TENSORRT_CUDA_GRAPH_ENABLE"; -static const std::string kDumpEpContextModel = "ORT_DUMP_EP_CONTEXT_MODEL"; -static const std::string kEpContextEmbedMode = "ORT_EP_CONTEXT_EMBED_MODE"; -static const std::string kEpContextComputeCapabilityEnable = "ORT_EP_CONTEXT_COMPUTE_CAPABILITY_ENABLE"; -static const std::string kEngineCachePrefix = "ORT_TENSORRT_CACHE_PREFIX"; -// Old env variable for backward compatibility -static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH"; -} // namespace tensorrt_env_vars class TensorrtLogger : public nvinfer1::ILogger { nvinfer1::ILogger::Severity verbosity_; @@ -386,6 +344,7 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { bool cuda_graph_enable_ = false; std::string cache_prefix_; bool engine_hw_compatible_ = false; + std::string op_types_to_exclude_; // For create/dump EP context node model bool dump_ep_context_model_ = false; @@ -399,6 +358,8 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { std::vector extra_attr_keys_; std::vector extra_attr_values_; + std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; + // std::unique_ptr model_proto_ = ONNX_NAMESPACE::ModelProto::Create(); // mutable std::unordered_map> subgraph_context_map_; @@ -442,6 +403,12 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { bool IsGraphCaptureAllowed() const { return false; }; nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const; + + bool AllNodesAssignedToSpecificEP(const OrtGraph* graph, const std::string& provider_type) const; + + bool IsSubGraphOfControlFlowOp(const OrtGraph* graph) const; + + bool IsSubGraphFullySupported(const OrtGraph* graph, SubGraphCollection_t supported_nodes_vector) const; }; /// From ae9686f213b41fa3145daac6b806ee75f648e92f Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 26 Sep 2025 09:34:43 -0700 Subject: [PATCH 59/60] Add code for updating cache path for EPContext node --- .../tensorrt/onnx_ctx_model_helper.cc | 47 +++++ .../tensorrt/onnx_ctx_model_helper.h | 4 + .../tensorrt/tensorrt_execution_provider.cc | 160 +++++++++--------- .../tensorrt/tensorrt_execution_provider.h | 2 +- .../tensorrt/utils/helper.cc | 64 +++++++ .../tensorrt/utils/path_string.h | 98 +++++++++++ 6 files changed, 292 insertions(+), 83 deletions(-) create mode 100644 plugin_execution_providers/tensorrt/utils/helper.cc create mode 100644 plugin_execution_providers/tensorrt/utils/path_string.h diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc index 34bd9647..a75498be 100644 --- a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc @@ -6,6 +6,7 @@ #include #include "ep_utils.h" +#include "path_string.h" #include "onnx_ctx_model_helper.h" #include "onnx/onnx_pb.h" @@ -13,6 +14,52 @@ namespace trt_ep { extern TensorrtLogger& GetTensorrtLogger(bool verbose_log, const OrtLogger& ort_default_logger, const OrtApi* ort_api); +bool IsAbsolutePath(const std::string& path_string) { +#ifdef _WIN32 + PathString ort_path_string = ToPathString(path_string); + auto path = std::filesystem::path(ort_path_string.c_str()); + return path.is_absolute(); +#else + if (!path_string.empty() && path_string[0] == '/') { + return true; + } + return false; +#endif +} + +// Like "../file_path" +bool IsRelativePathToParentPath(const std::string& path_string) { +#ifdef _WIN32 + PathString ort_path_string = ToPathString(path_string); + auto path = std::filesystem::path(ort_path_string.c_str()); + auto relative_path = path.lexically_normal().make_preferred().wstring(); + if (relative_path.find(L"..", 0) != std::string::npos) { + return true; + } + return false; +#else + if (!path_string.empty() && path_string.find("..", 0) != std::string::npos) { + return true; + } + return false; +#endif +} + +/* + * Return the directory where the ep context model locates + */ +std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path) { + if (ep_context_file_path.empty()) { + return std::filesystem::path(); + } + std::filesystem::path ctx_path(ep_context_file_path); + if (std::filesystem::is_directory(ep_context_file_path)) { + return ctx_path; + } else { + return ctx_path.parent_path(); + } +} + /* * Check whether the graph has the EP context node. * The node can contain the precompiled engine info for TRT EP to directly load the engine. diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h index 4979c07c..568ff20e 100644 --- a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h @@ -13,6 +13,10 @@ #include namespace trt_ep { +bool IsAbsolutePath(const std::string& path_string); +bool IsRelativePathToParentPath(const std::string& path_string); +std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); + class EPContextNodeHelper : public ApiPtrs { public: EPContextNodeHelper(TensorrtExecutionProvider& ep, diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index 7840ece5..af746551 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -151,7 +151,7 @@ bool ApplyProfileShapesFromProviderOptions(std::vectorLogger_LogMessage(logger, OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, @@ -777,7 +777,7 @@ bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const OrtGraph* graph) bool TensorrtExecutionProvider::IsSubGraphFullySupported(const OrtGraph* graph, SubGraphCollection_t supported_nodes_vector) const { size_t num_nodes = 0; THROW_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); - + int number_of_trt_nodes = 0; for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { @@ -922,12 +922,12 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this // TODO: Detect and remove cycles from supported node list // TODO: Consolidate supported node list - + // Handle the case where the graph is subgraph of control flow op. // The purpose is to make control flow op as well as its subgraphs run on TRT. // Here we need to check whether subgraph is fully supported by TRT and don't fuse the nodes of the subgraph until control flow op level. if (ep->IsSubGraphOfControlFlowOp(graph) && ep->IsSubGraphFullySupported(graph, supported_nodes_vector)) { - //const std::vector& node_index = graph.GetNodesInTopologicalOrder(1); + // const std::vector& node_index = graph.GetNodesInTopologicalOrder(1); bool all_subgraphs_are_supported = true; // "If" control flow op has two subgraph bodies, "then" body and "else" body respectively. @@ -942,7 +942,6 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this auto subgraph = attr_name_subgraph.sub_graph; const OrtGraph* subgraph_raw_pointer = subgraph; if (subgraph_raw_pointer != graph) { - size_t num_subgraph_nodes = 0; THROW_IF_ERROR(ort_api.Graph_GetNumNodes(subgraph, &num_subgraph_nodes)); @@ -1128,8 +1127,8 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this (static_cast(layer))->getOperation() == nvinfer1::ElementWiseOperation::kPOW) { std::string message = "[TensorRT EP] Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow"; Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, - OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, - message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); layer->setPrecision(nvinfer1::DataType::kFLOAT); next_layer->setPrecision(nvinfer1::DataType::kFLOAT); layer->setOutputType(0, nvinfer1::DataType::kFLOAT); @@ -1371,7 +1370,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } else { if (dla_core_ >= number_of_dla_core) { std::string message = "[TensorRT EP] Try to use DLA core #" + std::to_string(dla_core_) + - std::string(", but it exceeds platform's maximum DLA core number ") + std::to_string(number_of_dla_core) + + std::string(", but it exceeds platform's maximum DLA core number ") + std::to_string(number_of_dla_core) + std::string(". Use DLA core 0 instead."); Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, @@ -1412,14 +1411,12 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this // for TRT 8.6 onwards, heuristic-based tactic option is automatically enabled by setting builder optimization level 2 if (build_heuristics_enable_) { if (builder_optimization_level_ == 2) { - std::string message = "[TensorRT EP] Builder heuristics are automatically enabled by builder optimization " - + std::string("level 2. trt_build_heuristics_enable is deprecated on TRT 8.6 onwards."); + std::string message = "[TensorRT EP] Builder heuristics are automatically enabled by builder optimization " + std::string("level 2. trt_build_heuristics_enable is deprecated on TRT 8.6 onwards."); Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } else { - std::string message = "[TensorRT EP] trt_build_heuristics_enable is deprecated on TRT 8.6 onwards. Please set " - + std::string("builder optimization level as 2 to enable builder heuristics."); + std::string message = "[TensorRT EP] trt_build_heuristics_enable is deprecated on TRT 8.6 onwards. Please set " + std::string("builder optimization level as 2 to enable builder heuristics."); Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); @@ -2135,7 +2132,6 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine( std::string onnx_model_filename, std::string& onnx_model_folder_path, std::string& weight_stripped_engine_cath_path, bool path_check, const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, bool detailed_build_log) { - #if NV_TENSORRT_MAJOR >= 10 bool refit_from_file = onnx_model_bytestream == nullptr && onnx_model_bytestream_size == 0; std::filesystem::path onnx_model_path{onnx_model_folder_path}; @@ -2258,12 +2254,12 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // Initialize the execution provider. - auto status = ort_api.Logger_LogMessage(&logger_, - OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, - ("Plugin EP has been created with name " + name_).c_str(), - ORT_FILE, __LINE__, __FUNCTION__); + auto ort_status = ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + ("Plugin EP has been created with name " + name_).c_str(), + ORT_FILE, __LINE__, __FUNCTION__); // ignore status for now - (void)status; + (void)ort_status; // populate apis as global for utility functions g_ort_api = &ort_api; @@ -2343,7 +2339,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa force_timing_cache_match_ = info_.force_timing_cache; detailed_build_log_ = info_.detailed_build_log; dump_ep_context_model_ = info_.dump_ep_context_model; - // dump_ep_context_model_ = true; ep_context_file_path_ = info_.ep_context_file_path; ep_context_embed_mode_ = info_.ep_context_embed_mode; enable_engine_cache_for_ep_context_model(); @@ -2420,7 +2415,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa } } - /* // If dump_ep_context_model_ is enable, TRT EP forces cache_path_ to be the relative path of ep_context_file_path_. // For example, // - original cache path = "engine_cache_dir" -> new cache path = "./context_model_dir/engine_cache_dir" @@ -2429,20 +2423,25 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa // For security reason, it needs to make sure the engine cache is saved inside context model directory. if (dump_ep_context_model_ && engine_cache_enable_) { if (IsAbsolutePath(cache_path_)) { - // LOGS_DEFAULT(ERROR) << "In the case of dumping context model and for security purpose, the trt_engine_cache_path should be set with a relative path, but it is an absolute path: " << cache_path_; + std::string message = "In the case of dumping context model and for security purpose, the trt_engine_cache_path should be set with a relative path, but it is an absolute path: " + cache_path_; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } if (IsRelativePathToParentPath(cache_path_)) { - // LOGS_DEFAULT(ERROR) << "In the case of dumping context model and for security purpose, The trt_engine_cache_path has '..', it's not allowed to point outside the directory."; + std::string message = "In the case of dumping context model and for security purpose, The trt_engine_cache_path has '..', it's not allowed to point outside the directory."; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } // Engine cache relative path to context model directory. // It's used when dumping the "ep_cache_context" node attribute. - engine_cache_relative_path_to_context_model_dir = cache_path_; + engine_cache_relative_path_to_context_model_dir_ = cache_path_; // Make cache_path_ to be the relative path of ep_context_file_path_ cache_path_ = GetPathOrParentPathOfCtxModel(ep_context_file_path_).append(cache_path_).string(); } - */ // Hardware compatibility: pre-check on environment if (engine_cache_enable_ && engine_hw_compatible_) { @@ -2485,16 +2484,14 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa if (engine_decryption_enable_) { LIBTYPE handle = OPENLIB(engine_decryption_lib_path_.c_str()); if (handle == nullptr) { - // TODO(yang) - // ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - // "TensorRT EP could not open shared library from " + engine_decryption_lib_path_)); + std::string message = "TensorRT EP could not open shared library from " + engine_decryption_lib_path_; + THROW_IF_ERROR(ort_api.CreateStatus(ORT_EP_FAIL, message.c_str())); } engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt"); engine_encryption_ = (int (*)(const char*, char*, size_t))LIBFUNC(handle, "encrypt"); if (engine_decryption_ == nullptr) { - // TODO(yang) - // ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - // "TensorRT EP could not find decryption function in shared library from " + engine_decryption_lib_path_)); + std::string message = "TensorRT EP could not find decryption function in shared library from " + engine_decryption_lib_path_; + THROW_IF_ERROR(ort_api.CreateStatus(ORT_EP_FAIL, message.c_str())); } } @@ -2512,56 +2509,56 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa * Please refer to ParserProfileShapes() for more details) * */ - // bool status = true; - // if (status) { - // status = ParseProfileShapes(profile_min_shapes, profile_min_shapes_); - // if (!status) { - // profile_min_shapes_.clear(); - // std::string message = "[TensorRT EP] The format of provider option 'trt_profile_min_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; - // Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, - // OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, - // message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); - // } - // } + bool status = true; + if (status) { + status = ParseProfileShapes(profile_min_shapes, profile_min_shapes_); + if (!status) { + profile_min_shapes_.clear(); + std::string message = "[TensorRT EP] The format of provider option 'trt_profile_min_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } - // if (status) { - // status = ParseProfileShapes(profile_max_shapes, profile_max_shapes_); - // if (!status) { - // profile_max_shapes_.clear(); - // std::string message = "[TensorRT EP] The format of provider option 'trt_profile_max_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; - // Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, - // OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, - // message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); - // } - // } + if (status) { + status = ParseProfileShapes(profile_max_shapes, profile_max_shapes_); + if (!status) { + profile_max_shapes_.clear(); + std::string message = "[TensorRT EP] The format of provider option 'trt_profile_max_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } - // if (status) { - // status = ParseProfileShapes(profile_opt_shapes, profile_opt_shapes_); - // if (!status) { - // profile_opt_shapes_.clear(); - // std::string message = "[TensorRT EP] The format of provider option 'trt_profile_opt_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; - // Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, - // OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, - // message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); - // } - // } + if (status) { + status = ParseProfileShapes(profile_opt_shapes, profile_opt_shapes_); + if (!status) { + profile_opt_shapes_.clear(); + std::string message = "[TensorRT EP] The format of provider option 'trt_profile_opt_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } - // if (status) { - // status = ValidateProfileShapes(profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); - // if (!status) { - // std::string message = "[TensorRT EP] Profile shapes validation failed. Make sure the provider options 'trt_profile_min_shapes', 'trt_profile_max_shapes' and 'trt_profile_opt_shapes' have same input name and number of profile."; - // Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, - // OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, - // message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); - // message = "[TensorRT EP] TRT EP will implicitly create optimization profiles based on input tensor for you."; - // Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, - // OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, - // message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); - // profile_min_shapes_.clear(); - // profile_max_shapes_.clear(); - // profile_opt_shapes_.clear(); - // } - // } + if (status) { + status = ValidateProfileShapes(profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); + if (!status) { + std::string message = "[TensorRT EP] Profile shapes validation failed. Make sure the provider options 'trt_profile_min_shapes', 'trt_profile_max_shapes' and 'trt_profile_opt_shapes' have same input name and number of profile."; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + message = "[TensorRT EP] TRT EP will implicitly create optimization profiles based on input tensor for you."; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + profile_min_shapes_.clear(); + profile_max_shapes_.clear(); + profile_opt_shapes_.clear(); + } + } // cuda graph: // cudaStreamSynchronize() is not allowed in cuda graph capture. @@ -3037,8 +3034,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* } if (detailed_build_log) { auto engine_build_stop = std::chrono::steady_clock::now(); - std::string message = "TensorRT engine build for " + trt_state->trt_node_name_with_precision + " took: " - + std::to_string(std::chrono::duration_cast(engine_build_stop - engine_build_start).count()) + "ms"; + std::string message = "TensorRT engine build for " + trt_state->trt_node_name_with_precision + " took: " + std::to_string(std::chrono::duration_cast(engine_build_stop - engine_build_start).count()) + "ms"; Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); @@ -3075,8 +3071,8 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* } else { std::string message = "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, - OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, - message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } } else { std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); @@ -3620,4 +3616,4 @@ void TRTEpEpContextNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_pt (void)trt_ep_compute_state; // Do nothing for here. } -} +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index 90a056e2..d75695bc 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -352,7 +352,7 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { int ep_context_embed_mode_ = 0; std::string ctx_model_path_; std::string ep_cache_context_attr_; - std::string engine_cache_relative_path_to_context_model_dir; + std::string engine_cache_relative_path_to_context_model_dir_; OrtGraph* ep_ctx_graph_ = nullptr; std::vector extra_attr_keys_; diff --git a/plugin_execution_providers/tensorrt/utils/helper.cc b/plugin_execution_providers/tensorrt/utils/helper.cc new file mode 100644 index 00000000..8bf59d2f --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/helper.cc @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef _WIN32 +#include +#include +#endif + +#ifdef ORT_NO_EXCEPTIONS +#if defined(__ANDROID__) +#include +#else +#include +#endif +#endif + +#include +#include "ep_utils.h" + +#ifdef _WIN32 +std::string ToUTF8String(std::wstring_view s) { + if (s.size() >= static_cast(std::numeric_limits::max())) + THROW("length overflow"); + + const int src_len = static_cast(s.size() + 1); + const int len = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, nullptr, 0, nullptr, nullptr); + assert(len > 0); + std::string ret(static_cast(len) - 1, '\0'); +#pragma warning(disable : 4189) + const int r = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, (char*)ret.data(), len, nullptr, nullptr); + assert(len == r); +#pragma warning(default : 4189) + return ret; +} + +std::wstring ToWideString(std::string_view s) { + if (s.size() >= static_cast(std::numeric_limits::max())) + THROW("length overflow"); + + const int src_len = static_cast(s.size() + 1); + const int len = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, nullptr, 0); + assert(len > 0); + std::wstring ret(static_cast(len) - 1, '\0'); +#pragma warning(disable : 4189) + const int r = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, (wchar_t*)ret.data(), len); + assert(len == r); +#pragma warning(default : 4189) + return ret; +} +#endif // #ifdef _WIN32 + +#ifdef NO_EXCEPTIONS +void PrintFinalMessage(const char* msg) { +#if defined(__ANDROID__) + __android_log_print(ANDROID_LOG_ERROR, "onnxruntime", "%s", msg); +#else + // TODO, consider changing the output of the error message from std::cerr to logging when the + // exceptions are disabled, since using std::cerr might increase binary size, and std::cerr output + // might not be easily accessible on some systems such as mobile + // TODO, see if we need to change the output of the error message from std::cerr to NSLog for iOS + std::cerr << msg << std::endl; +#endif +} +#endif // #ifdef NO_EXCEPTIONS diff --git a/plugin_execution_providers/tensorrt/utils/path_string.h b/plugin_execution_providers/tensorrt/utils/path_string.h new file mode 100644 index 00000000..7e7b5310 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/path_string.h @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +// for std::tolower or std::towlower +#ifdef _WIN32 +#include +#else +#include +#endif + +#include "onnxruntime_c_api.h" + +// char type for filesystem paths +using PathChar = ORTCHAR_T; +// string type for filesystem paths +using PathString = std::basic_string; + + +inline std::string ToUTF8String(const std::string& s) { return s; } +#ifdef _WIN32 +/** + * Convert a wide character string to a UTF-8 string + */ +std::string ToUTF8String(std::wstring_view s); +inline std::string ToUTF8String(const wchar_t* s) { + return ToUTF8String(std::wstring_view{s}); +} +inline std::string ToUTF8String(const std::wstring& s) { + return ToUTF8String(std::wstring_view{s}); +} +std::wstring ToWideString(std::string_view s); +inline std::wstring ToWideString(const char* s) { + return ToWideString(std::string_view{s}); +} +inline std::wstring ToWideString(const std::string& s) { + return ToWideString(std::string_view{s}); +} +inline std::wstring ToWideString(const std::wstring& s) { return s; } +inline std::wstring ToWideString(std::wstring_view s) { return std::wstring{s}; } +#else +inline std::string ToWideString(const std::string& s) { return s; } +inline std::string ToWideString(const char* s) { return s; } +inline std::string ToWideString(std::string_view s) { return std::string{s}; } +#endif + + +inline PathString ToPathString(const PathString& s) { + return s; +} + +#ifdef _WIN32 + +static_assert(std::is_same::value, "PathString is not std::wstring!"); + +inline PathString ToPathString(std::string_view s) { + return ToWideString(s); +} +inline PathString ToPathString(const char* s) { + return ToWideString(s); +} +inline PathString ToPathString(const std::string& s) { + return ToWideString(s); +} + +inline PathChar ToLowerPathChar(PathChar c) { + return std::towlower(c); +} + +inline std::string PathToUTF8String(const PathString& s) { + return ToUTF8String(s); +} + +#else + +static_assert(std::is_same::value, "PathString is not std::string!"); + +inline PathString ToPathString(const char* s) { + return s; +} + +inline PathString ToPathString(std::string_view s) { + return PathString{s}; +} + +inline PathChar ToLowerPathChar(PathChar c) { + return std::tolower(c); +} + +inline std::string PathToUTF8String(const PathString& s) { + return s; +} + +#endif From c8a6ae60d5f5cbf04302b1bb45c2fa55c8098268 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 26 Sep 2025 14:51:58 -0700 Subject: [PATCH 60/60] add onnx_external_data_bytestream support for refitting the engine --- .../tensorrt/onnx_ctx_model_helper.cc | 182 +++++++------ .../tensorrt/onnx_ctx_model_helper.h | 16 +- .../tensorrt/tensorrt_execution_provider.cc | 245 +++++++++++++++--- .../tensorrt/tensorrt_execution_provider.h | 65 ++--- .../tensorrt_execution_provider_info.cc | 239 +++-------------- .../tensorrt_execution_provider_info.h | 12 +- 6 files changed, 383 insertions(+), 376 deletions(-) diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc index a75498be..be814b62 100644 --- a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc @@ -60,33 +60,13 @@ std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_contex } } -/* - * Check whether the graph has the EP context node. - * The node can contain the precompiled engine info for TRT EP to directly load the engine. - * - * Note: Please see more details about "EPContext" contrib op in contrib_defs.cc - */ -bool EPContextNodeHelper::GraphHasCtxNode(const OrtGraph* graph, const OrtApi& ort_api) { - size_t num_nodes = 0; - RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); - - std::vector nodes(num_nodes); - RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); - - for (size_t i = 0; i < num_nodes; ++i) { - auto node = nodes[i]; - - const char* op_type = nullptr; - RETURN_IF_ERROR(ort_api.Node_GetOperatorType(node, &op_type)); - if (node != nullptr && std::string(op_type) == "EPContext") { - return true; - } - } - return false; +bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) { + // The weight-stripped engine cache has the naming of xxx.stripped.engine + return engine_cache_path.stem().extension().string() == ".stripped"; } /* - * Create EPContext OrtNode from a fused_node + * Create an EPContext OrtNode from a fused_node */ OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_cache_path, char* engine_data, @@ -158,12 +138,54 @@ OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_ca return nullptr; } +/* + * Check whether the graph has the EP context node. + * The node can contain the precompiled engine info for TRT EP to directly load the engine. + * + * Note: Please see more details about "EPContext" contrib op in contrib_defs.cc + */ +bool EPContextNodeReader::GraphHasCtxNode(const OrtGraph* graph, const OrtApi& ort_api) { + size_t num_nodes = 0; + RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); + + std::vector nodes(num_nodes); + RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); + + for (size_t i = 0; i < num_nodes; ++i) { + auto node = nodes[i]; + + const char* op_type = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetOperatorType(node, &op_type)); + if (node != nullptr && std::string(op_type) == "EPContext") { + return true; + } + } + return false; +} + +/* + * The sanity check for EP context contrib op. + */ +bool EPContextNodeReader::ValidateEPCtxNode(const OrtGraph* graph) const { + size_t num_nodes = 0; + THROW_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); + ENFORCE(num_nodes == 1); + + std::vector nodes(num_nodes); + RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); + + const char* op_type = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetOperatorType(nodes[0], &op_type)); + ENFORCE(std::string(op_type) == "EPContext"); + + // TODO: Check compute capability and others + return true; +} + OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) { - /* - if (!ValidateEPCtxNode(graph)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node"); + if (!ValidateEPCtxNode(&graph)) { + return ort_api.CreateStatus(ORT_EP_FAIL, "It's not a valid EP Context node"); } - */ size_t num_nodes = 0; RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&graph, &num_nodes)); @@ -173,14 +195,6 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) { auto node = nodes[0]; - size_t num_node_attributes = 0; - RETURN_IF_ERROR(ort_api.Node_GetNumAttributes(node, &num_node_attributes)); - - /* - std::vector node_attributes(num_node_attributes); - RETURN_IF_ERROR(ort_api.Node_GetAttributes(node, node_attributes.data(), node_attributes.size())); - */ - const OrtOpAttr* node_attr = nullptr; RETURN_IF_ERROR(ort_api.Node_GetAttributeByName(node, "embed_mode", &node_attr)); const int64_t embed_mode = reinterpret_cast(node_attr)->i(); @@ -197,52 +211,61 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) { *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(const_cast(context_binary.c_str()), static_cast(context_binary.length()))); - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it"; + + std::string message = "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); if (!(*trt_engine_)) { return ort_api.CreateStatus(ORT_EP_FAIL, "TensorRT EP could not deserialize engine from binary data"); } - /* if (weight_stripped_engine_refit_) { - const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s(); + node_attr = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetAttributeByName(node, "onnx_model_filename", &node_attr)); + const std::string onnx_model_filename = reinterpret_cast(node_attr)->s(); std::string placeholder; - auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename, - onnx_model_folder_path_, - placeholder, - make_secure_path_checks, - onnx_model_bytestream_, - onnx_model_bytestream_size_, - onnx_external_data_bytestream_, - onnx_external_data_bytestream_size_, - (*trt_engine_).get(), - false, // serialize refitted engine to disk - detailed_build_log_); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + auto status = ep_.RefitEngine(onnx_model_filename, + onnx_model_folder_path_, + placeholder, + make_secure_path_checks, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, + (*trt_engine_).get(), + false, // serialize refitted engine to disk + detailed_build_log_); + if (status != nullptr) { + return ort_api.CreateStatus(ORT_EP_FAIL, "RefitEngine failed."); } } - */ } else { // Get engine from cache file. node_attr = nullptr; RETURN_IF_ERROR(ort_api.Node_GetAttributeByName(node, "ep_cache_context", &node_attr)); std::string cache_path = reinterpret_cast(node_attr)->s(); - /* // For security purpose, in the case of running context model, TRT EP won't allow // engine cache path to be the relative path like "../file_path" or the absolute path. // It only allows the engine cache to be in the same directory or sub directory of the context model. if (IsAbsolutePath(cache_path)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "For security purpose, the ep_cache_context attribute should be set with a relative path, but it is an absolute path: " + cache_path); + std::string message = "For security purpose, the ep_cache_context attribute should be set with a relative path, but it is an absolute path: " + cache_path; + return ort_api.CreateStatus(ORT_EP_FAIL, message.c_str()); } if (IsRelativePathToParentPath(cache_path)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "The file path in ep_cache_context attribute has '..'. For security purpose, it's not allowed to point outside the directory."); + std::string message = "The file path in ep_cache_context attribute has '..'. For security purpose, it's not allowed to point outside the directory."; + return ort_api.CreateStatus(ORT_EP_FAIL, message.c_str()); } // The engine cache and context model (current model) should be in the same directory std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_)); auto engine_cache_path = ctx_model_dir.append(cache_path); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string(); + + std::string message = "[TensorRT EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string(); + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); // If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled if (!weight_stripped_engine_refit_) { @@ -253,14 +276,15 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) { if (weight_stripped_engine_refit_) { const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string()); if (std::filesystem::exists(refitted_engine_cache_path)) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " + refitted_engine_cache_path.string() + " exists."; + std::string message = "[TensorRT EP] " + refitted_engine_cache_path.string() + " exists."; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); engine_cache_path = refitted_engine_cache_path.string(); weight_stripped_engine_refit_ = false; } } - */ - std::filesystem::path engine_cache_path(cache_path); if (!std::filesystem::exists(engine_cache_path)) { std::string error_msg = "TensorRT EP can't find engine cache: " + engine_cache_path.string() + @@ -279,28 +303,32 @@ OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) { std::string error_msg = "TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string(); return ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); } - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string(); - /* + message = "[TensorRT EP] DeSerialized " + engine_cache_path.string(); + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + if (weight_stripped_engine_refit_) { - const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s(); + node_attr = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetAttributeByName(node, "onnx_model_filename", &node_attr)); + const std::string onnx_model_filename = reinterpret_cast(node_attr)->s(); std::string weight_stripped_engine_cache = engine_cache_path.string(); - auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename, - onnx_model_folder_path_, - weight_stripped_engine_cache, - make_secure_path_checks, - onnx_model_bytestream_, - onnx_model_bytestream_size_, - onnx_external_data_bytestream_, - onnx_external_data_bytestream_size_, - (*trt_engine_).get(), - true, // serialize refitted engine to disk - detailed_build_log_); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + auto status = ep_.RefitEngine(onnx_model_filename, + onnx_model_folder_path_, + weight_stripped_engine_cache, + make_secure_path_checks, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, + (*trt_engine_).get(), + true, // serialize refitted engine to disk + detailed_build_log_); + if (status != nullptr) { + return ort_api.CreateStatus(ORT_EP_FAIL, "RefitEngine failed."); } } - */ } return nullptr; } diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h index 568ff20e..bc2c53d7 100644 --- a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h @@ -17,6 +17,10 @@ bool IsAbsolutePath(const std::string& path_string); bool IsRelativePathToParentPath(const std::string& path_string); std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); +// Class to create an EPContext node from an ORT's fused_node. +// +// Note: The class can be instantiated many times during EP's Compile() as to generate the EPContext nodes from fused_nodes/subgraphs and returns them to ORT via Compile(), +// ORT will end up creating the EPContext model. class EPContextNodeHelper : public ApiPtrs { public: EPContextNodeHelper(TensorrtExecutionProvider& ep, @@ -24,8 +28,6 @@ class EPContextNodeHelper : public ApiPtrs { const OrtNode* fused_node) : ApiPtrs{static_cast(ep)}, graph_(graph), fused_node_(fused_node) {} - static bool GraphHasCtxNode(const OrtGraph* graph, const OrtApi& ort_api); - OrtStatus* CreateEPContextNode(const std::string& engine_cache_path, char* engine_data, size_t size, @@ -39,9 +41,11 @@ class EPContextNodeHelper : public ApiPtrs { const OrtNode* fused_node_ = nullptr; }; +// Class to read an OrtGraph that contains an EPContext node and get the engine binary accordingly. class EPContextNodeReader : public ApiPtrs { public: EPContextNodeReader(TensorrtExecutionProvider& ep, + const OrtLogger& logger, std::unique_ptr* trt_engine, nvinfer1::IRuntime* trt_runtime, std::string ep_context_model_path, @@ -54,6 +58,8 @@ class EPContextNodeReader : public ApiPtrs { size_t onnx_external_data_bytestream_size, bool detailed_build_log) : ApiPtrs{static_cast(ep)}, + ep_(ep), + logger_(logger), trt_engine_(trt_engine), trt_runtime_(trt_runtime), ep_context_model_path_(ep_context_model_path), @@ -67,11 +73,15 @@ class EPContextNodeReader : public ApiPtrs { detailed_build_log_(detailed_build_log) { } - // bool ValidateEPCtxNode(const OrtGraph& graph); + static bool GraphHasCtxNode(const OrtGraph* graph, const OrtApi& ort_api); + + bool ValidateEPCtxNode(const OrtGraph* graph) const; OrtStatus* GetEpContextFromGraph(const OrtGraph& graph); private: + TensorrtExecutionProvider& ep_; + const OrtLogger& logger_; std::unique_ptr* trt_engine_; nvinfer1::IRuntime* trt_runtime_; std::string ep_context_model_path_; // If using context model, it implies context model and engine cache is in the same directory diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc index af746551..945e93e9 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -36,6 +36,16 @@ const OrtApi* g_ort_api = nullptr; const OrtEpApi* g_ep_api = nullptr; const OrtModelEditorApi* g_model_editor_api = nullptr; +namespace ONNX_NAMESPACE { +using int64s = google::protobuf::RepeatedField; +using float32s = google::protobuf::RepeatedField; +using StringStringEntryProtos = google::protobuf::RepeatedPtrField; +using TensorProtos = google::protobuf::RepeatedPtrField; +using TensorShapeProto_Dimensions = google::protobuf::RepeatedPtrField; +using ValueInfoProtos = google::protobuf::RepeatedPtrField; +using FunctionProtos = google::protobuf::RepeatedPtrField; +} // namespace ONNX_NAMESPACE + namespace trt_ep { void CUDA_RETURN_IF_ERROR(cudaError_t res) { @@ -1048,6 +1058,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this /* out */ OrtNode** ep_context_node) { TensorrtExecutionProvider* ep = static_cast(this_ptr); + // Comment out following code if you want the "large" initializers to be saved to a external file. /* //Save initializers to external file std::string ext_ini_file_path = "model_serialized.bin"; @@ -1735,10 +1746,15 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); - char* onnx = string_buf.data(); - size_t onnx_size = string_buf.size(); - auto status = RefitEngine(model_path_, onnx_model_folder_path_, engine_cache_path, - false /* path check for security */, onnx, onnx_size, trt_engine.get(), + auto status = RefitEngine(model_path_, + onnx_model_folder_path_, + engine_cache_path, + false /* path check for security */, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, + trt_engine.get(), true /* serialize refitted engine to disk */, detailed_build_log_); if (status != nullptr) { return ort_api.CreateStatus(ORT_EP_FAIL, "RefitEngine failed."); @@ -1887,6 +1903,8 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this onnx_model_folder_path_, onnx_model_bytestream_, onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, cache_prefix_, cache_suffix, engine_hw_compatible_, @@ -1920,6 +1938,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine // Get engine binary data and deserialize it std::unique_ptr ep_context_node_reader = std::make_unique(*ep, + logger_, &trt_engine, runtime_.get(), model_path_, @@ -2082,7 +2101,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ } OrtStatus* status; - if (EPContextNodeHelper::GraphHasCtxNode(graphs[fused_node_idx], ort_api)) { + if (EPContextNodeReader::GraphHasCtxNode(graphs[fused_node_idx], ort_api)) { RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[fused_node_idx], fused_node, input_map, output_map, &node_compute_infos_result[fused_node_idx])); @@ -2128,12 +2147,21 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CreateSyncStreamForDeviceImpl /** * Refit the weight-stripped engine */ -OrtStatus* TensorrtExecutionProvider::RefitEngine( - std::string onnx_model_filename, std::string& onnx_model_folder_path, std::string& weight_stripped_engine_cath_path, - bool path_check, const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, - nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, bool detailed_build_log) { +OrtStatus* TensorrtExecutionProvider::RefitEngine(std::string onnx_model_filename, + std::string& onnx_model_folder_path, + std::string& weight_stripped_engine_cath_path, + bool path_check, + const void* onnx_model_bytestream, + size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, + nvinfer1::ICudaEngine* trt_engine, + bool serialize_refitted_engine, + bool detailed_build_log) { #if NV_TENSORRT_MAJOR >= 10 bool refit_from_file = onnx_model_bytestream == nullptr && onnx_model_bytestream_size == 0; + bool refit_with_external_data = onnx_external_data_bytestream != nullptr && onnx_external_data_bytestream_size != 0; + bool refit_complete = false; std::filesystem::path onnx_model_path{onnx_model_folder_path}; if (refit_from_file) { if (!onnx_model_filename.empty()) { @@ -2143,12 +2171,11 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine( std::string err_msg = "The ONNX model was not provided as path. Please use provide an ONNX bytestream to enable refitting the weightless engine."; return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } else { - /* // check if file path to ONNX is legal if (path_check && IsAbsolutePath(onnx_model_path.string())) { std::string err_msg = "For security purpose, the ONNX model path should be set with a relative path, but it is an absolute path: " + onnx_model_path.string(); - "weightless engine."; + "weightless engine."; return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } if (path_check && IsRelativePathToParentPath(onnx_model_path.string())) { @@ -2156,7 +2183,6 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine( "The ONNX model path has '..'. For security purpose, it's not allowed to point outside the directory."; return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } - */ if (!(std::filesystem::exists(onnx_model_path) && std::filesystem::is_regular_file(onnx_model_path))) { std::string err_msg = "The ONNX model " + onnx_model_path.string() + " does not exist."; @@ -2170,30 +2196,164 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine( auto refitter = std::unique_ptr(nvinfer1::createInferRefitter(*trt_engine, trt_logger)); auto parser_refitter = std::unique_ptr(nvonnxparser::createParserRefitter(*refitter, trt_logger)); - if (refit_from_file) { - std::string message = "[TensorRT EP] Refitting from file on disk: " + onnx_model_path.string(); + +#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 12) || NV_TENSORRT_MAJOR > 10 + // New refit APIs + if (refit_with_external_data) { + // A valid model bytestream must be passed. + if (refit_from_file) { + std::string err_msg = "TensorRT EP's refit with external data must be called with a valid ONNX model bytestream"; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + if (!parser_refitter->loadModelProto(onnx_model_bytestream, onnx_model_bytestream_size, nullptr)) { + std::string err_msg = "TensorRT EP's IParserRefitter could not load model from provided onnx_model_bytestream"; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + // Extract weight information from the Refitter. + int required_weights = refitter->getAllWeights(0, nullptr); + std::vector refit_names(required_weights); + refitter->getAllWeights(required_weights, refit_names.data()); + std::string message = "[TensorRT EP] Refitter requires " + std::to_string(required_weights) + " weights"; Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); - if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { - std::string err_msg = - "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with " - "weights contained in: " + - onnx_model_path.string(); + + // Vectors to keep track of data pointers. + std::vector names; + names.reserve(required_weights); + std::vector bytes; + bytes.reserve(required_weights); + std::vector sizes; + sizes.reserve(required_weights); + + auto onnx_model = std::make_unique(); + ONNX_NAMESPACE::TensorProtos* allInitializers_byte_stream; + + // Reconstruct onnx model view. + const auto onnx_model_view = std::string((const char*)onnx_model_bytestream, + onnx_model_bytestream_size); + if (!onnx_model->ParseFromString(onnx_model_view)) { + std::string err_msg = "The provided ONNX bytestream to refit could not be parsed."; return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } - } else { - std::string message = "[TensorRT EP] Refitting from byte array"; + + // Extract graph and initializer information. + auto const& graph = onnx_model->mutable_graph(); + allInitializers_byte_stream = graph->mutable_initializer(); + message = "[TensorRT EP] Initializers that were found " + std::to_string(allInitializers_byte_stream->size()); Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); - if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) { - std::string err_msg = - "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with " - "weights contained in the provided bytestraem"; + + // Loop through all initializers + int missing_initializer_data = 0; + for (int initializer_idx = 0; initializer_idx < allInitializers_byte_stream->size(); ++initializer_idx) { + auto& proto = allInitializers_byte_stream->at(initializer_idx); + auto& proto_name = proto.name(); + bool weight_is_refittable = std::find(refit_names.begin(), refit_names.end(), proto_name) != refit_names.end(); + if (weight_is_refittable) { + if (proto.has_data_location()) { + if (proto.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + // Default values for reading into external_data blob. + int64_t offset = 0; + size_t length = 0; + auto external_data = proto.mutable_external_data(); + const std::string kOffset = "offset", kLength = "length"; + for (int entry_idx = 0; entry_idx < external_data->size(); ++entry_idx) { + auto current_key = external_data->at(entry_idx).mutable_key(); + auto current_value = external_data->at(entry_idx).mutable_value(); + if (*current_key == kOffset && !current_value->empty()) { + offset = std::stoll(*current_value); + } else if (*current_key == kLength && !current_value->empty()) { + length = std::stoul(*current_value); + } + } + names.push_back(proto.name()); + bytes.push_back(static_cast(onnx_external_data_bytestream) + offset); + sizes.push_back(length); + } else { + std::string err_msg = "[TensorRT EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead."; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } else if (proto.has_raw_data()) { + auto& raw_data = proto.raw_data(); + names.push_back(proto.name()); + bytes.push_back(raw_data.c_str()); + sizes.push_back(raw_data.size()); + } else { + message = "[TensorRT EP] Proto: " + proto_name + " has no raw nor external data."; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + ++missing_initializer_data; + } + } else { + message = "[TensorRT EP] Initializer with name: " + proto_name + " was not marked as refittable"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } + if (missing_initializer_data) { + std::string err_msg = "[TensorRT EP] RefitEngine is missing " + std::to_string(missing_initializer_data) + " initializers."; return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } + + // Load extracted initializers into the parser + if (!names.empty()) { + message = "[TensorRT EP] Number of initializers submitted to refitter " + std::to_string(names.size()); + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + for (size_t i = 0; i < names.size(); i++) { + bool refloadInit = parser_refitter->loadInitializer(names[i].c_str(), bytes[i], sizes[i]); + if (!refloadInit) { + std::string err_msg = "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestream"; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } + } + // Perform refit. + if (!parser_refitter->refitModelProto()) { + std::string err_msg = "TensorRT EP's IParserRefitter refitModelProto() failed with the provided external data bytestream."; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + refit_complete = true; + } +#else + // Refitting with external data is not supported prior to TensorRT 10.13. Log a warning in this case for the user. + if (refit_with_external_data) { + message = "[TensorRT EP] Refitting with an onnx_external_data_bytestream is only supported on TensorRT versions >= 10.13! This parameter will be ignored for refitting, and the resulting refitted engine may be incorrect."; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } +#endif // (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 12) || NV_TENSORRT_MAJOR > 10 + // If new refit flow was not completed, then fallback to refit_from_file. + if (!refit_complete) { + if (refit_from_file) { + std::string message = "[TensorRT EP] Refitting from file on disk: " + onnx_model_path.string(); + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { + std::string err_msg = "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string(); + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } else { + std::string message = "[TensorRT EP] Refitting from byte array"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) { + std::string err_msg = "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestream"; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } } + if (refitter->refitCudaEngine()) { std::string message = "[TensorRT EP] Successfully refitted the weight-stripped engine."; Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, @@ -2335,6 +2495,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa engine_cache_enable_ = info_.engine_cache_enable; weight_stripped_engine_enable_ = info_.weight_stripped_engine_enable; onnx_model_folder_path_ = info_.onnx_model_folder_path; + onnx_model_bytestream_ = info_.onnx_bytestream; + onnx_model_bytestream_size_ = info_.onnx_bytestream_size; + onnx_external_data_bytestream_ = info_.external_data_bytestream; + onnx_external_data_bytestream_size_ = info_.external_data_bytestream_size; timing_cache_enable_ = info_.timing_cache_enable; force_timing_cache_match_ = info_.force_timing_cache; detailed_build_log_ = info_.detailed_build_log; @@ -2674,6 +2838,8 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* auto onnx_model_folder_path = trt_state->onnx_model_folder_path; auto onnx_model_bytestream = trt_state->onnx_model_bytestream; auto onnx_model_bytestream_size = trt_state->onnx_model_bytestream_size; + auto onnx_external_data_bytestream = trt_state->onnx_external_data_bytestream; + auto onnx_external_data_bytestream_size = trt_state->onnx_external_data_bytestream_size; auto sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; @@ -2703,9 +2869,6 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* Ort::ThrowOnError(ep.ort_api.KernelContext_GetGPUComputeStream(kernel_context, &cuda_stream)); cudaStream_t stream = static_cast(cuda_stream); - // cudaStream_t stream; - // cudaStreamCreate(&stream); - // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even // if they share the same compute capacity Prepare cache name @@ -3101,20 +3264,22 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* } } - /* - // dump ep context model - if (dump_ep_context_model_ && ep_context_embed_mode_) { - UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast(serialized_engine->data()), - serialized_engine->size()); - DumpCtxModel(model_proto_.get(), ctx_model_path_); - } - */ + // TODO: In current ORT's EPContext design, there is no way TRT EP can update the engine cache binary in EPContext node with the rebuilt engine. + // The hacky way is to directly modify the EPContext model that graph_partitioner generates in session initialization. + context_update = true; if (weight_stripped_engine_refit) { auto status = - ep.RefitEngine(model_path, onnx_model_folder_path, engine_cache_path, false /* path check for security */, - onnx_model_bytestream, onnx_model_bytestream_size, trt_engine, + ep.RefitEngine(model_path, + onnx_model_folder_path, + engine_cache_path, + false /* path check for security */, + onnx_model_bytestream, + onnx_model_bytestream_size, + onnx_external_data_bytestream, + onnx_external_data_bytestream_size, + trt_engine, true /* serialize refitted engine to disk */, detailed_build_log); if (status != nullptr) { return ep.ort_api.CreateStatus(ORT_EP_FAIL, "RefitEngine failed."); @@ -3235,6 +3400,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* trt_context->setDeviceMemory((*context_memory).get()); } + // TODO: Add support for CUDA graph for plugin ep. /* // Start CUDA graph capture. // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because @@ -3315,6 +3481,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* } } + // TODO: Add support for CUDA graph for plugin ep. /* // End CUDA graph capture. // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream @@ -3509,6 +3676,7 @@ OrtStatus* TRTEpEpContextNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_p trt_context->setDeviceMemory((*context_memory).get()); } + // TODO: Add support for CUDA graph for plugin ep. /* // Start CUDA graph capture. // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because @@ -3589,6 +3757,7 @@ OrtStatus* TRTEpEpContextNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_p } } + // TODO: Add support for CUDA graph for plugin ep. /* // End CUDA graph capture. // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h index d75695bc..c6046988 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -166,13 +166,15 @@ struct TensorrtComputeState { std::string onnx_model_folder_path; const void* onnx_model_bytestream; size_t onnx_model_bytestream_size; + const void* onnx_external_data_bytestream; + size_t onnx_external_data_bytestream_size; std::string cache_prefix; std::string cache_suffix; bool engine_hw_compatible = false; bool sync_stream_after_enqueue = true; }; -// Minimum information to construct kernel function state for direct engine load code path +// Minimum information to construct kernel function state for EPContext workflow struct TensorrtComputeStateForEPContext { uint32_t device_id; std::string fused_node_name; @@ -229,35 +231,22 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { OrtNodeComputeInfo** node_compute_info, OrtNode** ep_context_node); - OrtStatus* RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path, - std::string& weight_stripped_engine_cath_path, bool path_check, - const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, - nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, + OrtStatus* RefitEngine(std::string onnx_model_filename, + std::string& onnx_model_folder_path, + std::string& weight_stripped_engine_cath_path, + bool path_check, + const void* onnx_model_bytestream, + size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, + nvinfer1::ICudaEngine* trt_engine, + bool serialize_refitted_engine, bool detailed_build_log); std::unordered_map& GetDDSOutputAllocators() { return dds_output_allocator_maps_; } - /* - bool IsGraphCaptured(int graph_annotation_id) const { return false; } - - static OrtStatusPtr RefitEngine(std::string onnx_model_filename, - std::string& onnx_model_folder_path, - std::string& weight_stripped_engine_cath_path, - bool path_check, - nvinfer1::ICudaEngine* trt_engine, - bool serialize_refitted_engine, - bool detailed_build_log); - - std::unique_ptr GetSubGraph(SubGraph_t graph_nodes_index, - const OrtGraph* graph, const HashValue& model_hash, int subgraph_index) const; - SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations, const int max_iterations, - const OrtGraph* graph, bool* early_termination) const; - - bool DetectTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const OrtGraphViewer* graph, const HashValue& model_hash, bool remove_cycles = true) const; - */ - /** Get a unique_lock object to control the concurrency behavior. Every api call not in the thread-safe operations(https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading) @@ -265,15 +254,6 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { */ std::unique_lock GetApiLock() const; - /**Check the graph is the subgraph of control flow op*/ - // bool IsSubGraphOfControlFlowOp(const OrtGraphViewer* graph) const; - - /**Check whether all the nodes of the graph are assigned to specific ep*/ - // bool AllNodesAssignedToSpecificEP(const OrtGraphViewer* graph, const std::string& provider_type) const; - - /**Check whether all the nodes of subgraph are supported*/ - // bool IsSubGraphFullySupported(SubGraphCollection_t supported_nodes_vector, const int number_of_ort_nodes) const; - std::unordered_map trt_node_name_with_precision_; std::unordered_map> dynamic_range_map_; std::unordered_map cache_suffix_; @@ -360,10 +340,6 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; - // std::unique_ptr model_proto_ = ONNX_NAMESPACE::ModelProto::Create(); - - // mutable std::unordered_map> subgraph_context_map_; - mutable std::unique_ptr builder_; // Following maps that hold TRT objects will be accessible by different threads if ORT is using multithreading. @@ -384,30 +360,37 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { std::unordered_map> profiles_; std::unordered_map dds_output_allocator_maps_; + // TODO: Add support for external cudnn and cublas. // for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture - // cudnnHandle_t external_cudnn_handle_ = nullptr; - // cublasHandle_t external_cublas_handle_ = nullptr; + // cudnnHandle_t external_cudnn_handle_ = nullptr; + // cublasHandle_t external_cublas_handle_ = nullptr; // Call cudaStreamSynchronize() after TRT enqueueV3() mutable bool sync_stream_after_enqueue_ = true; - // CUDAGraph cuda_graph_; - // bool is_graph_captured_ = false; + // TODO: Add support for CUDA graph for plugin ep. + /* + CUDAGraph cuda_graph_; + bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: // (1) memory pattern is enabled. (2) arena allocation for stream. // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs // to allocate enough memory in Arena before graph capturing. const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + */ bool IsGraphCaptureAllowed() const { return false; }; nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const; + /**Check whether all the nodes of the graph are assigned to specific ep*/ bool AllNodesAssignedToSpecificEP(const OrtGraph* graph, const std::string& provider_type) const; + /**Check the graph is the subgraph of control flow op*/ bool IsSubGraphOfControlFlowOp(const OrtGraph* graph) const; + /**Check whether all the nodes of subgraph are supported*/ bool IsSubGraphFullySupported(const OrtGraph* graph, SubGraphCollection_t supported_nodes_vector) const; }; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc index fece820a..17c65ef4 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc @@ -52,6 +52,11 @@ constexpr const char* kEpContextEmbedMode = "trt_ep_context_embed_mode"; constexpr const char* kEpContextFilePath = "trt_ep_context_file_path"; constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model"; constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible"; +constexpr const char* kONNXBytestream = "trt_onnx_bytestream"; +constexpr const char* kONNXBytestreamSize = "trt_onnx_bytestream_size"; +constexpr const char* kExternalDataBytestream = "trt_external_data_bytestream"; +constexpr const char* kExternalDataBytestreamSize = "trt_external_data_bytestream_size"; +constexpr const char* kOpTypesToExclude = "trt_op_types_to_exclude"; } // namespace provider_option_names } // namespace tensorrt @@ -60,6 +65,8 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions TensorrtExecutionProviderInfo info{}; void* user_compute_stream = nullptr; + void* onnx_bytestream = nullptr; + void* external_data_bytestream = nullptr; THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( @@ -121,218 +128,30 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextFilePath, info.ep_context_file_path) .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextEmbedMode, info.ep_context_embed_mode) .AddAssignmentToReference(tensorrt::provider_option_names::kEngineHwCompatible, info.engine_hw_compatible) + .AddValueParser( + tensorrt::provider_option_names::kONNXBytestream, + [&onnx_bytestream](const std::string& value_str) -> OrtStatus* { + size_t address; + RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + onnx_bytestream = reinterpret_cast(address); + return nullptr; + }) + .AddAssignmentToReference(tensorrt::provider_option_names::kONNXBytestreamSize, info.onnx_bytestream_size) + .AddValueParser( + tensorrt::provider_option_names::kExternalDataBytestream, + [&external_data_bytestream](const std::string& value_str) -> OrtStatus* { + size_t address; + RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + external_data_bytestream = reinterpret_cast(address); + return nullptr; + }) + .AddAssignmentToReference(tensorrt::provider_option_names::kExternalDataBytestreamSize, info.external_data_bytestream_size) + .AddAssignmentToReference(tensorrt::provider_option_names::kOpTypesToExclude, info.op_types_to_exclude) .Parse(options)); // add new provider option here. info.user_compute_stream = user_compute_stream; info.has_user_compute_stream = (user_compute_stream != nullptr); + info.onnx_bytestream = onnx_bytestream; + info.external_data_bytestream = external_data_bytestream; return info; -} - -// ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtExecutionProviderInfo& info) { -// const ProviderOptions options{ -// {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, -// {tensorrt::provider_option_names::kMaxPartitionIterations, MakeStringWithClassicLocale(info.max_partition_iterations)}, -// {tensorrt::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, -// {tensorrt::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, -// {tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.min_subgraph_size)}, -// {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.max_workspace_size)}, -// {tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, -// {tensorrt::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, -// {tensorrt::provider_option_names::kInt8CalibTable, MakeStringWithClassicLocale(info.int8_calibration_table_name)}, -// {tensorrt::provider_option_names::kInt8UseNativeCalibTable, MakeStringWithClassicLocale(info.int8_use_native_calibration_table)}, -// {tensorrt::provider_option_names::kDLAEnable, MakeStringWithClassicLocale(info.dla_enable)}, -// {tensorrt::provider_option_names::kDLACore, MakeStringWithClassicLocale(info.dla_core)}, -// {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.dump_subgraphs)}, -// {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.engine_cache_enable)}, -// {tensorrt::provider_option_names::kEngineCachePath, MakeStringWithClassicLocale(info.engine_cache_path)}, -// {tensorrt::provider_option_names::kWeightStrippedEngineEnable, MakeStringWithClassicLocale(info.weight_stripped_engine_enable)}, -// {tensorrt::provider_option_names::kOnnxModelFolderPath, MakeStringWithClassicLocale(info.onnx_model_folder_path)}, -// {tensorrt::provider_option_names::kEngineCachePrefix, MakeStringWithClassicLocale(info.engine_cache_prefix)}, -// {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)}, -// {tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)}, -// {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)}, -// // add new provider option here. -// {tensorrt::provider_option_names::kContextMemorySharingEnable, MakeStringWithClassicLocale(info.context_memory_sharing_enable)}, -// {tensorrt::provider_option_names::kLayerNormFP32Fallback, MakeStringWithClassicLocale(info.layer_norm_fp32_fallback)}, -// {tensorrt::provider_option_names::kTimingCacheEnable, MakeStringWithClassicLocale(info.timing_cache_enable)}, -// {tensorrt::provider_option_names::kTimingCachePath, MakeStringWithClassicLocale(info.timing_cache_path)}, -// {tensorrt::provider_option_names::kForceTimingCacheMatch, MakeStringWithClassicLocale(info.force_timing_cache)}, -// {tensorrt::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.detailed_build_log)}, -// {tensorrt::provider_option_names::kBuildHeuristics, MakeStringWithClassicLocale(info.build_heuristics_enable)}, -// {tensorrt::provider_option_names::kSparsityEnable, MakeStringWithClassicLocale(info.sparsity_enable)}, -// {tensorrt::provider_option_names::kBuilderOptimizationLevel, MakeStringWithClassicLocale(info.builder_optimization_level)}, -// {tensorrt::provider_option_names::kAuxiliaryStreams, MakeStringWithClassicLocale(info.auxiliary_streams)}, -// {tensorrt::provider_option_names::kTacticSources, MakeStringWithClassicLocale(info.tactic_sources)}, -// {tensorrt::provider_option_names::kExtraPluginLibPaths, MakeStringWithClassicLocale(info.extra_plugin_lib_paths)}, -// {tensorrt::provider_option_names::kProfilesMinShapes, MakeStringWithClassicLocale(info.profile_min_shapes)}, -// {tensorrt::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)}, -// {tensorrt::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)}, -// {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)}, -// {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.dump_ep_context_model)}, -// {tensorrt::provider_option_names::kEpContextFilePath, MakeStringWithClassicLocale(info.ep_context_file_path)}, -// {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.ep_context_embed_mode)}, -// {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.engine_hw_compatible)}, -// }; -// return options; -// } -// -// ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensorRTProviderOptionsV2& info) { -// auto empty_if_null = [](const char* s) { return s != nullptr ? std::string{s} : std::string{}; }; -// const std::string kInt8CalibTable_ = empty_if_null(info.trt_int8_calibration_table_name); -// const std::string kEngineCachePath_ = empty_if_null(info.trt_engine_cache_path); -// const std::string kEngineCachePrefix_ = empty_if_null(info.trt_engine_cache_prefix); -// const std::string kTimingCachePath_ = empty_if_null(info.trt_timing_cache_path); -// const std::string kTacticSources_ = empty_if_null(info.trt_tactic_sources); -// const std::string kDecryptionLibPath_ = empty_if_null(info.trt_engine_decryption_lib_path); -// const std::string kExtraPluginLibPaths_ = empty_if_null(info.trt_extra_plugin_lib_paths); -// const std::string kProfilesMinShapes_ = empty_if_null(info.trt_profile_min_shapes); -// const std::string kProfilesMaxShapes_ = empty_if_null(info.trt_profile_max_shapes); -// const std::string kProfilesOptShapes_ = empty_if_null(info.trt_profile_opt_shapes); -// const std::string kEpContextFilePath_ = empty_if_null(info.trt_ep_context_file_path); -// const std::string kOnnxModelFolderPath_ = empty_if_null(info.trt_onnx_model_folder_path); -// -// const ProviderOptions options{ -// {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, -// {tensorrt::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, -// {tensorrt::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, -// {tensorrt::provider_option_names::kMaxPartitionIterations, MakeStringWithClassicLocale(info.trt_max_partition_iterations)}, -// {tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.trt_min_subgraph_size)}, -// {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.trt_max_workspace_size)}, -// {tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.trt_fp16_enable)}, -// {tensorrt::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.trt_int8_enable)}, -// {tensorrt::provider_option_names::kInt8CalibTable, kInt8CalibTable_}, -// {tensorrt::provider_option_names::kInt8UseNativeCalibTable, MakeStringWithClassicLocale(info.trt_int8_use_native_calibration_table)}, -// {tensorrt::provider_option_names::kDLAEnable, MakeStringWithClassicLocale(info.trt_dla_enable)}, -// {tensorrt::provider_option_names::kDLACore, MakeStringWithClassicLocale(info.trt_dla_core)}, -// {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.trt_dump_subgraphs)}, -// {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.trt_engine_cache_enable)}, -// {tensorrt::provider_option_names::kEngineCachePath, kEngineCachePath_}, -// {tensorrt::provider_option_names::kEngineCachePrefix, kEngineCachePrefix_}, -// {tensorrt::provider_option_names::kWeightStrippedEngineEnable, MakeStringWithClassicLocale(info.trt_weight_stripped_engine_enable)}, -// {tensorrt::provider_option_names::kOnnxModelFolderPath, kOnnxModelFolderPath_}, -// {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.trt_engine_decryption_enable)}, -// {tensorrt::provider_option_names::kDecryptionLibPath, kDecryptionLibPath_}, -// {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.trt_force_sequential_engine_build)}, -// {tensorrt::provider_option_names::kContextMemorySharingEnable, MakeStringWithClassicLocale(info.trt_context_memory_sharing_enable)}, -// {tensorrt::provider_option_names::kLayerNormFP32Fallback, MakeStringWithClassicLocale(info.trt_layer_norm_fp32_fallback)}, -// {tensorrt::provider_option_names::kTimingCacheEnable, MakeStringWithClassicLocale(info.trt_timing_cache_enable)}, -// {tensorrt::provider_option_names::kTimingCachePath, kTimingCachePath_}, -// {tensorrt::provider_option_names::kForceTimingCacheMatch, MakeStringWithClassicLocale(info.trt_force_timing_cache)}, -// {tensorrt::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.trt_detailed_build_log)}, -// {tensorrt::provider_option_names::kBuildHeuristics, MakeStringWithClassicLocale(info.trt_build_heuristics_enable)}, -// {tensorrt::provider_option_names::kSparsityEnable, MakeStringWithClassicLocale(info.trt_sparsity_enable)}, -// {tensorrt::provider_option_names::kBuilderOptimizationLevel, MakeStringWithClassicLocale(info.trt_builder_optimization_level)}, -// {tensorrt::provider_option_names::kAuxiliaryStreams, MakeStringWithClassicLocale(info.trt_auxiliary_streams)}, -// {tensorrt::provider_option_names::kTacticSources, kTacticSources_}, -// {tensorrt::provider_option_names::kExtraPluginLibPaths, kExtraPluginLibPaths_}, -// {tensorrt::provider_option_names::kProfilesMinShapes, kProfilesMinShapes_}, -// {tensorrt::provider_option_names::kProfilesMaxShapes, kProfilesMaxShapes_}, -// {tensorrt::provider_option_names::kProfilesOptShapes, kProfilesOptShapes_}, -// {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.trt_cuda_graph_enable)}, -// {tensorrt::provider_option_names::kEpContextFilePath, kEpContextFilePath_}, -// {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.trt_dump_ep_context_model)}, -// {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.trt_ep_context_embed_mode)}, -// {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.trt_engine_hw_compatible)}, -// }; -// return options; -// } -// -///** -// * Update OrtTensorRTProviderOptionsV2 instance with ProviderOptions (map of string-based key-value pairs) -// * -// * Please note that it will reset the OrtTensorRTProviderOptionsV2 instance first and then set up the provided provider options -// * See TensorrtExecutionProviderInfo::FromProviderOptions() for more details. This function will be called by the C API UpdateTensorRTProviderOptions() also. -// * -// * \param provider_options - a pointer to OrtTensorRTProviderOptionsV2 instance -// * \param options - a reference to ProviderOptions instance -// * \param string_copy - if it's true, it uses strncpy() to copy 'provider option' string from ProviderOptions instance to where the 'provider option' const char pointer in OrtTensorRTProviderOptionsV2 instance points to. -// * it it's false, it only saves the pointer and no strncpy(). -// * -// * Note: If there is strncpy involved, please remember to deallocate or simply call C API ReleaseTensorRTProviderOptions. -// */ -// void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy) { -// if (provider_options == nullptr) { -// return; -// } -// auto copy_string_if_needed = [&](std::string& s_in) { -// if (string_copy) { -// char* dest = nullptr; -// auto str_size = s_in.size(); -// if (str_size == 0) { -// return (const char*)nullptr; -// } else { -// dest = new char[str_size + 1]; -// #ifdef _MSC_VER -// strncpy_s(dest, str_size + 1, s_in.c_str(), str_size); -// #else -// strncpy(dest, s_in.c_str(), str_size); -// #endif -// dest[str_size] = '\0'; -// return (const char*)dest; -// } -// } else { -// return s_in.c_str(); -// } -// }; -// -// TensorrtExecutionProviderInfo internal_options = onnxruntime::TensorrtExecutionProviderInfo::FromProviderOptions(options); -// auto& trt_provider_options_v2 = *reinterpret_cast(provider_options); -// trt_provider_options_v2.device_id = internal_options.device_id; -// -// // The 'has_user_compute_stream' of the OrtTensorRTProviderOptionsV2 instance can be set by C API UpdateTensorRTProviderOptionsWithValue() as well -// // We only set the 'has_user_compute_stream' of the OrtTensorRTProviderOptionsV2 instance if it is provided in options or user_compute_stream is provided -// if (options.find("has_user_compute_stream") != options.end()) { -// trt_provider_options_v2.has_user_compute_stream = internal_options.has_user_compute_stream; -// } -// if (options.find("user_compute_stream") != options.end() && internal_options.user_compute_stream != nullptr) { -// trt_provider_options_v2.user_compute_stream = internal_options.user_compute_stream; -// trt_provider_options_v2.has_user_compute_stream = true; -// } -// -// trt_provider_options_v2.trt_max_partition_iterations = internal_options.max_partition_iterations; -// trt_provider_options_v2.trt_min_subgraph_size = internal_options.min_subgraph_size; -// trt_provider_options_v2.trt_max_workspace_size = internal_options.max_workspace_size; -// trt_provider_options_v2.trt_fp16_enable = internal_options.fp16_enable; -// trt_provider_options_v2.trt_int8_enable = internal_options.int8_enable; -// -// trt_provider_options_v2.trt_int8_calibration_table_name = copy_string_if_needed(internal_options.int8_calibration_table_name); -// -// trt_provider_options_v2.trt_int8_use_native_calibration_table = internal_options.int8_use_native_calibration_table; -// trt_provider_options_v2.trt_dla_enable = internal_options.dla_enable; -// trt_provider_options_v2.trt_dla_core = internal_options.dla_core; -// trt_provider_options_v2.trt_dump_subgraphs = internal_options.dump_subgraphs; -// trt_provider_options_v2.trt_engine_cache_enable = internal_options.engine_cache_enable; -// trt_provider_options_v2.trt_weight_stripped_engine_enable = internal_options.weight_stripped_engine_enable; -// trt_provider_options_v2.trt_onnx_model_folder_path = copy_string_if_needed(internal_options.onnx_model_folder_path); -// -// trt_provider_options_v2.trt_engine_cache_path = copy_string_if_needed(internal_options.engine_cache_path); -// trt_provider_options_v2.trt_engine_cache_prefix = copy_string_if_needed(internal_options.engine_cache_prefix); -// trt_provider_options_v2.trt_timing_cache_path = copy_string_if_needed(internal_options.timing_cache_path); -// -// trt_provider_options_v2.trt_engine_decryption_enable = internal_options.engine_decryption_enable; -// -// trt_provider_options_v2.trt_engine_decryption_lib_path = copy_string_if_needed(internal_options.engine_decryption_lib_path); -// -// trt_provider_options_v2.trt_force_sequential_engine_build = internal_options.force_sequential_engine_build; -// trt_provider_options_v2.trt_context_memory_sharing_enable = internal_options.context_memory_sharing_enable; -// trt_provider_options_v2.trt_layer_norm_fp32_fallback = internal_options.layer_norm_fp32_fallback; -// trt_provider_options_v2.trt_timing_cache_enable = internal_options.timing_cache_enable; -// trt_provider_options_v2.trt_force_timing_cache = internal_options.force_timing_cache; -// trt_provider_options_v2.trt_detailed_build_log = internal_options.detailed_build_log; -// trt_provider_options_v2.trt_build_heuristics_enable = internal_options.build_heuristics_enable; -// trt_provider_options_v2.trt_sparsity_enable = internal_options.sparsity_enable; -// trt_provider_options_v2.trt_builder_optimization_level = internal_options.builder_optimization_level; -// trt_provider_options_v2.trt_auxiliary_streams = internal_options.auxiliary_streams; -// -// trt_provider_options_v2.trt_tactic_sources = copy_string_if_needed(internal_options.tactic_sources); -// trt_provider_options_v2.trt_extra_plugin_lib_paths = copy_string_if_needed(internal_options.extra_plugin_lib_paths); -// trt_provider_options_v2.trt_profile_min_shapes = copy_string_if_needed(internal_options.profile_min_shapes); -// trt_provider_options_v2.trt_profile_max_shapes = copy_string_if_needed(internal_options.profile_max_shapes); -// trt_provider_options_v2.trt_profile_opt_shapes = copy_string_if_needed(internal_options.profile_opt_shapes); -// -// trt_provider_options_v2.trt_cuda_graph_enable = internal_options.cuda_graph_enable; -// trt_provider_options_v2.trt_dump_ep_context_model = internal_options.dump_ep_context_model; -// trt_provider_options_v2.trt_ep_context_embed_mode = internal_options.ep_context_embed_mode; -// trt_provider_options_v2.trt_ep_context_file_path = copy_string_if_needed(internal_options.ep_context_file_path); -// trt_provider_options_v2.trt_engine_hw_compatible = internal_options.engine_hw_compatible; -//} +} \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h index e0596e42..df315cf9 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h @@ -7,8 +7,6 @@ #include -#define TRT_DEFAULT_OPTIMIZER_LEVEL 3 - // Information needed to construct trt execution providers. struct TensorrtExecutionProviderInfo { int device_id{0}; @@ -29,6 +27,10 @@ struct TensorrtExecutionProviderInfo { std::string engine_cache_path{""}; bool weight_stripped_engine_enable{false}; std::string onnx_model_folder_path{""}; + const void* onnx_bytestream{nullptr}; + size_t onnx_bytestream_size{0}; + const void* external_data_bytestream{nullptr}; + size_t external_data_bytestream_size{0}; bool engine_decryption_enable{false}; std::string engine_decryption_lib_path{""}; bool force_sequential_engine_build{false}; @@ -53,11 +55,7 @@ struct TensorrtExecutionProviderInfo { int ep_context_embed_mode{0}; std::string engine_cache_prefix{""}; bool engine_hw_compatible{false}; + std::string op_types_to_exclude{""}; static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); - // static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info); - // static ProviderOptions ToProviderOptions(const OrtTensorRTProviderOptionsV2& info); - // static void UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy); - // - // std::vector custom_op_domain_list; };