-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[TRTLLM-7028][feat] Enable guided decoding with speculative decoding (part 2: one-model engine) #6948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
📝 WalkthroughWalkthroughAdds CUDA host-function interop via pybind (launch/free), exposes a Torch helper module, and integrates a new batch-oriented GuidedWorker into guided decoding across executor, model engine, and speculative components. Adjusts imports/APIs, propagates guided metadata, updates speculative drafter and Eagle3 to drive the worker, and fixes dataclass defaults. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Py as Python (Torch)
participant TR as pybind runtime
participant CU as CUDA Stream
participant CB as HostFunc Callback
Py->>TR: launch_hostfunc(stream_ptr, py_hostfunc, *args, **kwargs)
TR->>CU: cudaLaunchHostFunc(trampoline, user_data)
Note over CU,CB: Callback scheduled on stream
CU-->>CB: invoke trampoline(user_data)
CB->>CB: Acquire GIL, call py_hostfunc(*args, **kwargs)
CB-->>Py: Python callable executes (errors printed, if any)
Py->>TR: free_hostfunc_user_data(handle)
TR-->>Py: user-data freed
sequenceDiagram
autonumber
participant Eng as PyTorchModelEngine
participant Cr as PyExecutorCreator
participant GW as GuidedWorker
participant Dra as ModelDrafter/Eagle3
participant Ex as PyExecutor
participant Dev as Device (Kernels)
Cr->>Eng: set_guided_worker(GuidedWorker(...))
Eng-->>Cr: bool
loop Per batch
Ex->>GW: add_batch(scheduled_requests[, new_tensors])
Dra->>GW: add_draft_batch(...) (draft mode)
Ex->>Dev: run model forward
Dev-->>Ex: logits (and d2t if any)
Ex->>GW: execute(logits[, d2t])
Dra->>GW: execute_draft_batch(logits, d2t, step flags)
opt Rejections
GW->>GW: rollback_rejected_tokens()
end
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 12
🧹 Nitpick comments (11)
cpp/tensorrt_llm/pybind/runtime/hostfunc.h (1)
1-27
: Add include guard per repo guidelines; prefer explicit guard over pragma onceThe coding guidelines require header guards named TRTLLM__H. Replace #pragma once with an explicit guard.
-#pragma once +#ifndef TRTLLM_HOSTFUNC_H +#define TRTLLM_HOSTFUNC_H @@ -} // namespace tensorrt_llm::pybind::runtime +} // namespace tensorrt_llm::pybind::runtime + +#endif // TRTLLM_HOSTFUNC_Htensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
342-346
: Consider refactoring the TODO comment implementationThe commented-out runtime guard for guided decoding has been removed, and the guided decoder assignment pattern with
TODO: fix it
seems like a temporary workaround. The guided decoder is instantiated but then immediately nullified after being assigned to the spec worker.Consider a cleaner approach that avoids the intermediate assignment and nullification:
- # TODO: fix it - model_engine.model.spec_worker.guided_decoder = guided_decoder - guided_decoder = None + # Assign guided decoder to spec worker for guided decoding in speculative mode + if hasattr(model_engine.model, 'spec_worker'): + model_engine.model.spec_worker.guided_decoder = guided_decoder + guided_decoder = None # Ownership transferred to spec worker + else: + logger.warning("Spec worker not found for guided decoding assignment")tensorrt_llm/_torch/speculative/eagle3.py (1)
335-336
: Consider implementing guided decoder for draft token generationThe comment "Insert guided decoder" indicates a placeholder where guided decoding logic should be applied during draft token generation. This seems incomplete compared to the target model path.
Would you like me to help implement the guided decoder logic for draft token generation? This would ensure consistency between target and draft model guided decoding behavior.
tensorrt_llm/_torch/hostfunc.py (1)
8-8
: Consider thread-safety for global registryThe
HOSTFUNC_USER_DATA_HANDLES
set is used globally without synchronization, which could lead to race conditions if multiple threads call the host function APIs concurrently.Consider adding thread-safety:
+import threading import atexit import torch from ..bindings.internal import runtime as bindings from ..logger import logger HOSTFUNC_USER_DATA_HANDLES = set() +_HOSTFUNC_LOCK = threading.Lock()Then protect all accesses to the set with the lock in the relevant functions.
cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp (1)
44-58
: Consider adding exception catching for all Python exceptionsThe current error handling in
cudaHostFuncTrampoline
only catchespy::error_already_set
. Consider catching all exceptions to prevent crashes from unexpected Python exceptions.Apply this diff to improve exception handling:
static void cudaHostFuncTrampoline(void* userData) { auto* hostFuncUserData = static_cast<HostFuncUserData*>(userData); // Acquire the GIL since we are calling Python code from a CUDA stream. py::gil_scoped_acquire gil; try { hostFuncUserData->pyHostFunc(*hostFuncUserData->pyArgs, **hostFuncUserData->pyKwargs); } catch (py::error_already_set& e) { e.restore(); PyErr_Print(); } + catch (std::exception& e) + { + py::print("Error in host function:", e.what()); + } + catch (...) + { + py::print("Unknown error in host function"); + } }examples/guided_decoding/try_pytorch_bindings_v3.py (2)
24-25
: Consider usingtorch.no_grad()
for inferenceThe forward pass in the DummyModel should be wrapped with
torch.no_grad()
to avoid unnecessary gradient computation during inference.def forward(self, x: torch.Tensor): # Simulate some GPU computation. - for i in range(10): - torch.matmul(self.a, self.b) + with torch.no_grad(): + for i in range(10): + torch.matmul(self.a, self.b)
96-97
: Add error handling for decoding failuresThe assertion and decoding could fail without proper error handling.
- assert len(guided_decoder.token_ids) == 20 - print(tokenizer.decode(guided_decoder.token_ids)) + if len(guided_decoder.token_ids) != 20: + logger.error(f"Expected 20 tokens, got {len(guided_decoder.token_ids)}") + try: + decoded_text = tokenizer.decode(guided_decoder.token_ids) + print(f"Decoded text: {decoded_text}") + except Exception as e: + logger.error(f"Failed to decode tokens: {e}")tensorrt_llm/_torch/pyexecutor/guided_decoder.py (3)
170-171
: Line length exceeds coding guideline limitLines 170-171 exceed the 120-character limit specified in the coding guidelines.
- # The last new token must be acceptable unless the matcher is terminated: - # 1. For the main model loop, when overlap scheduler is enabled, the matcher may have accepted the EOS token in the draft tokens at the previous iteration. - # 2. For the draft model loop, the matcher may have accepted the EOS token at the previous drafting iteration. + # The last new token must be acceptable unless the matcher is terminated: + # 1. For the main model loop, when overlap scheduler is enabled, the matcher may have + # accepted the EOS token in the draft tokens at the previous iteration. + # 2. For the draft model loop, the matcher may have accepted the EOS token at the + # previous drafting iteration.
285-285
: Line length exceeds coding guideline limitLine 285 exceeds the 120-character limit.
- f"Failed to rollback: num_advanced_tokens={self.num_advanced_tokens[slot]}, num_accepted_tokens={num_accepted_tokens}, num_rollback_tokens={num_rollback_tokens}" + f"Failed to rollback: num_advanced_tokens={self.num_advanced_tokens[slot]}, " + f"num_accepted_tokens={num_accepted_tokens}, num_rollback_tokens={num_rollback_tokens}"
328-337
: V2 API wrappers lack input validationThe v2 wrapper methods don't validate that
guided_metadata
contains the required fields before forwarding.def build_v2(self, guided_metadata: GuidedMetadata): + if guided_metadata.scheduled_requests is None: + raise ValueError("guided_metadata.scheduled_requests cannot be None") self.build(guided_metadata.scheduled_requests, guided_metadata.gathered_input_ids) def execute_v2(self, guided_metadata: GuidedMetadata, logits: torch.Tensor): + if guided_metadata.scheduled_requests is None: + raise ValueError("guided_metadata.scheduled_requests cannot be None") self.execute(guided_metadata.scheduled_requests, logits) def rollback_rejected_tokens_v2(self, guided_metadata: GuidedMetadata): + if guided_metadata.scheduled_requests is None: + raise ValueError("guided_metadata.scheduled_requests cannot be None") self.rollback_rejected_tokens(guided_metadata.scheduled_requests, guided_metadata.new_tokens_lens)tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
403-409
: Consider lazy initialization for guided decoding tensorsThe
gathered_input_ids
andnew_tokens_lens
tensors are always allocated even when guided decoding is not used, consuming pinned memory unnecessarily.Consider moving the tensor initialization to
_set_up_guided_metadata()
to allocate them only when needed:self.guided_metadata: Optional[GuidedMetadata] = None - self.gathered_input_ids = torch.empty(self.max_num_tokens, - dtype=torch.int, - pin_memory=True) - self.new_tokens_lens = torch.empty(self.batch_size, - dtype=torch.int, - pin_memory=True) + self.gathered_input_ids = None + self.new_tokens_lens = None def _set_up_guided_metadata(self): if self.guided_metadata is None: self.guided_metadata = GuidedMetadata() + if self.gathered_input_ids is None: + self.gathered_input_ids = torch.empty(self.max_num_tokens, + dtype=torch.int, + pin_memory=True) + if self.new_tokens_lens is None: + self.new_tokens_lens = torch.empty(self.batch_size, + dtype=torch.int, + pin_memory=True) return self.guided_metadata
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (14)
cpp/tensorrt_llm/pybind/CMakeLists.txt
(1 hunks)cpp/tensorrt_llm/pybind/runtime/bindings.cpp
(2 hunks)cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp
(1 hunks)cpp/tensorrt_llm/pybind/runtime/hostfunc.h
(1 hunks)examples/guided_decoding/try_pytorch_bindings_v2.py
(1 hunks)examples/guided_decoding/try_pytorch_bindings_v3.py
(1 hunks)tensorrt_llm/_torch/hostfunc.py
(1 hunks)tensorrt_llm/_torch/models/modeling_speculative.py
(4 hunks)tensorrt_llm/_torch/pyexecutor/guided_decoder.py
(8 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py
(11 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(2 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(3 hunks)tensorrt_llm/_torch/speculative/interface.py
(1 hunks)tensorrt_llm/_torch/speculative/mtp.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (6)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}
: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...
Files:
cpp/tensorrt_llm/pybind/runtime/hostfunc.h
cpp/tensorrt_llm/pybind/runtime/bindings.cpp
cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Parameter names must be consistent between declarations and definitions
Files:
cpp/tensorrt_llm/pybind/runtime/hostfunc.h
cpp/tensorrt_llm/pybind/runtime/bindings.cpp
cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp
**/*.{h,hpp,hxx,hh,cuh}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Header files must use include guards named TRTLLM__H without underscores prefix/suffix (e.g., TRTLLM_FOO_BAR_HELLO_H)
Files:
cpp/tensorrt_llm/pybind/runtime/hostfunc.h
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
cpp/tensorrt_llm/pybind/runtime/hostfunc.h
tensorrt_llm/_torch/speculative/mtp.py
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
examples/guided_decoding/try_pytorch_bindings_v2.py
examples/guided_decoding/try_pytorch_bindings_v3.py
tensorrt_llm/_torch/hostfunc.py
tensorrt_llm/_torch/speculative/interface.py
cpp/tensorrt_llm/pybind/runtime/bindings.cpp
tensorrt_llm/_torch/speculative/eagle3.py
cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp
tensorrt_llm/_torch/pyexecutor/model_engine.py
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
tensorrt_llm/_torch/models/modeling_speculative.py
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/speculative/mtp.py
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
examples/guided_decoding/try_pytorch_bindings_v2.py
examples/guided_decoding/try_pytorch_bindings_v3.py
tensorrt_llm/_torch/hostfunc.py
tensorrt_llm/_torch/speculative/interface.py
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
tensorrt_llm/_torch/models/modeling_speculative.py
**/*.{cpp,cxx,cc,cu}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,cxx,cc,cu}
: Avoid literal values except for 0, nullptr, true, false; use named constexpr for other literals
Place semicolon of empty for/while loop on a new line
Always use brace-delimited bodies for switch/while/do-for/if/else
Use inline C comments in argument lists when parameter meaning is unclear (e.g., /* checkForErrors = */ false)
Do not use assignment in subexpressions (e.g., if (x = y) ... is forbidden)
Switch on enums should enumerate all values and omit default to catch new values at compile time
Structure switch statements; prohibit fallthrough except between empty cases; each case ends with break or throw; return at end of case not allowed; put break inside braces for compound case
Prefer anonymous namespaces over static for internal linkage of functions
Every defined function must be called at least once (no unused methods)
Files:
cpp/tensorrt_llm/pybind/runtime/bindings.cpp
cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
344-344: Line too long (132 > 120)
(E501)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
170-170: Line too long (171 > 120)
(E501)
171-171: Line too long (126 > 120)
(E501)
285-285: Line too long (181 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (11)
tensorrt_llm/_torch/speculative/interface.py (1)
124-128
: Fix removes accidental tuple defaults; aligns runtime types with annotationsRemoving the trailing commas ensures these defaults are the intended scalar values (enum/None), not single-element tuples. This resolves subtle bugs when comparing or checking truthiness and matches the annotated types.
cpp/tensorrt_llm/pybind/runtime/bindings.cpp (2)
19-19
: Include hostfunc.h to wire HostFunc bindingsThe new include is correct and local-include path is appropriate given this file’s directory. This unblocks registration of HostFunc bindings.
465-466
: Initialize HostFunc bindings in module initCalling initHostFuncBindings(m) after initMoeBindings(m) cleanly registers the new Python-facing APIs. No ordering issues apparent.
cpp/tensorrt_llm/pybind/CMakeLists.txt (1)
17-20
: Add hostfunc.cpp to pybind SRCSIncluding runtime/hostfunc.cpp ensures the new bindings are compiled and linked. Moving testing/modelSpecBinding.cpp doesn’t materially affect linkage; order here is fine.
tensorrt_llm/_torch/speculative/eagle3.py (2)
269-269
: LGTM! Clean integration of guided decoderThe addition of the optional
guided_decoder
member is well-structured and follows best practices for optional features.
273-274
: LGTM! Forward signature properly extendedThe forward method signature correctly adds the
guided_metadata
parameter in the expected position, maintaining backwards compatibility.tensorrt_llm/_torch/models/modeling_speculative.py (3)
7-7
: LGTM! Clean import organizationThe import path adjustment to use relative imports is appropriate and follows Python best practices.
372-372
: LGTM! Proper API extension for guided decodingThe addition of the optional
guided_metadata
parameter to the forward signature is well-designed, maintaining backward compatibility while enabling guided decoding functionality.
399-399
: LGTM! Correct propagation of guided metadataThe guided_metadata is properly threaded through to the spec_worker, enabling guided decoding in the speculative path.
cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp (1)
60-75
: Missing null pointer validationThe function doesn't validate that
streamPtr
is non-null before casting and using it.Add null pointer validation:
uintptr_t launchHostFunc(uintptr_t streamPtr, py::function pyHostFunc, py::args pyArgs, py::kwargs pyKwargs) { + if (streamPtr == 0) + { + throw std::invalid_argument("Stream pointer cannot be null"); + } auto const stream = reinterpret_cast<cudaStream_t>(streamPtr); auto hostFuncUserData = std::make_unique<HostFuncUserData>(pyHostFunc, py::tuple(pyArgs), py::dict(pyKwargs)); cudaError_t err = cudaLaunchHostFunc(stream, cudaHostFuncTrampoline, hostFuncUserData.get()); if (err != cudaSuccess) { - throw std::runtime_error("Failed to launch host function."); + throw std::runtime_error(std::string("Failed to launch host function: ") + cudaGetErrorString(err)); }Likely an incorrect or invalid review comment.
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
1624-1642
: Missing null check before accessing guided_decoderThe code checks if
guided_metadata
is not None but doesn't verify thatself.model.spec_worker.guided_decoder
exists before calling methods on it.if guided_metadata is not None: guided_metadata.scheduled_requests = scheduled_requests if new_tensors_device is None: guided_metadata.gathered_input_ids = None guided_metadata.new_tokens_lens = None else: num_gathered = len(gather_ids) gathered_input_ids_cuda = self.input_ids_cuda[ self.gather_ids_cuda[:num_gathered]] self.gathered_input_ids[:num_gathered].copy_( gathered_input_ids_cuda, non_blocking=True) guided_metadata.gathered_input_ids = self.gathered_input_ids[: num_gathered] self.new_tokens_lens.copy_(new_tokens_lens_device, non_blocking=True) guided_metadata.new_tokens_lens = self.new_tokens_lens guided_metadata.token_event.record() inputs['guided_metadata'] = guided_metadataLikely an incorrect or invalid review comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🔭 Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
346-359
: Ensure guided_decoder is propagated into the speculative drafterRight now you attach the newly created GuidedDecoder to model_engine.model.spec_worker and then do
guided_decoder = None
. But later you calldrafter = get_spec_drafter( model_engine, draft_model_engine, sampler, spec_resource_manager=spec_resource_manager, guided_decoder=guided_decoder, )with
guided_decoder=None
, so the resulting ModelDrafter never sees the guided decoder you just built. You need to either:• Pass the real decoder instance into get_spec_drafter instead of nulling it out:
- In _torch/pyexecutor/py_executor_creator.py at lines 356–358, remove
guided_decoder = None
and always forward the decoder.
• Or make get_spec_drafter fall back tomodel_engine.model.spec_worker.guided_decoder
when its guided_decoder arg is None.Pinpointed changes:
- _torch/pyexecutor/py_executor_creator.py@356–358
- _torch/speculative/utils.py@get_spec_drafter: add a guard like
if guided_decoder is None and hasattr(model_engine.model, "spec_worker"): guided_decoder = model_engine.model.spec_worker.guided_decoderThese refactors are required to ensure guided decoding actually flows into your drafter.
♻️ Duplicate comments (5)
tensorrt_llm/_torch/speculative/mtp.py (2)
103-105
: Fix tuple default: mtp_num_modules mistakenly a tuple due to trailing commaThe trailing comma makes mtp_num_modules a tuple (1,) instead of int 1, which will break arithmetic and torch ops (e.g., torch.arange(self.mtp_num_modules)).
- mtp_num_modules: int = 1, + mtp_num_modules: int = 1
268-270
: Accepted-count and rewind length have correct int semantics (good); keep it that waynum_new_tokens here is a Python int (from list/tensor.item), so both py_num_accepted_draft_tokens and py_rewind_len are plain ints. This addresses previous tensor/int mixing concerns downstream.
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
342-346
: Fix long comment line; keep under 120 charsRuff flagged the long line in this commented-out guard. Wrap it to satisfy line-length limits.
- # raise ValueError( - # "Guided decoding is only supported with speculative decoding that has a dedicated drafter (two-model engine)." - # ) + # raise ValueError( + # "Guided decoding is only supported with speculative decoding that has a " + # "dedicated drafter (two-model engine)." + # )tensorrt_llm/_torch/speculative/eagle3.py (1)
281-294
: Guard guided-decoder sync/execute to prevent hangs or hard failurestoken_event.synchronize()/waits and guided-decoder calls can raise or hang; wrap with try/except and log, so the engine can continue without guidance on failure.
- if self.guided_decoder is not None and guided_metadata is not None: - with torch.cuda.stream(self.guided_decoder._stream): - torch.cuda.current_stream().wait_event( - guided_metadata.token_event) - self.guided_decoder.rollback_rejected_tokens_v2(guided_metadata) - self.guided_decoder.build_v2(guided_metadata) - self.guided_decoder.copy_bitmask( - guided_metadata.scheduled_requests) - guided_metadata.bitmask_event.record() - - torch.cuda.current_stream().wait_event( - guided_metadata.bitmask_event) - self.guided_decoder.execute_v2(guided_metadata, logits) + if self.guided_decoder is not None and guided_metadata is not None: + try: + with torch.cuda.stream(self.guided_decoder._stream): + torch.cuda.current_stream().wait_event(guided_metadata.token_event) + self.guided_decoder.rollback_rejected_tokens_v2(guided_metadata) + self.guided_decoder.build_v2(guided_metadata) + self.guided_decoder.copy_bitmask(guided_metadata.scheduled_requests) + guided_metadata.bitmask_event.record() + torch.cuda.current_stream().wait_event(guided_metadata.bitmask_event) + self.guided_decoder.execute_v2(guided_metadata, logits) + except Exception as e: + logger.error(f"Guided decoder failed; continuing without guidance: {e}") + import traceback as _tb; _tb.print_exc()Add missing imports to support logging and traceback:
@@ -from tensorrt_llm.mapping import Mapping +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.logger import logger +import traceback # used in failure path (aliased locally above to avoid linter complaints)tensorrt_llm/_torch/pyexecutor/guided_decoder.py (1)
345-355
: Add bounds checks and exception handling in host-callable run()run() assumes grammar_matchers[0] exists and token_ids has at least one element. Add checks and handle exceptions to avoid hard crashes when called from host.
@hostfunc def run(self, token_ids: torch.Tensor): - self.grammar_matchers[0].accept_token(token_ids[0].item()) - self.grammar_matchers[0].fill_next_token_bitmask(self.bitmask_host, 0) - if not hasattr(self, "token_ids"): - self.token_ids = [] - self.token_ids.append(token_ids[0].item()) + if token_ids.numel() == 0: + logger.warning("GuidedDecoder.run received empty token_ids") + return + if not self.grammar_matchers or self.grammar_matchers[0] is None: + logger.warning("GuidedDecoder.run called before matcher initialization") + return + try: + tid0 = token_ids[0].item() + self.grammar_matchers[0].accept_token(tid0) + self.grammar_matchers[0].fill_next_token_bitmask(self.bitmask_host, 0) + except Exception as e: + logger.error(f"Error in GuidedDecoder.run: {e}") + return + if not hasattr(self, "token_ids"): + self.token_ids = [] + self.token_ids.append(tid0)
🧹 Nitpick comments (12)
tensorrt_llm/_torch/speculative/mtp.py (2)
1-3
: Add NVIDIA copyright headerPer repository guidelines, prepend the current-year NVIDIA copyright header to all source files.
Apply at file top:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
250-252
: Avoid materializing full Python lists for lengths; fetch per-request as intnew_tokens_lens is already a CPU tensor; creating a full Python list is unnecessary. Index the tensor per request and cast to int. This reduces overhead for large batches.
- new_tokens_lens_list = state.host.new_tokens_lens.tolist() + # Avoid building a full Python list; index per request and cast to int belowAnd:
- num_new_tokens = new_tokens_lens_list[req.py_seq_slot] + num_new_tokens = int(state.host.new_tokens_lens[req.py_seq_slot])Also applies to: 263-264
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (2)
1-3
: Add NVIDIA copyright headerPer repository guidelines, prepend the current-year NVIDIA copyright header to all source files.
+// Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
356-359
: Remove or ticket-ize the TODO for production readinessThe TODO: fix it suggests unfinished integration. Convert to a tracked issue or clarify the intended ownership model in a comment.
tensorrt_llm/_torch/speculative/eagle3.py (2)
1-3
: Add NVIDIA copyright headerPer repository guidelines, prepend the current-year NVIDIA copyright header to all source files.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
283-284
: Avoid reaching into GuidedDecoder private attribute _streamAccessing a private attribute is brittle. Expose a public property on GuidedDecoder and use it here.
In GuidedDecoder (see file tensorrt_llm/_torch/pyexecutor/guided_decoder.py), add:
@@ class GuidedDecoder: def bitmask_size(self) -> int: return math.ceil(self.vocab_size_padded / 32) + @property + def stream(self) -> torch.cuda.Stream: + return self._streamThen update this call site:
- try: - with torch.cuda.stream(self.guided_decoder._stream): + try: + with torch.cuda.stream(self.guided_decoder.stream):tensorrt_llm/_torch/pyexecutor/guided_decoder.py (6)
1-4
: Add NVIDIA copyright headerPer repository guidelines, prepend the current-year NVIDIA copyright header to all source files.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import math from dataclasses import dataclass, field from typing import Iterable, List, Optional, Tuple
114-126
: Wrap long comment and clarify conditionBreak the long comment lines to satisfy linters; keep the rationale succinct.
- # Fix race condition: The following host code may change the state of requests and in turn affect the result of hostfunc. - # Where the hostfunc is launched, but not executed yet. - # if llm_req.is_generation_in_progress_state: + # Fix race condition: host code may change request state while a hostfunc is + # queued but not yet executed, altering the expected offsets. Use a broader + # condition than is_generation_in_progress_state here. + # if llm_req.is_generation_in_progress_state:
180-182
: Wrap long comment lines under 120 charactersLine length exceeds 120. Reflow for readability and linting.
- # 1. For the main model loop, when overlap scheduler is enabled, the matcher may have accepted the EOS token in the draft tokens at the previous iteration. - # 2. For the draft model loop, the matcher may have accepted the EOS token at the previous drafting iteration. + # 1. Main model loop: with overlap scheduler, the matcher may have accepted EOS in the + # draft tokens during the previous iteration. + # 2. Draft model loop: the matcher may have accepted EOS at the previous drafting step.
288-289
: Wrap long error message stringThe f-string exceeds 120 chars. Wrap to satisfy linters.
- raise ValueError( - f"Failed to rollback: num_advanced_tokens={self.num_advanced_tokens[slot]}, num_accepted_tokens={num_accepted_tokens}, num_rollback_tokens={num_rollback_tokens}" - ) + raise ValueError( + "Failed to rollback: " + f"num_advanced_tokens={self.num_advanced_tokens[slot]}, " + f"num_accepted_tokens={num_accepted_tokens}, " + f"num_rollback_tokens={num_rollback_tokens}" + )
331-335
: Mark execute_v2 as hostfunc (optional) and keep v2 API symmetrybuild_v2 and rollback_rejected_tokens_v2 are hostfuncs; execute_v2 could remain device-side, but for symmetry and scheduling guarantees you may want to expose it as hostfunc too. If intentional to keep it non-hostfunc, ignore.
Option A (make it hostfunc):
- def execute_v2(self, guided_metadata: GuidedMetadata, logits: torch.Tensor): + @hostfunc + def execute_v2(self, guided_metadata: GuidedMetadata, logits: torch.Tensor): self.execute(guided_metadata.scheduled_requests, logits)Also applies to: 339-343
89-92
: Expose a public stream property to avoid private attribute accessProvide a public property so callers (e.g., Eagle3) don't reach into _stream directly.
def bitmask_size(self) -> int: return math.ceil(self.vocab_size_padded / 32) + @property + def stream(self) -> torch.cuda.Stream: + return self._stream
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (4)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
(7 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(2 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(3 hunks)tensorrt_llm/_torch/speculative/mtp.py
(2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/speculative/mtp.py
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/speculative/mtp.py
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
120-120: Line too long (133 > 120)
(E501)
180-180: Line too long (171 > 120)
(E501)
181-181: Line too long (126 > 120)
(E501)
288-288: Line too long (181 > 120)
(E501)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
344-344: Line too long (132 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (4)
tensorrt_llm/_torch/speculative/mtp.py (2)
263-263
: Add int cast to maintain consistent typing.
num_new_tokens
is extracted from a Python list but should be cast toint
for consistency with the previous type-safe approach from the past review.Apply this diff:
- num_new_tokens = new_tokens_lens_list[req.py_seq_slot] + num_new_tokens = int(new_tokens_lens_list[req.py_seq_slot])
268-269
: Apply int cast to maintain type consistency.Following the established pattern from the past review,
num_new_tokens
should be cast toint
before arithmetic operations to ensurepy_num_accepted_draft_tokens
andpy_rewind_len
are plain Python integers.Apply this diff:
- req.py_num_accepted_draft_tokens = num_new_tokens - 1 - req.py_rewind_len = self.draft_len - req.py_num_accepted_draft_tokens + req.py_num_accepted_draft_tokens = int(num_new_tokens) - 1 + req.py_rewind_len = self.draft_len - req.py_num_accepted_draft_tokenstensorrt_llm/_torch/speculative/eagle3.py (1)
281-296
: Add error handling around guided-decoder operations.The guided-decoder calls are unprotected and may raise exceptions or hang (e.g.,
token_event.synchronize
has no timeout). Add defensive error handling to ensure execution continues even if guided decoding fails.Apply this diff:
if self.guided_decoder is not None: - with torch.cuda.stream(self.guided_decoder._stream): - torch.cuda.current_stream().wait_event( - guided_metadata.token_event) - # Fix it. - guided_metadata.next_batch_hostfunc() - self.guided_decoder.rollback_rejected_tokens_hostfunc( - guided_metadata) - self.guided_decoder.build_hostfunc(guided_metadata) - self.guided_decoder.bitmask_copy(guided_metadata) - guided_metadata.bitmask_event.record() - - torch.cuda.current_stream().wait_event( - guided_metadata.bitmask_event) - self.guided_decoder.execute(guided_metadata, logits) + try: + with torch.cuda.stream(self.guided_decoder._stream): + torch.cuda.current_stream().wait_event( + guided_metadata.token_event) + guided_metadata.next_batch_hostfunc() + self.guided_decoder.rollback_rejected_tokens_hostfunc( + guided_metadata) + self.guided_decoder.build_hostfunc(guided_metadata) + self.guided_decoder.bitmask_copy(guided_metadata) + guided_metadata.bitmask_event.record() + + torch.cuda.current_stream().wait_event( + guided_metadata.bitmask_event) + self.guided_decoder.execute(guided_metadata, logits) + except Exception as e: + import traceback + traceback.print_exc() + logger.error(f"Guided decoder failed; continuing without guidance: {e}")tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
2197-2200
: Fix unsafe attribute access chain.The code directly accesses
self.model.spec_worker.guided_decoder
without ensuring all intermediate attributes exist, which can raiseAttributeError
.Apply this diff:
guided_metadata = None - if self.model.spec_worker.guided_decoder: + if (hasattr(self.model, 'spec_worker') and + hasattr(self.model.spec_worker, 'guided_decoder') and + self.model.spec_worker.guided_decoder): guided_metadata = self._set_up_guided_metadata()
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
(3 hunks)tensorrt_llm/_torch/pyexecutor/guided_decoder.py
(6 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py
(16 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(4 hunks)tensorrt_llm/_torch/speculative/mtp.py
(3 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/speculative/mtp.py
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/speculative/mtp.py
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
108-108: Line too long (129 > 120)
(E501)
246-246: Line too long (171 > 120)
(E501)
247-247: Line too long (126 > 120)
(E501)
255-255: Line too long (122 > 120)
(E501)
348-348: Line too long (181 > 120)
(E501)
🔇 Additional comments (42)
tensorrt_llm/_torch/speculative/mtp.py (2)
103-103
: LGTM! Fix applied from past review.The trailing comma in the default value has been correctly removed, preventing
mtp_num_modules
from becoming a single-element tuple.
250-250
: LGTM! Improved list conversion approach.Converting
new_tokens_lens
to a Python list withtolist()
is cleaner than tensor indexing for per-request access patterns.tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (3)
36-36
: LGTM! Guided metadata parameter added.The constructor now accepts
guided_metadata
parameter to support guided decoding integration while maintaining backward compatibility withNone
default.
75-75
: LGTM! Store guided metadata for CUDA graph usage.Proper storage of the
guided_metadata
instance for later use in graph capture and execution.
95-95
: LGTM! Guided metadata propagated to forward function.The
guided_metadata
is correctly included in the inputs dictionary passed to the forward function during graph capture and execution.tensorrt_llm/_torch/speculative/eagle3.py (4)
10-10
: LGTM! Import added for guided decoder integration.The import of
GuidedDecoder
enables Eagle3 to integrate with the guided decoding framework.
269-269
: LGTM! Guided decoder attribute added.The
guided_decoder
attribute allows Eagle3OneModelWorker to optionally integrate with guided decoding while maintaining compatibility when not used.
274-274
: LGTM! Forward signature extended for guided decoding.The
guided_metadata
parameter properly extends the forward method signature to support guided decoding integration.
345-345
: LGTM! Guided decoder integration point marked.The comment appropriately marks where guided decoding steps are integrated into the draft-generation flow.
tensorrt_llm/_torch/pyexecutor/model_engine.py (15)
44-53
: LGTM! Import reorganization and new imports added.The imports have been properly reorganized with new guided decoding components (
GuidedMetadata
,SampleStateTensors
,SampleStateTensorsMTP
) and moved speculative utilities to their appropriate locations.
58-64
: LGTM! Additional imports for guided decoding support.The imports of
GuidedMetadata
and related components properly enable guided decoding integration in the model engine.
403-410
: LGTM! Guided metadata state initialization.The guided metadata fields and host tensor buffers are properly initialized to support the guided decoding workflow.
871-876
: LGTM! Lazy guided metadata initialization.The
_set_up_guided_metadata
method properly initializesGuidedMetadata
on-demand, following the established pattern for other metadata initialization.
1019-1021
: LGTM! Guided metadata integration in CUDA graph runner.The
guided_metadata
is correctly passed to theDecodingCUDAGraphRunner
constructor to enable guided decoding support in CUDA graphs.
1215-1217
: LGTM! Token event recording for guided metadata.The token event is properly recorded when guided metadata is present, enabling proper synchronization in the guided decoding workflow.
1226-1226
: LGTM! Guided metadata parameter added to input preparation methods.The
guided_metadata
parameter is properly added to the input preparation method signatures.
1331-1331
: LGTM! Request ID tracking maintained.The
request_ids.append(request.py_request_id)
calls are properly positioned to maintain request ID tracking for guided decoding.
1629-1646
: LGTM! Guided metadata population in input preparation.The guided metadata is properly populated with the scheduled requests and tensor data when provided, enabling the guided decoding workflow to access the necessary information.
1681-1682
: LGTM! Guided metadata parameter consistency.The
guided_metadata
parameter is consistently added to the no-cache input preparation method.
1779-1782
: LGTM! Guided metadata handling in no-cache path.Guided metadata is properly handled in the no-cache input preparation path, maintaining consistency across different execution modes.
2151-2151
: LGTM! Guided metadata parameter added to main preparation method.The
guided_metadata
parameter is consistently added to the main_prepare_inputs
method signature.
2164-2164
: LGTM! Guided metadata forwarding in input preparation.The
guided_metadata
parameter is properly forwarded to the appropriate input preparation method.
2212-2214
: LGTM! Guided metadata parameter forwarding.The
guided_metadata
parameter is properly forwarded to the no-cache input preparation method.
2233-2235
: LGTM! Guided metadata parameter forwarding in main path.The
guided_metadata
parameter is correctly forwarded in the main forward execution path.tensorrt_llm/_torch/pyexecutor/guided_decoder.py (18)
2-4
: LGTM! New imports for guided decoding data structures.The imports of
dataclass
,field
,Queue
, and typing components properly support the new guided decoding data structures.
7-7
: LGTM! XGrammar import added.The
xgrammar
import enables theapply_token_bitmask_inplace
functionality used in the execute method.
10-10
: LGTM! Guided decoding configuration imports.The imports of
GuidedDecodingConfig
andGuidedDecodingParams
provide the necessary configuration types.
12-12
: LGTM! Host function decorator import.The
hostfunc
import enables the host-callable method decorators used throughout the guided decoder.
13-15
: LGTM! Grammar matcher imports updated.The import statement properly includes all necessary grammar matcher components.
19-74
: LGTM! Well-designed GuidedRequest dataclass.The
GuidedRequest
dataclass provides a clean abstraction for request state with proper method logic for determining when matcher initialization and advancement are required. Thefrom_request
class method provides a clean conversion fromLlmRequest
.
76-101
: LGTM! Well-structured GuidedRequestBatch class.The
GuidedRequestBatch
class provides good encapsulation of batch processing logic with proper offset calculation for guided request processing.
103-141
: LGTM! Comprehensive GuidedMetadata class.The
GuidedMetadata
dataclass provides proper structure for coordinating host-device guided decoding workflow with appropriate CUDA events and queue management.
171-182
: LGTM! Proper bitmask buffer initialization.The grammar matchers list and bitmask buffers (both host and device) are properly initialized with appropriate sizing for draft tokens support.
200-200
: LGTM! Method signature updated for guided metadata.The
build
method signature properly acceptsGuidedMetadata
instead ofScheduledRequests
, aligning with the new guided decoding architecture.
208-232
: LGTM! Flexible input handling for build method.The method properly handles both
gathered_input_ids
from guided metadata and fallback to request-specific tokens, providing flexibility for different execution modes.
233-283
: LGTM! Robust matcher processing logic.The matcher initialization and advancement logic properly handles various request states and draft tokens with appropriate error handling and termination conditions.
284-289
: LGTM! Proper bitmask copy implementation.The
bitmask_copy
method correctly transfers the host bitmask to device memory using non-blocking copy for performance.
293-293
: LGTM! Execute method signature updated.The
execute
method signature properly acceptsGuidedMetadata
for consistency with the new architecture.
310-312
: LGTM! Proper logits masking implementation.The method correctly applies token bitmask using XGrammar's
apply_token_bitmask_inplace
with proper token count limiting.
318-318
: LGTM! Rollback method signature updated.The
rollback_rejected_tokens
method signature properly acceptsGuidedMetadata
for consistency.
328-351
: LGTM! Flexible rollback token counting.The rollback logic properly handles both explicit
new_tokens_lens
from metadata and fallback to request-specific accepted token counts, providing robustness across different execution modes.
391-411
: LGTM! Host-callable method wrappers.The host-callable method wrappers properly enable device-side execution of guided decoding operations. The methods provide appropriate delegation to the main implementation methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (3)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
342-345
: Fix the formatting issueThe commented-out lines exceed the maximum line length and should be properly formatted.
Apply this diff to properly format the commented lines:
- # if spec_config is not None and not has_spec_drafter: - # raise ValueError( - # "Guided decoding is only supported with speculative decoding that has a dedicated drafter (two-model engine)." - # ) + # if spec_config is not None and not has_spec_drafter: + # raise ValueError( + # "Guided decoding is only supported with speculative decoding that has a " + # "dedicated drafter (two-model engine)." + # )tensorrt_llm/_torch/speculative/eagle3.py (1)
451-452
: Fix the incorrect attribute name from "dt2" to "d2t"The draft-to-token offset mapping is defined as
d2t
throughout the codebase, but this code incorrectly checks fordt2
.Apply this diff to fix the attribute name:
- if (d2t := getattr(draft_model.model, 'dt2', None)): + if (d2t := getattr(draft_model.model, 'd2t', None)): draft_tokens = d2t[draft_tokens] + draft_tokenstensorrt_llm/_torch/pyexecutor/guided_decoder.py (1)
358-363
: Add bounds checking for grammar matcher accessThe
run
method accessesself.grammar_matchers[0]
without checking if it exists or if the index is valid.Apply this diff to add proper error handling:
@hostfunc def run(self, token_ids: torch.Tensor): + if token_ids.numel() == 0: + logger.warning("Empty token_ids tensor provided to run method") + return + if self.grammar_matchers[0] is None: + logger.warning("Grammar matcher at index 0 is not initialized") + return self.grammar_matchers[0].accept_token(token_ids[0].item()) self.grammar_matchers[0].fill_next_token_bitmask(self.bitmask_host, 0) if not hasattr(self, "token_ids"): self.token_ids = [] self.token_ids.append(token_ids[0].item())
🧹 Nitpick comments (4)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
1211-1213
: Consider adding error handling for event recordingThe
token_event.record()
operation could potentially fail. Consider adding error handling to ensure the forward pass can continue even if guided decoding setup fails.Apply this diff to add error handling:
if self.guided_worker is not None: - self.guided_worker.token_event.record() + try: + self.guided_worker.token_event.record() + except Exception as e: + logger.warning(f"Failed to record guided worker token event: {e}")tensorrt_llm/_torch/pyexecutor/guided_decoder.py (3)
176-178
: Add TODO tracking for the "Fix it" commentThe comment "Fix it" on line 176 is vague and should be more descriptive about what needs to be fixed.
Would you like me to help clarify what needs to be fixed here or create an issue to track this TODO?
198-199
: Fix line length violationsLines 198-199 exceed the maximum line length of 120 characters.
Apply this diff to fix the line length:
- # The last new token must be acceptable unless the matcher is terminated: - # 1. For the main model loop, when overlap scheduler is enabled, the matcher may have accepted the EOS token in the draft tokens at the previous iteration. - # 2. For the draft model loop, the matcher may have accepted the EOS token at the previous drafting iteration. + # The last new token must be acceptable unless the matcher is terminated: + # 1. For the main model loop, when overlap scheduler is enabled, the matcher may have + # accepted the EOS token in the draft tokens at the previous iteration. + # 2. For the draft model loop, the matcher may have accepted the EOS token at the + # previous drafting iteration.
207-207
: Fix additional line length violationsLines 207 and 295 also exceed the maximum line length.
Apply this diff to fix the line lengths:
- f"Draft request {req.request_id} at slot {slot} failed to accept last new token: {req.new_token}." + f"Draft request {req.request_id} at slot {slot} failed to accept " + f"last new token: {req.new_token}." - f"Failed to rollback: num_advanced_tokens={self.num_advanced_tokens[slot]}, num_accepted_tokens={num_accepted_tokens}, num_rollback_tokens={num_rollback_tokens}" + f"Failed to rollback: num_advanced_tokens={self.num_advanced_tokens[slot]}, " + f"num_accepted_tokens={num_accepted_tokens}, " + f"num_rollback_tokens={num_rollback_tokens}"Also applies to: 295-295
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
tensorrt_llm/_torch/models/modeling_speculative.py
(3 hunks)tensorrt_llm/_torch/pyexecutor/guided_decoder.py
(5 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py
(7 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(2 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tensorrt_llm/_torch/models/modeling_speculative.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
🧠 Learnings (1)
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
PR: NVIDIA/TensorRT-LLM#7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.
Applied to files:
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
🧬 Code graph analysis (4)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (6)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (2)
GuidedDecoder
(113-363)GuidedWorker
(371-445)tensorrt_llm/_torch/models/modeling_utils.py (1)
vocab_size_padded
(504-505)tensorrt_llm/_torch/modules/embedding.py (1)
vocab_size_padded
(76-80)tensorrt_llm/_torch/models/modeling_speculative.py (1)
set_guided_worker
(423-426)tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
set_guided_worker
(472-478)tensorrt_llm/_torch/speculative/eagle3.py (1)
set_guided_worker
(502-504)
tensorrt_llm/_torch/speculative/eagle3.py (3)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (3)
GuidedWorker
(371-445)execute
(263-279)execute
(433-445)tensorrt_llm/_torch/models/modeling_speculative.py (1)
set_guided_worker
(423-426)tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
set_guided_worker
(472-478)
tensorrt_llm/_torch/pyexecutor/model_engine.py (6)
tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py (1)
BaseCheckpointLoader
(19-87)tensorrt_llm/_torch/speculative/interface.py (1)
SpecMetadata
(111-199)tensorrt_llm/_torch/speculative/utils.py (3)
get_num_extra_kv_tokens
(166-176)get_spec_metadata
(16-59)update_spec_config_from_model_config
(179-184)tensorrt_llm/_torch/speculative/mtp.py (1)
SampleStateTensorsMTP
(22-24)tensorrt_llm/_torch/pyexecutor/guided_decoder.py (3)
GuidedWorker
(371-445)SampleStateTensors
(366-368)add_batch
(395-406)tensorrt_llm/_torch/pyexecutor/sampler.py (1)
SampleStateTensors
(31-37)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (5)
tensorrt_llm/_torch/hostfunc.py (1)
hostfunc
(19-24)tensorrt_llm/_torch/pyexecutor/grammar_matcher.py (18)
GrammarMatcher
(13-30)LLGuidanceMatcherFactory
(174-222)XGrammarMatcherFactory
(61-128)create
(36-38)create
(93-128)create
(195-222)is_terminated
(29-30)is_terminated
(57-58)is_terminated
(165-166)accept_token
(16-17)accept_token
(47-48)accept_token
(139-150)fill_next_token_bitmask
(24-26)fill_next_token_bitmask
(53-55)fill_next_token_bitmask
(159-163)rollback
(20-21)rollback
(50-51)rollback
(152-157)tensorrt_llm/_torch/pyexecutor/llm_request.py (2)
LlmRequest
(271-411)get
(99-108)tensorrt_llm/_torch/pyexecutor/scheduler.py (2)
ScheduledRequests
(18-39)all_requests
(38-39)tensorrt_llm/_torch/pyexecutor/sampler.py (1)
SampleStateTensors
(31-37)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
344-344: Line too long (132 > 120)
(E501)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
198-198: Line too long (171 > 120)
(E501)
199-199: Line too long (126 > 120)
(E501)
207-207: Line too long (126 > 120)
(E501)
269-269: Line too long (126 > 120)
(E501)
295-295: Line too long (181 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (13)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (2)
28-28
: LGTM!The new import for
GuidedWorker
alongsideGuidedDecoder
is appropriate for the guided decoding framework changes.
347-359
: Spec_worker attribute availability across models verifiedI confirmed that all speculative‐decoding model classes (in
tensorrt_llm/_torch/models/modeling_speculative.py
) assignself.spec_worker
in their constructor viaget_spec_worker
(around line 360) . No other model classes definespec_worker
, so thehasattr(model_engine.model, "spec_worker")
check inpy_executor_creator.py
will only be true for those that support speculative decoding, and the fallback toGuidedDecoder
correctly covers the rest. No changes required here.tensorrt_llm/_torch/speculative/eagle3.py (2)
269-269
: LGTM!The addition of the
guided_worker
attribute with proper typing is well-structured.
502-504
: LGTM!The
set_guided_worker
method implementation is clean and follows the expected pattern.tensorrt_llm/_torch/pyexecutor/model_engine.py (4)
44-52
: LGTM!The imports have been properly reorganized to match the new module structure, with the addition of
GuidedWorker
alongside the existing imports.
403-403
: LGTM!The initialization of
guided_worker
as an optional attribute is properly typed and follows the existing pattern.
472-478
: LGTM!The
set_guided_worker
method correctly delegates to the model's method if available and updates the local reference on success.
1624-1626
: LGTM!The guided worker batch addition is properly integrated into the input preparation flow, correctly passing both the scheduled requests and the new tensors when available.
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (5)
19-74
: LGTM! Well-structured dataclass for guided request handlingThe
GuidedRequest
dataclass is well-designed with clear field definitions and logical helper methods for determining when matcher initialization and advancement are required.
76-110
: LGTM! Clean batch processing abstractionThe
GuidedRequests
dataclass provides a clean abstraction for batch processing with useful properties and iteration methods.
366-369
: LGTM! Clean protocol definitionThe
SampleStateTensors
protocol provides a clean interface for the expected tensor structure.
371-394
: LGTM! Well-structured GuidedWorker initializationThe
GuidedWorker
class properly extendsGuidedDecoder
with additional queue-based batch processing and CUDA event synchronization.
433-445
: LGTM! Clean event-based synchronizationThe
execute
method properly uses CUDA events for synchronization between the host functions and device operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (1)
336-349
: Wrong method invoked; draft rollback uses rejected-token rollbackrollback_draft_tokens calls _rollback_rejected_tokens instead of _rollback_draft_tokens.
Apply:
- requests = GuidedRequests.from_scheduled_requests( - scheduled_requests, self.max_num_draft_tokens) - self._rollback_rejected_tokens(requests) + requests = GuidedRequests.from_scheduled_requests( + scheduled_requests, self.max_num_draft_tokens) + self._rollback_draft_tokens(requests)
♻️ Duplicate comments (4)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (2)
371-377
: Harden hostfunc run against empty inputs and matcher stateApply:
@hostfunc def run(self, token_ids: torch.Tensor): - self.grammar_matchers[0].accept_token(token_ids[0].item()) - self.grammar_matchers[0].fill_next_token_bitmask(self.bitmask_host, 0) - if not hasattr(self, "token_ids"): - self.token_ids = [] - self.token_ids.append(token_ids[0].item()) + if token_ids.numel() == 0: + logger.warning("GuidedDecoder.run: empty token_ids") + return + matcher = self.grammar_matchers[0] + if matcher is None: + logger.warning("GuidedDecoder.run: matcher[0] is not initialized") + return + try: + matcher.accept_token(token_ids[0].item()) + matcher.fill_next_token_bitmask(self.bitmask_host, 0) + except Exception as e: + logger.error(f"GuidedDecoder.run failed: {e}") + return + self.token_ids = getattr(self, "token_ids", []) + self.token_ids.append(token_ids[0].item())
413-426
: Bounds and shape checks for next_batch slot indexingApply:
@hostfunc def next_batch(self) -> None: - # Fix it. if self.queue.empty(): return self.requests_hostfunc, has_new_tensors = self.queue.get() if not has_new_tensors: return - for req in self.requests_hostfunc: + for req in self.requests_hostfunc: if (slot := req.seq_slot) is None: continue - req.new_token, *req.draft_tokens = self.new_tokens[:, slot].tolist() + if not isinstance(slot, int) or slot < 0 or slot >= self.max_num_sequences: + logger.warning(f"next_batch: invalid slot {slot}") + continue + col = self.new_tokens[:, slot] + if col.numel() == 0: + logger.warning(f"next_batch: empty new_tokens for slot {slot}") + continue + values = col.tolist() + req.new_token, *req.draft_tokens = valuestensorrt_llm/_torch/speculative/eagle3.py (2)
281-283
: Wrap guided-worker execution to avoid forward crashesApply:
- if self.guided_worker is not None: - self.guided_worker.execute(logits) + if self.guided_worker is not None: + try: + self.guided_worker.execute(logits) + except Exception as e: + import traceback + traceback.print_exc() + # Continue without guidance + print(f"GuidedWorker.execute failed; proceeding without guidance: {e}")
465-468
: Apply correct d2t mapping in draft_decoderSame “dt2” typo here; breaks offset mapping.
Apply:
- # Apply d2t (offsets between draft model dictionary and main model dictionary). - if (d2t := getattr(draft_model.model, 'dt2', None)): + # Apply d2t (offsets between draft model dictionary and main model dictionary). + if (d2t := getattr(draft_model.model, 'd2t', None)): draft_tokens = d2t[draft_tokens] + draft_tokens
🧹 Nitpick comments (10)
tensorrt_llm/_torch/pyexecutor/model_engine.py (3)
1-1
: Add NVIDIA copyright headerPer repository guidelines, prepend the current-year NVIDIA copyright header to all source files.
Apply at file top:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
403-404
: Document and harden set_guided_worker APIAdd a short docstring and clarify return semantics. Also be explicit about accepted None to avoid accidental misuse.
Apply:
- def set_guided_worker(self, guided_worker: GuidedWorker) -> bool: + def set_guided_worker(self, guided_worker: GuidedWorker) -> bool: + """Attach a GuidedWorker to the underlying model (if supported) and cache it locally. + + Args: + guided_worker: The worker instance to attach. + Returns: + True if the underlying model accepted the worker; False otherwise. + """ if hasattr(self.model, "set_guided_worker"): success = self.model.set_guided_worker(guided_worker) if success: self.guided_worker = guided_worker return success return FalseAlso applies to: 472-479
1624-1627
: Guard against shape/dtype mismatches when passing new_tokensnew_tokens is copied into a pinned-int32 buffer inside GuidedWorker via squeeze(-1). If callers ever pass unexpected ranks/dtypes, it will fail at runtime.
Consider validating before enqueue:
- if self.guided_worker is not None: - self.guided_worker.add_batch(scheduled_requests, - new_tokens=new_tokens_device) + if self.guided_worker is not None: + if new_tokens_device is not None: + assert new_tokens_device.dtype in (torch.int32, torch.int64), "new_tokens dtype must be int" + assert new_tokens_device.dim() in (2, 3), "new_tokens must be [draft+1, batch] or [draft+1, batch, 1]" + self.guided_worker.add_batch(scheduled_requests, new_tokens=new_tokens_device)tensorrt_llm/_torch/pyexecutor/guided_decoder.py (4)
1-1
: Add NVIDIA copyright headerPer repository guidelines, prepend the current-year NVIDIA copyright header to all source files.
Apply at file top:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
445-455
: Add None-guards for rollback hostfuncsAvoid crashes when called before next_batch populated requests_hostfunc.
Apply:
def rollback_rejected_tokens(self) -> None: if self.max_num_draft_tokens <= 0: return - self._rollback_rejected_tokens(self.requests_hostfunc) + if self.requests_hostfunc is not None: + self._rollback_rejected_tokens(self.requests_hostfunc) @hostfunc def rollback_draft_tokens(self) -> None: if self.max_num_draft_tokens <= 0: return - self._rollback_draft_tokens(self.requests_hostfunc) + if self.requests_hostfunc is not None: + self._rollback_draft_tokens(self.requests_hostfunc)
456-467
: Defensive check in add_draft_batchValidate that a batch exists before recording the event; helps during warmup or misordered calls.
Apply:
def add_draft_batch(self, new_tokens: torch.Tensor, num_accepted_tokens: torch.Tensor, is_first_step: bool = False) -> None: - batch_size = len(self.requests) + if self.requests is None: + logger.warning("GuidedWorker.add_draft_batch: no active batch; skipping") + return + batch_size = len(self.requests)
208-209
: Long lines flagged by Ruff (E501)Several lines exceed 120 chars. Consider wrapping to satisfy style checks.
Locations:
- Line 208
- Line 217
- Line 279
- Line 304
Also applies to: 217-217, 279-279, 304-304
tensorrt_llm/_torch/speculative/eagle3.py (3)
1-1
: Add NVIDIA copyright headerPer repository guidelines, prepend the current-year NVIDIA copyright header to all source files.
Apply at file top:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
322-327
: Guard add_draft_batch against runtime failuresMatch the defensive pattern used for execute to prevent guided path failures from aborting drafting.
Apply:
- if self.guided_worker is not None: - new_tokens = inputs["input_ids"][gather_ids] - self.guided_worker.add_draft_batch(new_tokens, - num_accepted_tokens, - is_first_step=(i == 0)) + if self.guided_worker is not None: + try: + new_tokens = inputs["input_ids"][gather_ids] + self.guided_worker.add_draft_batch(new_tokens, + num_accepted_tokens, + is_first_step=(i == 0)) + except Exception as e: + import traceback + traceback.print_exc() + print(f"GuidedWorker.add_draft_batch failed; skipping this step: {e}")
517-519
: Add docstring to set_guided_workerPublic method should document behavior and return value.
Apply:
def set_guided_worker(self, guided_worker: GuidedWorker) -> bool: + """Attach a GuidedWorker for guided decoding; always returns True for Eagle3 one-model.""" self.guided_worker = guided_worker return True
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
(6 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py
(7 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(6 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
208-208: Line too long (171 > 120)
(E501)
209-209: Line too long (126 > 120)
(E501)
217-217: Line too long (126 > 120)
(E501)
279-279: Line too long (126 > 120)
(E501)
304-304: Line too long (181 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
44-64
: Import refactor looks goodThe new relative imports and surfaced symbols (BaseCheckpointLoader, SpecMetadata helpers, GuidedWorker, SampleStateTensors) are consistent with the module layout and improve dependency locality.
tensorrt_llm/_torch/speculative/eagle3.py (1)
269-270
: Initialize guided_worker to None is fineThe attribute introduction is clean and matches the new GuidedWorker API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/speculative/eagle3.py (2)
1-1
: Add required NVIDIA copyright header (2025).Per the coding guidelines, all source files must prepend the NVIDIA copyright header. This file is missing it.
Apply this diff at the top of the file:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
200-205
: Fix out-of-bounds capture when num_layers == 1.For one-layer models you set
layers_to_capture = (1,)
which is out-of-range (valid index is 0). This will makemaybe_capture_hidden_states
silently skip capture or cause issues depending on caller expectations.Apply this diff:
- if self.num_layers == 1: - self.layers_to_capture = (1, ) + if self.num_layers == 1: + self.layers_to_capture = (0, )
♻️ Duplicate comments (3)
tensorrt_llm/_torch/speculative/eagle3.py (3)
281-283
: Guard guided path: wrapguided_worker.execute
in try/except to avoid crashing forward pass.Uncaught exceptions here would take down the whole iteration. This mirrors prior feedback on guarding guided-decoder paths.
Apply this diff:
- if self.guided_worker is not None: - self.guided_worker.execute(logits) + if self.guided_worker is not None: + try: + self.guided_worker.execute(logits) + except Exception: + # Continue without guided decoding; keep raw logits path alive + self.logger.exception("Guided worker execute() failed; continuing without guidance")
322-327
: Guardguided_worker.add_draft_batch
to prevent stateful guided failures from aborting drafting.State errors (e.g., batch shape mismatches) can be raised here; handle gracefully and keep generation going.
Apply this diff:
- if self.guided_worker is not None: - new_tokens = inputs["input_ids"][gather_ids] - self.guided_worker.add_draft_batch(new_tokens, - num_accepted_tokens, - is_first_step=(i == 0)) + if self.guided_worker is not None: + new_tokens = inputs["input_ids"][gather_ids] + try: + self.guided_worker.add_draft_batch( + new_tokens, + num_accepted_tokens, + is_first_step=(i == 0), + ) + except Exception: + self.logger.exception("Guided worker add_draft_batch() failed; ignoring this guided step")
340-347
: Guardguided_worker.execute_draft_batch
and confirm attribute name.Good to see the
d2t
attribute used (fixing priordt2
typo elsewhere). Please also guard this call to avoid guided failures aborting the loop.Apply this diff:
- if self.guided_worker is not None: - d2t = getattr(draft_model.model, "d2t", None) - self.guided_worker.execute_draft_batch( - logits, - d2t, - is_first_step=(i == 0), - is_last_step=(i == self.max_draft_len - 1)) + if self.guided_worker is not None: + d2t = getattr(draft_model.model, "d2t", None) + try: + self.guided_worker.execute_draft_batch( + logits, + d2t, + is_first_step=(i == 0), + is_last_step=(i == self.max_draft_len - 1), + ) + except Exception: + self.logger.exception("Guided worker execute_draft_batch() failed; skipping guided draft execution for this step")
🧹 Nitpick comments (4)
tensorrt_llm/_torch/speculative/eagle3.py (4)
10-10
: Prefer module-namespace import; also add logging import for upcoming error handling.Guidelines ask to maintain module namespace in imports. Import the module and reference
GuidedWorker
via the module. We'll also needlogging
for the error-handling changes below.Apply this diff:
-from ..pyexecutor.guided_decoder import GuidedWorker +import logging +from ..pyexecutor import guided_decoderAnd update annotations/usages accordingly (see other diffs in this review).
264-271
: Initialize a class logger; annotate guided_worker via module namespace.We'll need a logger for the guarded guided-decoding calls. Also align the type annotation to module-namespace import.
Apply this diff:
def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping): super().__init__() self.spec_config = spec_config self.max_draft_len = self.spec_config.max_draft_len self.mapping = mapping - self.guided_worker: Optional[GuidedWorker] = None + # Local logger for error reporting in guided-decoding path + self.logger = logging.getLogger(__name__) + self.guided_worker: Optional[guided_decoder.GuidedWorker] = None
466-468
: Ensured2t
device matchesdraft_tokens
before indexing (avoid device mismatch).If
draft_model.model.d2t
is not registered as a CUDA buffer on the same device,d2t[draft_tokens]
will error. Either guaranteed2t
is a buffer on the model or guard locally.Apply this diff if you cannot guarantee buffer placement:
- if (d2t := getattr(draft_model.model, "d2t", None)) is not None: - draft_tokens = d2t[draft_tokens] + draft_tokens + if (d2t := getattr(draft_model.model, "d2t", None)) is not None: + # Ensure device alignment for indexing + d2t_local = d2t.to(draft_tokens.device, non_blocking=True) if getattr(d2t, "device", None) != draft_tokens.device else d2t + draft_tokens = d2t_local[draft_tokens] + draft_tokensFollow-up:
- Prefer registering
d2t
as a buffer ondraft_model.model
so it moves with.to(device)
and avoids per-step copies.
516-519
: Docstring the API and align type with module-namespace import.Provide a concise docstring and use the module namespace in the type annotation for consistency with import style.
Apply this diff:
- def set_guided_worker(self, guided_worker: GuidedWorker) -> bool: - self.guided_worker = guided_worker - return True + def set_guided_worker(self, guided_worker: guided_decoder.GuidedWorker) -> bool: + """Attach a guided-decoding worker instance. + + Args: + guided_worker: guided_decoder.GuidedWorker used to guide decoding. + + Returns: + True if the worker was attached. + """ + self.guided_worker = guided_worker + return True
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
tensorrt_llm/_torch/speculative/eagle3.py
(6 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/speculative/eagle3.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/speculative/eagle3.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tensorrt_llm/_torch/speculative/eagle3.py (1)
330-336
: LGTM: explicit toggle ofuse_spec_decoding
for nth steps.The intent and FIXME are clear; and
use_spec_decoding
is restored to True after the loop (Line 395). No action needed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
1758-1781
: No-cache path: enqueue a guided batch and record the token event.When
kv_cache_manager is None
, we still need to keep the guided path consistent so downstream execute paths don't see empty/None state.Apply this diff just before the return:
- return inputs, None + # Keep guided path consistent in no-cache mode + if self.guided_worker is not None: + # No new_tokens available in no-cache path; enqueue a placeholder batch. + self.guided_worker.add_batch(scheduled_requests, new_tokens=None) + self.guided_worker.token_event.record() + return inputs, None
♻️ Duplicate comments (2)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
1211-1213
: Do not record token_event in _preprocess_inputs; it races with add_batch enqueue.Recording here can happen before any H2D/D2H work is enqueued for the batch, leading to out-of-order dependencies or later dereferencing of uninitialized guided state. This was flagged earlier and remains unresolved.
Apply this diff to remove premature recording:
- if self.guided_worker is not None: - self.guided_worker.token_event.record()Then, record immediately after
guided_worker.add_batch(...)
in_prepare_tp_inputs
and add a guarded enqueue+record in the no-cache path (see next comments for diffs).Run to verify ordering after applying the diffs:
#!/bin/bash # Expect: no record() in _preprocess_inputs; exactly two record() call sites: # 1) after add_batch in _prepare_tp_inputs # 2) in _prepare_tp_inputs_no_cache before return rg -nP 'guided_worker\.token_event\.record\(' tensorrt_llm/_torch/pyexecutor/model_engine.py -n -C2 # Confirm add_batch call sites rg -nP 'guided_worker\.add_batch\s*\(' tensorrt_llm/_torch/pyexecutor/model_engine.py -n -C3
1623-1626
: Record token_event immediately after enqueueing the guided batch to preserve dependency ordering.This ensures the event reflects the actual submission of the guided inputs.
Apply this diff:
if self.guided_worker is not None: self.guided_worker.add_batch(scheduled_requests, new_tokens=new_tokens_device) + # Record after enqueue to preserve proper dependency ordering. + self.guided_worker.token_event.record()
🧹 Nitpick comments (4)
tensorrt_llm/_torch/pyexecutor/model_engine.py (4)
44-53
: Imports look consistent with the refactor; consider type-only import for GuidedWorker to avoid runtime deps/cycles.
- The added imports for checkpoints, speculative metadata, and sampler types look correct.
- For
GuidedWorker
, prefer a type-only import to prevent import-time side effects or circular dependencies when this module is imported.Apply this diff within the selected range to gate the GuidedWorker import:
-from .guided_decoder import GuidedWorker +# Avoid importing GuidedWorker at runtime to reduce import-time deps and cycles +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .guided_decoder import GuidedWorkerAdditionally, update type annotations at their use sites (examples shown in separate comments) to use forward references:
'GuidedWorker'
.Also applies to: 58-58, 63-63
1-2
: Missing NVIDIA copyright header (2025).Per the coding guidelines, prepend the current-year NVIDIA copyright header.
Add the standard header at the top of the file (outside the changed lines), for example:
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.
403-404
: Initialize attribute with forward-referenced type and document intent.Use a forward reference to align with the type-only import above and add a short attribute docstring.
Apply this diff:
- self.guided_worker: Optional[GuidedWorker] = None + self.guided_worker: Optional['GuidedWorker'] = None # Guided decoding controller; set via set_guided_worker()
472-479
: Public API: add a concise docstring and use a forward reference in the signature.The method is part of the external surface for enabling guided decoding; add a docstring for discoverability and keep type import lazy.
Apply this diff:
- def set_guided_worker(self, guided_worker: GuidedWorker) -> bool: + def set_guided_worker(self, guided_worker: 'GuidedWorker') -> bool: + """Attach a GuidedWorker to this engine and propagate it to the model if supported. + + Returns: + bool: True if the model accepted the worker and the engine stored it; False otherwise. + """ if hasattr(self.model, "set_guided_worker"): success = self.model.set_guided_worker(guided_worker) if success: self.guided_worker = guided_worker return success return False
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
tensorrt_llm/_torch/pyexecutor/model_engine.py
(7 hunks)tensorrt_llm/_torch/pyexecutor/py_executor.py
(0 hunks)
💤 Files with no reviewable changes (1)
- tensorrt_llm/_torch/pyexecutor/py_executor.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (5)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (5)
106-113
: Potential index error on empty batches.Accessing
self.requests[0]
will raise an IndexError if the batch is empty. Add a check for empty requests before accessing the first element.@property def num_bitmask_tokens(self) -> int: + if not self.requests: + return 0 if self.requests[0].is_draft: return len(self.requests) else: return self.num_contexts + self.num_generations * ( self.max_num_draft_tokens + 1)
194-253
: Incorrect advancement accounting and unsafe draft token access.The code has two issues:
num_advanced_tokens
is incremented unconditionally even when only initializing the matcher (line 233)req.draft_tokens
may be None but is used without checking (line 238)These issues can lead to incorrect rollback calculations and potential AttributeError.
if matcher_advance: matcher = self.grammar_matchers[slot] # The last new token must be acceptable unless the matcher is terminated: # 1. For the main model loop, when overlap scheduler is enabled, the matcher may have accepted the EOS token in the draft tokens at the previous iteration. # 2. For the draft model loop, the matcher may have accepted the EOS token at the previous drafting iteration. if matcher.is_terminated() or self.is_draft_terminated[slot]: continue accepted = matcher.accept_token(req.new_token) if not accepted: if req.is_draft: self.is_draft_terminated[slot] = True logger.debug( f"Draft request {req.request_id} at slot {slot} failed to accept last new token: {req.new_token}." ) continue # TODO: Make this an error response. raise ValueError( f"Request {req.request_id} at slot {slot} failed to accept last new token: {req.new_token}." ) + # Only count actually accepted tokens + self.num_advanced_tokens[slot] += 1 - self.num_advanced_tokens[slot] += 1 if not matcher.is_terminated(): matcher.fill_next_token_bitmask(self.bitmask_host, offset) self.num_guided_tokens[slot] += 1 # Process draft tokens - for i, tid in enumerate(req.draft_tokens, 1): + for i, tid in enumerate((req.draft_tokens or []), 1): accepted = matcher.accept_token(tid) if not accepted: break self.num_advanced_tokens[slot] += 1 if matcher.is_terminated(): break matcher.fill_next_token_bitmask(self.bitmask_host, offset + i) self.num_guided_tokens[slot] += 1
433-435
: Add bounds checking for tensor access.The
fetch_batch
method accesses tensors by slot index without validating that the slot is within bounds of the tensor dimensions.for req in self.requests_hostfunc: if (slot := req.seq_slot) is None: continue + if slot >= self.max_num_sequences: + logger.warning(f"Slot {slot} exceeds max_num_sequences {self.max_num_sequences}") + continue req.new_token, *req.draft_tokens = self.new_tokens[:, slot].tolist()
441-452
: Guard execute() when no batch is enqueued.When
add_batch
hasn't been called or if the queue is empty,self.requests_hostfunc
could be None, leading to a potential AttributeError in_build
. Add an early return to handle this case gracefully.def execute(self, logits: torch.Tensor, d2t: Optional[torch.Tensor] = None) -> None: + if self.requests_hostfunc is None: + logger.debug("GuidedWorker.execute: no requests enqueued; skipping guidance") + return with torch.cuda.stream(self._stream): torch.cuda.current_stream().wait_event(self.token_event) self.fetch_batch() self.build() self.copy_bitmask() self.bitmask_event.record() torch.cuda.current_stream().wait_event(self.bitmask_event) self.apply_bitmask(logits, d2t=d2t)
249-252
: Invalid assertion for draft requests.The assertion that draft tokens must be empty for draft requests is incorrect. Draft requests can have draft tokens from previous iterations that need to be processed.
if req.is_draft: - assert len(req.draft_tokens) == 0 self.num_advanced_draft_tokens[ slot] += self.num_advanced_tokens[slot]
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
(7 hunks)tensorrt_llm/_torch/pyexecutor/py_executor.py
(3 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(3 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/pyexecutor/py_executor.py
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/pyexecutor/py_executor.py
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
🧬 Code graph analysis (3)
tensorrt_llm/_torch/speculative/model_drafter.py (1)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (6)
add_batch
(285-287)add_batch
(412-421)execute
(303-312)execute
(441-452)rollback_draft_tokens
(366-367)rollback_draft_tokens
(459-460)
tensorrt_llm/_torch/pyexecutor/py_executor.py (4)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (7)
add_batch
(285-287)add_batch
(412-421)init_disagg_gen_requests
(384-385)execute
(303-312)execute
(441-452)rollback_rejected_tokens
(362-363)rollback_rejected_tokens
(455-456)tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
request_context
(114-138)tensorrt_llm/_torch/speculative/model_drafter.py (1)
prepare_draft_tokens
(324-410)tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
_forward_step
(2257-2278)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (5)
tensorrt_llm/sampling_params.py (1)
GuidedDecodingParams
(14-36)tensorrt_llm/_torch/hostfunc.py (1)
hostfunc
(19-24)tensorrt_llm/_torch/pyexecutor/grammar_matcher.py (16)
GrammarMatcher
(13-30)create
(36-38)create
(93-128)create
(195-222)is_terminated
(29-30)is_terminated
(57-58)is_terminated
(165-166)accept_token
(16-17)accept_token
(47-48)accept_token
(139-150)fill_next_token_bitmask
(24-26)fill_next_token_bitmask
(53-55)fill_next_token_bitmask
(159-163)rollback
(20-21)rollback
(50-51)rollback
(152-157)tensorrt_llm/_torch/pyexecutor/llm_request.py (2)
LlmRequest
(271-411)get
(99-108)tensorrt_llm/_torch/pyexecutor/scheduler.py (3)
ScheduledRequests
(18-39)all_requests
(38-39)batch_size
(35-36)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
216-216: Line too long (171 > 120)
(E501)
217-217: Line too long (126 > 120)
(E501)
225-225: Line too long (126 > 120)
(E501)
336-336: Line too long (181 > 120)
(E501)
🔇 Additional comments (8)
tensorrt_llm/_torch/speculative/model_drafter.py (3)
357-360
: LGTM! Batch-oriented guided decoding integration looks good.The changes properly integrate the new batch-oriented guided decoding API by calling
add_batch
followed byexecute
after the initial forward pass, which aligns with the new GuidedDecoder interface design.
380-383
: LGTM! Consistent batch-oriented guided decoding in iteration loop.The implementation correctly applies the same batch-oriented pattern (
add_batch
+execute
) within the draft token generation loop, maintaining consistency with the new API.
403-404
: LGTM! Proper rollback handling with new API.The rollback logic correctly uses the new batch-oriented API by first calling
add_batch
with the scheduled requests, then invoking the no-argumentrollback_draft_tokens()
method. This aligns with the updated GuidedDecoder interface where rollback operates on internally stored requests.tensorrt_llm/_torch/pyexecutor/py_executor.py (5)
744-750
: LGTM! Proper batch-oriented guided decoding integration in PP mode.The implementation correctly integrates guided decoding for pipeline parallelism by calling
add_batch
followed by the conditionalinit_disagg_gen_requests()
for disaggregated serving, then executing with the logits. This maintains the proper sequencing for the new batch-oriented API.
934-937
: LGTM! Consistent guided decoding setup in main executor loop.The integration properly handles the batch-oriented guided decoding setup before draft token preparation, maintaining the same pattern as in the PP executor loop.
944-944
: LGTM! Proper rollback without arguments.The rollback call correctly uses the new no-argument signature, operating on the internally stored requests from the prior
add_batch
call.
949-950
: LGTM! Clean separation of guided decoding execution.The execute call properly passes only the logits tensor, following the new simplified interface where batch information is already stored internally from the
add_batch
call.
1065-1069
: LGTM! Consistent guided decoding in overlap scheduler.The overlap scheduler correctly implements the same batch-oriented guided decoding pattern, maintaining consistency across all executor loops.
1d2b09b
to
7111c5a
Compare
/bot run --disable-fail-fast |
PR_Github #17274 [ run ] triggered by Bot |
PR_Github #17274 [ run ] completed with state |
/bot run |
PR_Github #17328 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked at the GIL releases and found two cases where the GIL release is not safe. Let me know what you think
PR_Github #17328 [ run ] completed with state |
Signed-off-by: Enwei Zhu <[email protected]>
Signed-off-by: Enwei Zhu <[email protected]>
/bot run --disable-fail-fast |
PR_Github #17438 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR_Github #17438 [ run ] completed with state |
/bot run --disable-fail-fast |
PR_Github #17490 [ run ] triggered by Bot |
PR_Github #17490 [ run ] completed with state |
…(part 2: one-model engine) (NVIDIA#6948) Signed-off-by: Enwei Zhu <[email protected]>
[TRTLLM-7028][feat] Enable guided decoding with speculative decoding (part 2: one-model engine)
Background
When multiple (draft and target) generation steps are launched by a single CUDA Graph, the CPU-GPU synchronization is difficult:
i
, there are two event waitsi-1
i
CUDA callback cudaLaunchHostFunc can launch a host function to a CUDA stream, so the host function can be captured and replayed by CUDA graph. Hence, we can launch the grammar compute as a CUDA callback. The CUDA graph should capture and replay multi-step model forwards and grammar compute all together.
See #6414 for more details on the CUDA callback integration to Python.
Description
This PR supports guided decoding with one-model speculative decoding:
This PR introduces a variant of
GuidedDecoder
--CapturableGuidedDecoder
, which is embedded in the modeling code and capturable by CUDA graph.This PR enables GIL releases for some bindings which are likely to call CUDA APIs. This is to avoid potential deadlock with the CUDA callback. Without this context, releasing GIL should be a recommended practice -- for example, PyTorch almost always releases GIL.
TODO (Aug. 28)
Fix GIL issues of bindingsFix nanobindFree hostfunc handles automaticallyTest Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]
Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id
(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test
(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"
(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log
(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug
(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-list
parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
and the
scripts/test_to_stage_mapping.py
helper.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.