-
Notifications
You must be signed in to change notification settings - Fork 389
Plugin TensorRT EP using ORT EP ABI #527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
chilo-ms
wants to merge
60
commits into
main
Choose a base branch
from
chi/plugin_trt_ep_impl
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 57 commits
Commits
Show all changes
60 commits
Select commit
Hold shift + click to select a range
36c0dc1
plugin TRT EP init
chilo-ms ed65a9f
clean up GetCapabilityImpl and make it pass compiler for now
chilo-ms 3269f73
Clean up CompileImpl
chilo-ms 4da9f90
update ep factory
chilo-ms 1928767
update ep factory
chilo-ms 4f5ffcb
update ep factory
chilo-ms bc64bdc
clean up and add back onnx_ctx_model_helper.cc
chilo-ms c4437a2
clean up
chilo-ms a5a294e
remove onnxruntime namespace
chilo-ms f990a7b
update
chilo-ms 7851a1c
Add TRTEpNodeComputeInfo
chilo-ms be453b1
add allocator and data transfer
chilo-ms 3d6fa57
fix a lot of compile errors
chilo-ms c8e3d6f
call EpDevice_AddAllocatorInfo in GetSupportedDevicesImpl
chilo-ms 3c43029
temporary way to get provider option without proper API
chilo-ms 549b29d
Clean up cmake file to remove dependencies that built with ORT
chilo-ms 3ad7736
Update CompileImpl
chilo-ms 3ced4cf
add ort_graph_to_proto.h and leverage OrtGraphToProto utilities
chilo-ms 081de36
update EP context model helper
chilo-ms 75240a4
Convert onnxruntime::Status to OrtStatus
chilo-ms f73420f
remove unused files
chilo-ms 938a3fe
use GetSessionOptionsConfigEntries to get provider options
chilo-ms 731ed72
fix a bunch of compile errors
chilo-ms 30e0f91
update memory info and data transfer in TRT EP's factor to accommodat…
chilo-ms f443a33
update cuda/pinned allocator to make compiler happy
chilo-ms 95dd71e
add GetVersionImpl in factory
chilo-ms 35b0cf1
update data transfer initialization in TRT EP
chilo-ms a65908f
Fix compile errors/issues
chilo-ms c77391f
fix to use correct API
chilo-ms c5363e6
fix bug for gpu data transfer implementation
chilo-ms 09138ee
clean up
chilo-ms a8dde45
remove unnecessary files
chilo-ms b911754
Temporarily manually creates cudaStream to run
chilo-ms 0c817ac
Temporary make plugin TRT links against the protobuf, onnx, flatbuffe…
chilo-ms da729f9
fix the issue of error LNK2038: mismatch detected for 'RuntimeLibrary…
chilo-ms 6fd38c3
refactor memory info stored in factory
chilo-ms 7467c65
update as onnxruntime_ep_c_api.h changes
chilo-ms da0f9c6
Add support for dump and run EP Context model
chilo-ms ccf20da
update and sync with latest ep c api
chilo-ms cca956d
remove delete resource in TRTEpDataTransfer::ReleaseImpl
chilo-ms 404cd4e
update cmake file to force dynamic release CRT globally for all depen…
chilo-ms c58130b
use updated Value_GetMemoryDevice API
chilo-ms 5828e10
update ort to graph util
chilo-ms 832a7f4
Add EP API Stream support
chilo-ms edd4b34
Update CMakeLists.txt
chilo-ms 5f46b68
fix mem leak for OrtAllocator
chilo-ms e81d395
add missing header file
chilo-ms 1211cd6
fix build issue on Linux
chilo-ms 0a8be0d
lintrunner -a
chilo-ms e4c2405
Update to use new API OpAttr_GetTensorAttributeAsOrtValue
chilo-ms 2472a15
remove unnecessary files
chilo-ms ab8cd70
Add default logger for TRT logger
chilo-ms 12d2306
Add default logger for TRT EP
chilo-ms c6ae7b6
update include path in utility function header
chilo-ms 6b180a4
Add default logger for TRT EP (cont.)
chilo-ms b3ac797
put code under namespace trt_ep
chilo-ms 632d224
remove unnecessary files
chilo-ms 4d32867
update GetCapabilityImpl()
chilo-ms ae9686f
Add code for updating cache path for EPContext node
chilo-ms c8a6ae6
add onnx_external_data_bytestream support for refitting the engine
chilo-ms File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# usage: | ||
# cd build/ | ||
# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DORT_HOME=/home/lochi/onnxruntime-win-x64-gpu-1.23.0 -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DTENSORRT_HOME=/home/lochi/tensorrt/TensorRT-10.3.0.26 -DCMAKE_POSITION_INDEPENDENT_CODE=ON (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") | ||
# cmake --build ./ --config Debug | ||
cmake_minimum_required(VERSION 3.26) | ||
project(TensorRTEp VERSION 1.0) | ||
set(CMAKE_CXX_STANDARD 17) | ||
|
||
enable_language(CUDA) # via nvcc to get the CUDA tool kit | ||
file(TO_CMAKE_PATH "/usr/local/cuda" CUDAToolkit_ROOT) | ||
find_package(CUDAToolkit REQUIRED) | ||
|
||
# CMake config to force dynamic debug CRT or dynamic release CRT globally for all dependencies. | ||
# This is to address the issue of: | ||
# libprotobufd.lib(common.obj) : error LNK2038: mismatch detected for 'RuntimeLibrary': value 'MTd_StaticDebug' doesn't match value 'MDd_DynamicDebug' in unary_elementwise_ops_impl.obj | ||
if (WIN32) | ||
if(CMAKE_BUILD_TYPE STREQUAL "Debug") | ||
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreadedDebugDLL" CACHE STRING "" FORCE) # /MDd | ||
set(BUILD_SHARED_LIBS OFF) # Build protobuf as static .lib, but using dynamic runtime | ||
endif() | ||
|
||
if(CMAKE_BUILD_TYPE STREQUAL "Release") | ||
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreadedDLL" CACHE STRING "" FORCE) | ||
set(BUILD_SHARED_LIBS OFF) # Build protobuf as static .lib, but using dynamic runtime | ||
endif() | ||
endif() | ||
|
||
add_definitions(-DONNX_NAMESPACE=onnx) | ||
add_definitions(-DONNX_ML) | ||
add_definitions(-DNV_TENSORRT_MAJOR=10) | ||
add_definitions(-DNOMINMAX) | ||
file(GLOB tensorrt_src "./*.cc" "./utils/*.cc" "./cuda/unary_elementwise_ops_impl.cu" "./*.h") | ||
add_library(TensorRTEp SHARED ${tensorrt_src}) | ||
|
||
if (NOT ORT_HOME) | ||
message(FATAL_ERROR "Please specify ORT_HOME, e.g. -DORT_HOME=/path/to/ort/") | ||
endif() | ||
|
||
if (NOT TENSORRT_HOME) | ||
message(FATAL_ERROR "Please specify TENSORRT_HOME, e.g. -DTENSORRT_HOME=/path/to/trt/") | ||
endif() | ||
|
||
# Use release mode if not specified | ||
if (NOT CMAKE_BUILD_TYPE) | ||
set(CMAKE_BUILD_TYPE "Release") | ||
endif() | ||
|
||
# Add dependencies | ||
include(FetchContent) | ||
|
||
# Add protobuf | ||
FetchContent_Declare( | ||
protobuf | ||
GIT_REPOSITORY https://github.com/protocolbuffers/protobuf.git | ||
GIT_TAG v21.12 # Use a specific tag or commit | ||
) | ||
|
||
if (WIN32) | ||
# Sometimes, protobuf ignores CMAKE_MSVC_RUNTIME_LIBRARY. To ensure it works: | ||
set(protobuf_MSVC_STATIC_RUNTIME OFF CACHE BOOL "" FORCE) | ||
endif() | ||
|
||
FetchContent_MakeAvailable(protobuf) | ||
|
||
# Add ONNX | ||
FetchContent_Declare( | ||
onnx | ||
GIT_REPOSITORY https://github.com/onnx/onnx.git | ||
GIT_TAG v1.18.0 # Use a specific tag or commit | ||
) | ||
|
||
FetchContent_MakeAvailable(onnx) | ||
|
||
# Add GSL | ||
FetchContent_Declare( | ||
gsl | ||
GIT_REPOSITORY https://github.com/microsoft/GSL.git | ||
GIT_TAG v4.0.0 # Use a specific tag or commit | ||
) | ||
|
||
FetchContent_MakeAvailable(gsl) | ||
|
||
# Add flatbuffers | ||
FetchContent_Declare( | ||
flatbuffers | ||
GIT_REPOSITORY https://github.com/google/flatbuffers.git | ||
GIT_TAG v23.5.26 # Use a specific tag or commit | ||
) | ||
|
||
FetchContent_MakeAvailable(flatbuffers) | ||
|
||
set(DEPS_PATH "${CMAKE_BINARY_DIR}/_deps") | ||
|
||
if (WIN32) # Windows | ||
set(ORT_LIB "${ORT_HOME}/lib/onnxruntime.lib") | ||
set(TRT_LIBS "${TENSORRT_HOME}/lib/nvinfer_10.lib" | ||
"${TENSORRT_HOME}/lib/nvinfer_plugin_10.lib" | ||
"${TENSORRT_HOME}/lib/nvonnxparser_10.lib") | ||
|
||
if(CMAKE_BUILD_TYPE STREQUAL "Debug") | ||
set(DEPS_LIBS ${DEPS_LIBS} | ||
"${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobufd.lib" | ||
"${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotocd.lib") | ||
else() | ||
set(DEPS_LIBS ${DEPS_LIBS} | ||
"${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobuf.lib" | ||
"${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotoc.lib") | ||
endif() | ||
|
||
set(DEPS_LIBS "${DEPS_PATH}/flatbuffers-build/${CMAKE_BUILD_TYPE}/flatbuffers.lib" | ||
"${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx.lib" | ||
"${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx_proto.lib") | ||
|
||
set(TRT_EP_LIB_LINK_FLAG | ||
"-DEF:${CMAKE_SOURCE_DIR}/tensorrt_execution_provider.def") | ||
|
||
else() # Linux | ||
set(ORT_LIB "${ORT_HOME}/lib/libonnxruntime.so") | ||
set(TRT_LIBS "${TENSORRT_HOME}/lib/libnvinfer.so" | ||
"${TENSORRT_HOME}/lib/libnvinfer_plugin.so" | ||
"${TENSORRT_HOME}/lib/libnvonnxparser.so") | ||
set(DEPS_LIBS "${DEPS_PATH}/flatbuffers-build/libflatbuffers.a" | ||
"${DEPS_PATH}/onnx-build/libonnx.a" | ||
"${DEPS_PATH}/onnx-build/libonnx_proto.a") | ||
|
||
if(CMAKE_BUILD_TYPE STREQUAL "Debug") | ||
set(DEPS_LIBS ${DEPS_LIBS} | ||
"${DEPS_PATH}/protobuf-build/libprotobufd.a" | ||
"${DEPS_PATH}/protobuf-build/libprotocd.a") | ||
else() | ||
set(DEPS_LIBS ${DEPS_LIBS} | ||
"${DEPS_PATH}/protobuf-build/libprotobuf.a" | ||
"${DEPS_PATH}/protobuf-build/libprotoc.a") | ||
endif() | ||
endif() | ||
|
||
MESSAGE(STATUS "Looking for following dependencies ...") | ||
MESSAGE(STATUS "ORT lib : ${ORT_LIB}") | ||
MESSAGE(STATUS "TRT libs : ${TRT_LIBS}") | ||
MESSAGE(STATUS "Deps libs: ${DEPS_LIBS}") | ||
|
||
set_property(TARGET TensorRTEp APPEND_STRING PROPERTY LINK_FLAGS | ||
${TRT_EP_LIB_LINK_FLAG}) | ||
|
||
target_include_directories(TensorRTEp PUBLIC "${ORT_HOME}/include" | ||
"./utils" | ||
"/usr/local/cuda/include" | ||
"${TENSORRT_HOME}/include" | ||
"${DEPS_PATH}/flatbuffers-src/include" | ||
"${DEPS_PATH}/gsl-src/include" # GSL is header-only | ||
"${DEPS_PATH}/onnx-src" | ||
"${DEPS_PATH}/onnx-build" | ||
"${DEPS_PATH}/protobuf-src/src" | ||
) | ||
|
||
target_link_libraries(TensorRTEp PUBLIC #${DEPS_LIBS} | ||
protobuf::libprotobuf onnx flatbuffers | ||
${ORT_LIB} | ||
${TRT_LIBS} | ||
CUDA::cudart | ||
) |
76 changes: 76 additions & 0 deletions
76
plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include <stdint.h> | ||
|
||
namespace cuda { | ||
|
||
// We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer | ||
// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. | ||
#ifndef CUDA_LONG | ||
#define CUDA_LONG int32_t | ||
#endif | ||
|
||
template <class INT, class INT2> | ||
inline __host__ __device__ INT CeilDiv(INT a, INT2 b) // ceil(a/b) | ||
{ | ||
return (INT)(((size_t)a + (size_t)b - 1) / (size_t)b); // these size_t casts are necessary since b may be INT_MAX (for maxGridSize[]) | ||
} | ||
|
||
struct GridDim { | ||
enum : CUDA_LONG { | ||
maxThreadsPerBlock = 256, // max threads per block | ||
maxElementsPerThread = 4, // max element processed per thread | ||
}; | ||
}; | ||
|
||
template <typename InT, typename OutT, typename FuncT, int NumThreadsPerBlock, int NumElementsPerThread> | ||
__global__ void _UnaryElementWise( | ||
const InT* input_data, | ||
OutT* output_data, | ||
const FuncT functor, | ||
CUDA_LONG N) { | ||
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; | ||
InT value[NumElementsPerThread]; | ||
|
||
CUDA_LONG id = start; | ||
#pragma unroll | ||
for (int i = 0; i < NumElementsPerThread; i++) { | ||
if (id < N) { | ||
value[i] = input_data[id]; | ||
id += NumThreadsPerBlock; | ||
} | ||
} | ||
|
||
id = start; | ||
#pragma unroll | ||
for (int i = 0; i < NumElementsPerThread; i++) { | ||
if (id < N) { | ||
output_data[id] = functor(value[i]); | ||
id += NumThreadsPerBlock; | ||
} | ||
} | ||
} | ||
|
||
template <typename InT, typename OutT, typename FuncT> | ||
void UnaryElementWiseImpl( | ||
cudaStream_t stream, | ||
const InT* input_data, | ||
OutT* output_data, | ||
const FuncT& func, | ||
size_t count) { | ||
if (count == 0) // special case where there's a dim value of 0 in the shape | ||
return; | ||
|
||
int blocksPerGrid = static_cast<int>(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); | ||
CUDA_LONG N = static_cast<CUDA_LONG>(count); | ||
_UnaryElementWise<InT, OutT, FuncT, GridDim::maxThreadsPerBlock, GridDim::maxElementsPerThread> | ||
<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>( | ||
input_data, | ||
output_data, | ||
func, | ||
N); | ||
} | ||
|
||
} // namespace cuda |
90 changes: 90 additions & 0 deletions
90
plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include <cuda_runtime.h> | ||
#include "cu_inc/unary_elementwise_impl.cuh" | ||
|
||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 | ||
#include "cuda_fp8.h" | ||
#endif | ||
#include <cuda_fp16.h> | ||
|
||
namespace cuda { | ||
|
||
// the postfix of means the types supported by the op: | ||
// B: uint8_t | ||
// W: uint16_t | ||
// U: uint32_t | ||
// Z: uint64_t | ||
// C: int8_t | ||
// S: int16_t | ||
// I: int32_t | ||
// L: int64_t | ||
// H: float16 | ||
// F: float | ||
// D: double | ||
// O: bool | ||
// X: BFloat16 | ||
|
||
// When casting, half needs to be converted via float type from most other types | ||
template <typename T> | ||
struct ViaTypeMap { | ||
typedef T ViaT; | ||
}; | ||
|
||
template <> | ||
struct ViaTypeMap<half> { | ||
typedef float ViaT; | ||
}; | ||
|
||
template <typename InT, typename OutT> | ||
struct OP_Cast { | ||
__device__ __inline__ OutT operator()(const InT& a) const { | ||
const bool any_float16 = std::is_same<half, InT>::value || std::is_same<half, OutT>::value; | ||
typedef typename std::conditional<any_float16, half, OutT>::type T; | ||
typedef typename ViaTypeMap<T>::ViaT ViaT; | ||
return (OutT)((ViaT)a); | ||
} | ||
}; | ||
|
||
#define IMPL_CAST_IMPL(InT, OutT) \ | ||
void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \ | ||
UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast<InT, OutT>(), count); \ | ||
} | ||
|
||
#define IMPL_CAST_IMPL_THROW(InT, OutT) \ | ||
void Explicit_Impl_Cast(cudaStream_t /*stream*/, const InT* /*input_data*/, OutT* /*output_data*/, \ | ||
size_t /*count*/) { \ | ||
ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ | ||
} | ||
|
||
#define IMPL_CAST_IMPL_FROM(T) \ | ||
IMPL_CAST_IMPL(T, half) \ | ||
IMPL_CAST_IMPL(T, float) \ | ||
IMPL_CAST_IMPL(T, double) \ | ||
IMPL_CAST_IMPL(T, int8_t) \ | ||
IMPL_CAST_IMPL(T, int16_t) \ | ||
IMPL_CAST_IMPL(T, int32_t) \ | ||
IMPL_CAST_IMPL(T, int64_t) \ | ||
IMPL_CAST_IMPL(T, uint8_t) \ | ||
IMPL_CAST_IMPL(T, uint16_t) \ | ||
IMPL_CAST_IMPL(T, uint32_t) \ | ||
IMPL_CAST_IMPL(T, uint64_t) \ | ||
IMPL_CAST_IMPL(T, bool) \ | ||
// IMPL_CAST_IMPL(T, BFloat16) | ||
|
||
IMPL_CAST_IMPL_FROM(half) | ||
IMPL_CAST_IMPL_FROM(float) | ||
IMPL_CAST_IMPL_FROM(double) | ||
IMPL_CAST_IMPL_FROM(int8_t) | ||
IMPL_CAST_IMPL_FROM(int16_t) | ||
IMPL_CAST_IMPL_FROM(int32_t) | ||
IMPL_CAST_IMPL_FROM(int64_t) | ||
IMPL_CAST_IMPL_FROM(uint8_t) | ||
IMPL_CAST_IMPL_FROM(uint16_t) | ||
IMPL_CAST_IMPL_FROM(uint32_t) | ||
IMPL_CAST_IMPL_FROM(uint64_t) | ||
IMPL_CAST_IMPL_FROM(bool) | ||
// IMPL_CAST_IMPL_FROM(BFloat16) | ||
|
||
} // namespace cuda |
51 changes: 51 additions & 0 deletions
51
plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include <stdint.h> | ||
#include <cuda_fp16.h> | ||
#include <cuda_runtime.h> | ||
|
||
namespace cuda { | ||
|
||
// Cast | ||
|
||
#define DECL_IMPL_CAST(InT, OutT) \ | ||
void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count); | ||
|
||
#define DECL_IMPL_CAST_FROM(T) \ | ||
DECL_IMPL_CAST(T, half) \ | ||
DECL_IMPL_CAST(T, float) \ | ||
DECL_IMPL_CAST(T, double) \ | ||
DECL_IMPL_CAST(T, int8_t) \ | ||
DECL_IMPL_CAST(T, int16_t) \ | ||
DECL_IMPL_CAST(T, int32_t) \ | ||
DECL_IMPL_CAST(T, int64_t) \ | ||
DECL_IMPL_CAST(T, uint8_t) \ | ||
DECL_IMPL_CAST(T, uint16_t) \ | ||
DECL_IMPL_CAST(T, uint32_t) \ | ||
DECL_IMPL_CAST(T, uint64_t) \ | ||
DECL_IMPL_CAST(T, bool) \ | ||
// DECL_IMPL_CAST(T, BFloat16) | ||
|
||
DECL_IMPL_CAST_FROM(half) | ||
DECL_IMPL_CAST_FROM(float) | ||
DECL_IMPL_CAST_FROM(double) | ||
DECL_IMPL_CAST_FROM(int8_t) | ||
DECL_IMPL_CAST_FROM(int16_t) | ||
DECL_IMPL_CAST_FROM(int32_t) | ||
DECL_IMPL_CAST_FROM(int64_t) | ||
DECL_IMPL_CAST_FROM(uint8_t) | ||
DECL_IMPL_CAST_FROM(uint16_t) | ||
DECL_IMPL_CAST_FROM(uint32_t) | ||
DECL_IMPL_CAST_FROM(uint64_t) | ||
DECL_IMPL_CAST_FROM(bool) | ||
// DECL_IMPL_CAST_FROM(BFloat16) | ||
|
||
template <typename InT, typename OutT> | ||
void Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { | ||
Explicit_Impl_Cast(stream, input_data, output_data, count); | ||
} | ||
|
||
} // namespace cuda |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: perhaps should replace
lochi
with a genericuser
or something like itThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could it be put in the c_cxx folder along with other C/C++ examples?