Skip to content

Conversation

syuoni
Copy link
Collaborator

@syuoni syuoni commented Aug 15, 2025

[TRTLLM-7028][feat] Enable guided decoding with speculative decoding (part 2: one-model engine)

Background

When multiple (draft and target) generation steps are launched by a single CUDA Graph, the CPU-GPU synchronization is difficult:

  • For every step i, there are two event waits
    • Before grammar advance, the host waits the CPU tokens from step i-1
    • Before sampling, the main stream waits the GPU bitmask from step i
  • If multi-step forwards are launched by a single CUDA Graph, it is not possible to let the CUDA Graph wait events recorded from the host anymore.

CUDA callback cudaLaunchHostFunc can launch a host function to a CUDA stream, so the host function can be captured and replayed by CUDA graph. Hence, we can launch the grammar compute as a CUDA callback. The CUDA graph should capture and replay multi-step model forwards and grammar compute all together.

image

See #6414 for more details on the CUDA callback integration to Python.

Description

This PR supports guided decoding with one-model speculative decoding:

  • MTP vanilla
  • MTP eagle
  • Eagle3 (one-model implementation)

This PR introduces a variant of GuidedDecoder -- CapturableGuidedDecoder, which is embedded in the modeling code and capturable by CUDA graph.

This PR enables GIL releases for some bindings which are likely to call CUDA APIs. This is to avoid potential deadlock with the CUDA callback. Without this context, releasing GIL should be a recommended practice -- for example, PyTorch almost always releases GIL.

TODO (Aug. 28)

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Copy link
Contributor

coderabbitai bot commented Aug 15, 2025

📝 Walkthrough

Walkthrough

Adds CUDA host-function interop via pybind (launch/free), exposes a Torch helper module, and integrates a new batch-oriented GuidedWorker into guided decoding across executor, model engine, and speculative components. Adjusts imports/APIs, propagates guided metadata, updates speculative drafter and Eagle3 to drive the worker, and fixes dataclass defaults.

Changes

Cohort / File(s) Summary
PyBind runtime: hostfunc + wiring
cpp/tensorrt_llm/pybind/CMakeLists.txt, cpp/tensorrt_llm/pybind/runtime/bindings.cpp, cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp, cpp/tensorrt_llm/pybind/runtime/hostfunc.h
Adds hostfunc bindings (launch/free) to call Python from CUDA via cudaLaunchHostFunc; wires init into bindings; updates build sources/order.
Torch hostfunc helper
tensorrt_llm/_torch/hostfunc.py
Provides Python-side launch/decorator/cleanup for hostfunc; tracks and frees user-data handles; atexit cleanup.
Guided decoding core (worker + batching)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py, tensorrt_llm/_torch/pyexecutor/py_executor.py, tensorrt_llm/_torch/pyexecutor/py_executor_creator.py, tensorrt_llm/_torch/pyexecutor/model_engine.py, tensorrt_llm/_torch/speculative/model_drafter.py
Introduces GuidedWorker (batch-oriented), new request structs and bitmask layout, hostfunc-driven control, add_batch/execute flow, rollback signature changes, creator wiring, and model engine propagation of guided metadata. Removes old helper paths.
Speculative model integration
tensorrt_llm/_torch/models/modeling_speculative.py, tensorrt_llm/_torch/speculative/eagle3.py
Adds set_guided_worker, invokes GuidedWorker during main and draft logits, passes draft tokens and d2t to worker; import path adjustments.
Speculative interfaces defaults
tensorrt_llm/_torch/speculative/interface.py, tensorrt_llm/_torch/speculative/mtp.py
Fixes dataclass trailing commas to correct defaults; MTP updates to use Python list for lens, adds per-request bookkeeping; minor signature default formatting.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor Py as Python (Torch)
  participant TR as pybind runtime
  participant CU as CUDA Stream
  participant CB as HostFunc Callback

  Py->>TR: launch_hostfunc(stream_ptr, py_hostfunc, *args, **kwargs)
  TR->>CU: cudaLaunchHostFunc(trampoline, user_data)
  Note over CU,CB: Callback scheduled on stream
  CU-->>CB: invoke trampoline(user_data)
  CB->>CB: Acquire GIL, call py_hostfunc(*args, **kwargs)
  CB-->>Py: Python callable executes (errors printed, if any)
  Py->>TR: free_hostfunc_user_data(handle)
  TR-->>Py: user-data freed
Loading
sequenceDiagram
  autonumber
  participant Eng as PyTorchModelEngine
  participant Cr as PyExecutorCreator
  participant GW as GuidedWorker
  participant Dra as ModelDrafter/Eagle3
  participant Ex as PyExecutor
  participant Dev as Device (Kernels)

  Cr->>Eng: set_guided_worker(GuidedWorker(...))
  Eng-->>Cr: bool

  loop Per batch
    Ex->>GW: add_batch(scheduled_requests[, new_tensors])
    Dra->>GW: add_draft_batch(...) (draft mode)
    Ex->>Dev: run model forward
    Dev-->>Ex: logits (and d2t if any)
    Ex->>GW: execute(logits[, d2t])
    Dra->>GW: execute_draft_batch(logits, d2t, step flags)
    opt Rejections
      GW->>GW: rollback_rejected_tokens()
    end
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • Tabrizian
  • mikeiovine
  • Shixiaowei02
  • nv-guomingz
  • QiJune
  • yweng0828

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 12

🧹 Nitpick comments (11)
cpp/tensorrt_llm/pybind/runtime/hostfunc.h (1)

1-27: Add include guard per repo guidelines; prefer explicit guard over pragma once

The coding guidelines require header guards named TRTLLM__H. Replace #pragma once with an explicit guard.

-#pragma once
+#ifndef TRTLLM_HOSTFUNC_H
+#define TRTLLM_HOSTFUNC_H
@@
-} // namespace tensorrt_llm::pybind::runtime
+} // namespace tensorrt_llm::pybind::runtime
+
+#endif // TRTLLM_HOSTFUNC_H
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)

342-346: Consider refactoring the TODO comment implementation

The commented-out runtime guard for guided decoding has been removed, and the guided decoder assignment pattern with TODO: fix it seems like a temporary workaround. The guided decoder is instantiated but then immediately nullified after being assigned to the spec worker.

Consider a cleaner approach that avoids the intermediate assignment and nullification:

-            # TODO: fix it
-            model_engine.model.spec_worker.guided_decoder = guided_decoder
-            guided_decoder = None
+            # Assign guided decoder to spec worker for guided decoding in speculative mode
+            if hasattr(model_engine.model, 'spec_worker'):
+                model_engine.model.spec_worker.guided_decoder = guided_decoder
+                guided_decoder = None  # Ownership transferred to spec worker
+            else:
+                logger.warning("Spec worker not found for guided decoding assignment")
tensorrt_llm/_torch/speculative/eagle3.py (1)

335-336: Consider implementing guided decoder for draft token generation

The comment "Insert guided decoder" indicates a placeholder where guided decoding logic should be applied during draft token generation. This seems incomplete compared to the target model path.

Would you like me to help implement the guided decoder logic for draft token generation? This would ensure consistency between target and draft model guided decoding behavior.

tensorrt_llm/_torch/hostfunc.py (1)

8-8: Consider thread-safety for global registry

The HOSTFUNC_USER_DATA_HANDLES set is used globally without synchronization, which could lead to race conditions if multiple threads call the host function APIs concurrently.

Consider adding thread-safety:

+import threading
 import atexit

 import torch

 from ..bindings.internal import runtime as bindings
 from ..logger import logger

 HOSTFUNC_USER_DATA_HANDLES = set()
+_HOSTFUNC_LOCK = threading.Lock()

Then protect all accesses to the set with the lock in the relevant functions.

cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp (1)

44-58: Consider adding exception catching for all Python exceptions

The current error handling in cudaHostFuncTrampoline only catches py::error_already_set. Consider catching all exceptions to prevent crashes from unexpected Python exceptions.

Apply this diff to improve exception handling:

 static void cudaHostFuncTrampoline(void* userData)
 {
     auto* hostFuncUserData = static_cast<HostFuncUserData*>(userData);
     // Acquire the GIL since we are calling Python code from a CUDA stream.
     py::gil_scoped_acquire gil;
     try
     {
         hostFuncUserData->pyHostFunc(*hostFuncUserData->pyArgs, **hostFuncUserData->pyKwargs);
     }
     catch (py::error_already_set& e)
     {
         e.restore();
         PyErr_Print();
     }
+    catch (std::exception& e)
+    {
+        py::print("Error in host function:", e.what());
+    }
+    catch (...)
+    {
+        py::print("Unknown error in host function");
+    }
 }
examples/guided_decoding/try_pytorch_bindings_v3.py (2)

24-25: Consider using torch.no_grad() for inference

The forward pass in the DummyModel should be wrapped with torch.no_grad() to avoid unnecessary gradient computation during inference.

     def forward(self, x: torch.Tensor):
         # Simulate some GPU computation.
-        for i in range(10):
-            torch.matmul(self.a, self.b)
+        with torch.no_grad():
+            for i in range(10):
+                torch.matmul(self.a, self.b)

96-97: Add error handling for decoding failures

The assertion and decoding could fail without proper error handling.

-    assert len(guided_decoder.token_ids) == 20
-    print(tokenizer.decode(guided_decoder.token_ids))
+    if len(guided_decoder.token_ids) != 20:
+        logger.error(f"Expected 20 tokens, got {len(guided_decoder.token_ids)}")
+    try:
+        decoded_text = tokenizer.decode(guided_decoder.token_ids)
+        print(f"Decoded text: {decoded_text}")
+    except Exception as e:
+        logger.error(f"Failed to decode tokens: {e}")
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (3)

170-171: Line length exceeds coding guideline limit

Lines 170-171 exceed the 120-character limit specified in the coding guidelines.

-                # The last new token must be acceptable unless the matcher is terminated:
-                # 1. For the main model loop, when overlap scheduler is enabled, the matcher may have accepted the EOS token in the draft tokens at the previous iteration.
-                # 2. For the draft model loop, the matcher may have accepted the EOS token at the previous drafting iteration.
+                # The last new token must be acceptable unless the matcher is terminated:
+                # 1. For the main model loop, when overlap scheduler is enabled, the matcher may have
+                #    accepted the EOS token in the draft tokens at the previous iteration.
+                # 2. For the draft model loop, the matcher may have accepted the EOS token at the
+                #    previous drafting iteration.

285-285: Line length exceeds coding guideline limit

Line 285 exceeds the 120-character limit.

-                    f"Failed to rollback: num_advanced_tokens={self.num_advanced_tokens[slot]}, num_accepted_tokens={num_accepted_tokens}, num_rollback_tokens={num_rollback_tokens}"
+                    f"Failed to rollback: num_advanced_tokens={self.num_advanced_tokens[slot]}, "
+                    f"num_accepted_tokens={num_accepted_tokens}, num_rollback_tokens={num_rollback_tokens}"

328-337: V2 API wrappers lack input validation

The v2 wrapper methods don't validate that guided_metadata contains the required fields before forwarding.

     def build_v2(self, guided_metadata: GuidedMetadata):
+        if guided_metadata.scheduled_requests is None:
+            raise ValueError("guided_metadata.scheduled_requests cannot be None")
         self.build(guided_metadata.scheduled_requests,
                    guided_metadata.gathered_input_ids)
 
     def execute_v2(self, guided_metadata: GuidedMetadata, logits: torch.Tensor):
+        if guided_metadata.scheduled_requests is None:
+            raise ValueError("guided_metadata.scheduled_requests cannot be None")
         self.execute(guided_metadata.scheduled_requests, logits)
 
     def rollback_rejected_tokens_v2(self, guided_metadata: GuidedMetadata):
+        if guided_metadata.scheduled_requests is None:
+            raise ValueError("guided_metadata.scheduled_requests cannot be None")
         self.rollback_rejected_tokens(guided_metadata.scheduled_requests,
                                       guided_metadata.new_tokens_lens)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

403-409: Consider lazy initialization for guided decoding tensors

The gathered_input_ids and new_tokens_lens tensors are always allocated even when guided decoding is not used, consuming pinned memory unnecessarily.

Consider moving the tensor initialization to _set_up_guided_metadata() to allocate them only when needed:

         self.guided_metadata: Optional[GuidedMetadata] = None
-        self.gathered_input_ids = torch.empty(self.max_num_tokens,
-                                              dtype=torch.int,
-                                              pin_memory=True)
-        self.new_tokens_lens = torch.empty(self.batch_size,
-                                           dtype=torch.int,
-                                           pin_memory=True)
+        self.gathered_input_ids = None
+        self.new_tokens_lens = None

     def _set_up_guided_metadata(self):
         if self.guided_metadata is None:
             self.guided_metadata = GuidedMetadata()
+            if self.gathered_input_ids is None:
+                self.gathered_input_ids = torch.empty(self.max_num_tokens,
+                                                      dtype=torch.int,
+                                                      pin_memory=True)
+            if self.new_tokens_lens is None:
+                self.new_tokens_lens = torch.empty(self.batch_size,
+                                                   dtype=torch.int,
+                                                   pin_memory=True)
         return self.guided_metadata
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between f7dbc14 and 52fcc56.

📒 Files selected for processing (14)
  • cpp/tensorrt_llm/pybind/CMakeLists.txt (1 hunks)
  • cpp/tensorrt_llm/pybind/runtime/bindings.cpp (2 hunks)
  • cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp (1 hunks)
  • cpp/tensorrt_llm/pybind/runtime/hostfunc.h (1 hunks)
  • examples/guided_decoding/try_pytorch_bindings_v2.py (1 hunks)
  • examples/guided_decoding/try_pytorch_bindings_v3.py (1 hunks)
  • tensorrt_llm/_torch/hostfunc.py (1 hunks)
  • tensorrt_llm/_torch/models/modeling_speculative.py (4 hunks)
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py (8 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (11 hunks)
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (2 hunks)
  • tensorrt_llm/_torch/speculative/eagle3.py (3 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (1 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (6)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...

Files:

  • cpp/tensorrt_llm/pybind/runtime/hostfunc.h
  • cpp/tensorrt_llm/pybind/runtime/bindings.cpp
  • cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Parameter names must be consistent between declarations and definitions

Files:

  • cpp/tensorrt_llm/pybind/runtime/hostfunc.h
  • cpp/tensorrt_llm/pybind/runtime/bindings.cpp
  • cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp
**/*.{h,hpp,hxx,hh,cuh}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Header files must use include guards named TRTLLM__H without underscores prefix/suffix (e.g., TRTLLM_FOO_BAR_HELLO_H)

Files:

  • cpp/tensorrt_llm/pybind/runtime/hostfunc.h
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • cpp/tensorrt_llm/pybind/runtime/hostfunc.h
  • tensorrt_llm/_torch/speculative/mtp.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • examples/guided_decoding/try_pytorch_bindings_v2.py
  • examples/guided_decoding/try_pytorch_bindings_v3.py
  • tensorrt_llm/_torch/hostfunc.py
  • tensorrt_llm/_torch/speculative/interface.py
  • cpp/tensorrt_llm/pybind/runtime/bindings.cpp
  • tensorrt_llm/_torch/speculative/eagle3.py
  • cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py
  • tensorrt_llm/_torch/models/modeling_speculative.py
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/speculative/mtp.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • examples/guided_decoding/try_pytorch_bindings_v2.py
  • examples/guided_decoding/try_pytorch_bindings_v3.py
  • tensorrt_llm/_torch/hostfunc.py
  • tensorrt_llm/_torch/speculative/interface.py
  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py
  • tensorrt_llm/_torch/models/modeling_speculative.py
**/*.{cpp,cxx,cc,cu}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu}: Avoid literal values except for 0, nullptr, true, false; use named constexpr for other literals
Place semicolon of empty for/while loop on a new line
Always use brace-delimited bodies for switch/while/do-for/if/else
Use inline C comments in argument lists when parameter meaning is unclear (e.g., /* checkForErrors = */ false)
Do not use assignment in subexpressions (e.g., if (x = y) ... is forbidden)
Switch on enums should enumerate all values and omit default to catch new values at compile time
Structure switch statements; prohibit fallthrough except between empty cases; each case ends with break or throw; return at end of case not allowed; put break inside braces for compound case
Prefer anonymous namespaces over static for internal linkage of functions
Every defined function must be called at least once (no unused methods)

Files:

  • cpp/tensorrt_llm/pybind/runtime/bindings.cpp
  • cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

344-344: Line too long (132 > 120)

(E501)

tensorrt_llm/_torch/pyexecutor/guided_decoder.py

170-170: Line too long (171 > 120)

(E501)


171-171: Line too long (126 > 120)

(E501)


285-285: Line too long (181 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (11)
tensorrt_llm/_torch/speculative/interface.py (1)

124-128: Fix removes accidental tuple defaults; aligns runtime types with annotations

Removing the trailing commas ensures these defaults are the intended scalar values (enum/None), not single-element tuples. This resolves subtle bugs when comparing or checking truthiness and matches the annotated types.

cpp/tensorrt_llm/pybind/runtime/bindings.cpp (2)

19-19: Include hostfunc.h to wire HostFunc bindings

The new include is correct and local-include path is appropriate given this file’s directory. This unblocks registration of HostFunc bindings.


465-466: Initialize HostFunc bindings in module init

Calling initHostFuncBindings(m) after initMoeBindings(m) cleanly registers the new Python-facing APIs. No ordering issues apparent.

cpp/tensorrt_llm/pybind/CMakeLists.txt (1)

17-20: Add hostfunc.cpp to pybind SRCS

Including runtime/hostfunc.cpp ensures the new bindings are compiled and linked. Moving testing/modelSpecBinding.cpp doesn’t materially affect linkage; order here is fine.

tensorrt_llm/_torch/speculative/eagle3.py (2)

269-269: LGTM! Clean integration of guided decoder

The addition of the optional guided_decoder member is well-structured and follows best practices for optional features.


273-274: LGTM! Forward signature properly extended

The forward method signature correctly adds the guided_metadata parameter in the expected position, maintaining backwards compatibility.

tensorrt_llm/_torch/models/modeling_speculative.py (3)

7-7: LGTM! Clean import organization

The import path adjustment to use relative imports is appropriate and follows Python best practices.


372-372: LGTM! Proper API extension for guided decoding

The addition of the optional guided_metadata parameter to the forward signature is well-designed, maintaining backward compatibility while enabling guided decoding functionality.


399-399: LGTM! Correct propagation of guided metadata

The guided_metadata is properly threaded through to the spec_worker, enabling guided decoding in the speculative path.

cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp (1)

60-75: Missing null pointer validation

The function doesn't validate that streamPtr is non-null before casting and using it.

Add null pointer validation:

 uintptr_t launchHostFunc(uintptr_t streamPtr, py::function pyHostFunc, py::args pyArgs, py::kwargs pyKwargs)
 {
+    if (streamPtr == 0)
+    {
+        throw std::invalid_argument("Stream pointer cannot be null");
+    }
     auto const stream = reinterpret_cast<cudaStream_t>(streamPtr);
 
     auto hostFuncUserData = std::make_unique<HostFuncUserData>(pyHostFunc, py::tuple(pyArgs), py::dict(pyKwargs));
 
     cudaError_t err = cudaLaunchHostFunc(stream, cudaHostFuncTrampoline, hostFuncUserData.get());
     if (err != cudaSuccess)
     {
-        throw std::runtime_error("Failed to launch host function.");
+        throw std::runtime_error(std::string("Failed to launch host function: ") + cudaGetErrorString(err));
     }

Likely an incorrect or invalid review comment.

tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

1624-1642: Missing null check before accessing guided_decoder

The code checks if guided_metadata is not None but doesn't verify that self.model.spec_worker.guided_decoder exists before calling methods on it.

         if guided_metadata is not None:
             guided_metadata.scheduled_requests = scheduled_requests
             if new_tensors_device is None:
                 guided_metadata.gathered_input_ids = None
                 guided_metadata.new_tokens_lens = None
             else:
                 num_gathered = len(gather_ids)
                 gathered_input_ids_cuda = self.input_ids_cuda[
                     self.gather_ids_cuda[:num_gathered]]
                 self.gathered_input_ids[:num_gathered].copy_(
                     gathered_input_ids_cuda, non_blocking=True)
                 guided_metadata.gathered_input_ids = self.gathered_input_ids[:
                                                                              num_gathered]
                 self.new_tokens_lens.copy_(new_tokens_lens_device,
                                            non_blocking=True)
                 guided_metadata.new_tokens_lens = self.new_tokens_lens
             guided_metadata.token_event.record()
             inputs['guided_metadata'] = guided_metadata

Likely an incorrect or invalid review comment.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🔭 Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)

346-359: Ensure guided_decoder is propagated into the speculative drafter

Right now you attach the newly created GuidedDecoder to model_engine.model.spec_worker and then do guided_decoder = None. But later you call

drafter = get_spec_drafter(
    model_engine,
    draft_model_engine,
    sampler,
    spec_resource_manager=spec_resource_manager,
    guided_decoder=guided_decoder,
)

with guided_decoder=None, so the resulting ModelDrafter never sees the guided decoder you just built. You need to either:

• Pass the real decoder instance into get_spec_drafter instead of nulling it out:

  • In _torch/pyexecutor/py_executor_creator.py at lines 356–358, remove guided_decoder = None and always forward the decoder.
    • Or make get_spec_drafter fall back to model_engine.model.spec_worker.guided_decoder when its guided_decoder arg is None.

Pinpointed changes:

  • _torch/pyexecutor/py_executor_creator.py@356–358
  • _torch/speculative/utils.py@get_spec_drafter: add a guard like
    if guided_decoder is None and hasattr(model_engine.model, "spec_worker"):
        guided_decoder = model_engine.model.spec_worker.guided_decoder

These refactors are required to ensure guided decoding actually flows into your drafter.

♻️ Duplicate comments (5)
tensorrt_llm/_torch/speculative/mtp.py (2)

103-105: Fix tuple default: mtp_num_modules mistakenly a tuple due to trailing comma

The trailing comma makes mtp_num_modules a tuple (1,) instead of int 1, which will break arithmetic and torch ops (e.g., torch.arange(self.mtp_num_modules)).

-    mtp_num_modules: int = 1,
+    mtp_num_modules: int = 1

268-270: Accepted-count and rewind length have correct int semantics (good); keep it that way

num_new_tokens here is a Python int (from list/tensor.item), so both py_num_accepted_draft_tokens and py_rewind_len are plain ints. This addresses previous tensor/int mixing concerns downstream.

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)

342-346: Fix long comment line; keep under 120 chars

Ruff flagged the long line in this commented-out guard. Wrap it to satisfy line-length limits.

-            #     raise ValueError(
-            #         "Guided decoding is only supported with speculative decoding that has a dedicated drafter (two-model engine)."
-            #     )
+            #     raise ValueError(
+            #         "Guided decoding is only supported with speculative decoding that has a "
+            #         "dedicated drafter (two-model engine)."
+            #     )
tensorrt_llm/_torch/speculative/eagle3.py (1)

281-294: Guard guided-decoder sync/execute to prevent hangs or hard failures

token_event.synchronize()/waits and guided-decoder calls can raise or hang; wrap with try/except and log, so the engine can continue without guidance on failure.

-        if self.guided_decoder is not None and guided_metadata is not None:
-            with torch.cuda.stream(self.guided_decoder._stream):
-                torch.cuda.current_stream().wait_event(
-                    guided_metadata.token_event)
-                self.guided_decoder.rollback_rejected_tokens_v2(guided_metadata)
-                self.guided_decoder.build_v2(guided_metadata)
-                self.guided_decoder.copy_bitmask(
-                    guided_metadata.scheduled_requests)
-                guided_metadata.bitmask_event.record()
-
-            torch.cuda.current_stream().wait_event(
-                guided_metadata.bitmask_event)
-            self.guided_decoder.execute_v2(guided_metadata, logits)
+        if self.guided_decoder is not None and guided_metadata is not None:
+            try:
+                with torch.cuda.stream(self.guided_decoder._stream):
+                    torch.cuda.current_stream().wait_event(guided_metadata.token_event)
+                    self.guided_decoder.rollback_rejected_tokens_v2(guided_metadata)
+                    self.guided_decoder.build_v2(guided_metadata)
+                    self.guided_decoder.copy_bitmask(guided_metadata.scheduled_requests)
+                    guided_metadata.bitmask_event.record()
+                torch.cuda.current_stream().wait_event(guided_metadata.bitmask_event)
+                self.guided_decoder.execute_v2(guided_metadata, logits)
+            except Exception as e:
+                logger.error(f"Guided decoder failed; continuing without guidance: {e}")
+                import traceback as _tb; _tb.print_exc()

Add missing imports to support logging and traceback:

@@
-from tensorrt_llm.mapping import Mapping
+from tensorrt_llm.mapping import Mapping
+from tensorrt_llm.logger import logger
+import traceback  # used in failure path (aliased locally above to avoid linter complaints)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (1)

345-355: Add bounds checks and exception handling in host-callable run()

run() assumes grammar_matchers[0] exists and token_ids has at least one element. Add checks and handle exceptions to avoid hard crashes when called from host.

     @hostfunc
     def run(self, token_ids: torch.Tensor):
-        self.grammar_matchers[0].accept_token(token_ids[0].item())
-        self.grammar_matchers[0].fill_next_token_bitmask(self.bitmask_host, 0)
-        if not hasattr(self, "token_ids"):
-            self.token_ids = []
-        self.token_ids.append(token_ids[0].item())
+        if token_ids.numel() == 0:
+            logger.warning("GuidedDecoder.run received empty token_ids")
+            return
+        if not self.grammar_matchers or self.grammar_matchers[0] is None:
+            logger.warning("GuidedDecoder.run called before matcher initialization")
+            return
+        try:
+            tid0 = token_ids[0].item()
+            self.grammar_matchers[0].accept_token(tid0)
+            self.grammar_matchers[0].fill_next_token_bitmask(self.bitmask_host, 0)
+        except Exception as e:
+            logger.error(f"Error in GuidedDecoder.run: {e}")
+            return
+        if not hasattr(self, "token_ids"):
+            self.token_ids = []
+        self.token_ids.append(tid0)
🧹 Nitpick comments (12)
tensorrt_llm/_torch/speculative/mtp.py (2)

1-3: Add NVIDIA copyright header

Per repository guidelines, prepend the current-year NVIDIA copyright header to all source files.

Apply at file top:

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.

250-252: Avoid materializing full Python lists for lengths; fetch per-request as int

new_tokens_lens is already a CPU tensor; creating a full Python list is unnecessary. Index the tensor per request and cast to int. This reduces overhead for large batches.

-        new_tokens_lens_list = state.host.new_tokens_lens.tolist()
+        # Avoid building a full Python list; index per request and cast to int below

And:

-            num_new_tokens = new_tokens_lens_list[req.py_seq_slot]
+            num_new_tokens = int(state.host.new_tokens_lens[req.py_seq_slot])

Also applies to: 263-264

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (2)

1-3: Add NVIDIA copyright header

Per repository guidelines, prepend the current-year NVIDIA copyright header to all source files.

+// Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.

356-359: Remove or ticket-ize the TODO for production readiness

The TODO: fix it suggests unfinished integration. Convert to a tracked issue or clarify the intended ownership model in a comment.

tensorrt_llm/_torch/speculative/eagle3.py (2)

1-3: Add NVIDIA copyright header

Per repository guidelines, prepend the current-year NVIDIA copyright header to all source files.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.

283-284: Avoid reaching into GuidedDecoder private attribute _stream

Accessing a private attribute is brittle. Expose a public property on GuidedDecoder and use it here.

In GuidedDecoder (see file tensorrt_llm/_torch/pyexecutor/guided_decoder.py), add:

@@ class GuidedDecoder:
     def bitmask_size(self) -> int:
         return math.ceil(self.vocab_size_padded / 32)
 
+    @property
+    def stream(self) -> torch.cuda.Stream:
+        return self._stream

Then update this call site:

-            try:
-                with torch.cuda.stream(self.guided_decoder._stream):
+            try:
+                with torch.cuda.stream(self.guided_decoder.stream):
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (6)

1-4: Add NVIDIA copyright header

Per repository guidelines, prepend the current-year NVIDIA copyright header to all source files.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
 import math
 from dataclasses import dataclass, field
 from typing import Iterable, List, Optional, Tuple

114-126: Wrap long comment and clarify condition

Break the long comment lines to satisfy linters; keep the rationale succinct.

-            # Fix race condition: The following host code may change the state of requests and in turn affect the result of hostfunc.
-            # Where the hostfunc is launched, but not executed yet.
-            # if llm_req.is_generation_in_progress_state:
+            # Fix race condition: host code may change request state while a hostfunc is
+            # queued but not yet executed, altering the expected offsets. Use a broader
+            # condition than is_generation_in_progress_state here.
+            # if llm_req.is_generation_in_progress_state:

180-182: Wrap long comment lines under 120 characters

Line length exceeds 120. Reflow for readability and linting.

-                # 1. For the main model loop, when overlap scheduler is enabled, the matcher may have accepted the EOS token in the draft tokens at the previous iteration.
-                # 2. For the draft model loop, the matcher may have accepted the EOS token at the previous drafting iteration.
+                # 1. Main model loop: with overlap scheduler, the matcher may have accepted EOS in the
+                #    draft tokens during the previous iteration.
+                # 2. Draft model loop: the matcher may have accepted EOS at the previous drafting step.

288-289: Wrap long error message string

The f-string exceeds 120 chars. Wrap to satisfy linters.

-                raise ValueError(
-                    f"Failed to rollback: num_advanced_tokens={self.num_advanced_tokens[slot]}, num_accepted_tokens={num_accepted_tokens}, num_rollback_tokens={num_rollback_tokens}"
-                )
+                raise ValueError(
+                    "Failed to rollback: "
+                    f"num_advanced_tokens={self.num_advanced_tokens[slot]}, "
+                    f"num_accepted_tokens={num_accepted_tokens}, "
+                    f"num_rollback_tokens={num_rollback_tokens}"
+                )

331-335: Mark execute_v2 as hostfunc (optional) and keep v2 API symmetry

build_v2 and rollback_rejected_tokens_v2 are hostfuncs; execute_v2 could remain device-side, but for symmetry and scheduling guarantees you may want to expose it as hostfunc too. If intentional to keep it non-hostfunc, ignore.

Option A (make it hostfunc):

-    def execute_v2(self, guided_metadata: GuidedMetadata, logits: torch.Tensor):
+    @hostfunc
+    def execute_v2(self, guided_metadata: GuidedMetadata, logits: torch.Tensor):
         self.execute(guided_metadata.scheduled_requests, logits)

Also applies to: 339-343


89-92: Expose a public stream property to avoid private attribute access

Provide a public property so callers (e.g., Eagle3) don't reach into _stream directly.

     def bitmask_size(self) -> int:
         return math.ceil(self.vocab_size_padded / 32)
 
+    @property
+    def stream(self) -> torch.cuda.Stream:
+        return self._stream
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 52fcc56 and da1275a.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py (7 hunks)
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (2 hunks)
  • tensorrt_llm/_torch/speculative/eagle3.py (3 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py

120-120: Line too long (133 > 120)

(E501)


180-180: Line too long (171 > 120)

(E501)


181-181: Line too long (126 > 120)

(E501)


288-288: Line too long (181 > 120)

(E501)

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

344-344: Line too long (132 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (4)
tensorrt_llm/_torch/speculative/mtp.py (2)

263-263: Add int cast to maintain consistent typing.

num_new_tokens is extracted from a Python list but should be cast to int for consistency with the previous type-safe approach from the past review.

Apply this diff:

-            num_new_tokens = new_tokens_lens_list[req.py_seq_slot]
+            num_new_tokens = int(new_tokens_lens_list[req.py_seq_slot])

268-269: Apply int cast to maintain type consistency.

Following the established pattern from the past review, num_new_tokens should be cast to int before arithmetic operations to ensure py_num_accepted_draft_tokens and py_rewind_len are plain Python integers.

Apply this diff:

-            req.py_num_accepted_draft_tokens = num_new_tokens - 1
-            req.py_rewind_len = self.draft_len - req.py_num_accepted_draft_tokens
+            req.py_num_accepted_draft_tokens = int(num_new_tokens) - 1
+            req.py_rewind_len = self.draft_len - req.py_num_accepted_draft_tokens
tensorrt_llm/_torch/speculative/eagle3.py (1)

281-296: Add error handling around guided-decoder operations.

The guided-decoder calls are unprotected and may raise exceptions or hang (e.g., token_event.synchronize has no timeout). Add defensive error handling to ensure execution continues even if guided decoding fails.

Apply this diff:

         if self.guided_decoder is not None:
-            with torch.cuda.stream(self.guided_decoder._stream):
-                torch.cuda.current_stream().wait_event(
-                    guided_metadata.token_event)
-                # Fix it.
-                guided_metadata.next_batch_hostfunc()
-                self.guided_decoder.rollback_rejected_tokens_hostfunc(
-                    guided_metadata)
-                self.guided_decoder.build_hostfunc(guided_metadata)
-                self.guided_decoder.bitmask_copy(guided_metadata)
-                guided_metadata.bitmask_event.record()
-
-            torch.cuda.current_stream().wait_event(
-                guided_metadata.bitmask_event)
-            self.guided_decoder.execute(guided_metadata, logits)
+            try:
+                with torch.cuda.stream(self.guided_decoder._stream):
+                    torch.cuda.current_stream().wait_event(
+                        guided_metadata.token_event)
+                    guided_metadata.next_batch_hostfunc()
+                    self.guided_decoder.rollback_rejected_tokens_hostfunc(
+                        guided_metadata)
+                    self.guided_decoder.build_hostfunc(guided_metadata)
+                    self.guided_decoder.bitmask_copy(guided_metadata)
+                    guided_metadata.bitmask_event.record()
+
+                torch.cuda.current_stream().wait_event(
+                    guided_metadata.bitmask_event)
+                self.guided_decoder.execute(guided_metadata, logits)
+            except Exception as e:
+                import traceback
+                traceback.print_exc()
+                logger.error(f"Guided decoder failed; continuing without guidance: {e}")
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

2197-2200: Fix unsafe attribute access chain.

The code directly accesses self.model.spec_worker.guided_decoder without ensuring all intermediate attributes exist, which can raise AttributeError.

Apply this diff:

         guided_metadata = None
-        if self.model.spec_worker.guided_decoder:
+        if (hasattr(self.model, 'spec_worker') and 
+            hasattr(self.model.spec_worker, 'guided_decoder') and 
+            self.model.spec_worker.guided_decoder):
             guided_metadata = self._set_up_guided_metadata()
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between da1275a and 566a849.

📒 Files selected for processing (5)
  • tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (3 hunks)
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py (6 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (16 hunks)
  • tensorrt_llm/_torch/speculative/eagle3.py (4 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (3 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py

108-108: Line too long (129 > 120)

(E501)


246-246: Line too long (171 > 120)

(E501)


247-247: Line too long (126 > 120)

(E501)


255-255: Line too long (122 > 120)

(E501)


348-348: Line too long (181 > 120)

(E501)

🔇 Additional comments (42)
tensorrt_llm/_torch/speculative/mtp.py (2)

103-103: LGTM! Fix applied from past review.

The trailing comma in the default value has been correctly removed, preventing mtp_num_modules from becoming a single-element tuple.


250-250: LGTM! Improved list conversion approach.

Converting new_tokens_lens to a Python list with tolist() is cleaner than tensor indexing for per-request access patterns.

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (3)

36-36: LGTM! Guided metadata parameter added.

The constructor now accepts guided_metadata parameter to support guided decoding integration while maintaining backward compatibility with None default.


75-75: LGTM! Store guided metadata for CUDA graph usage.

Proper storage of the guided_metadata instance for later use in graph capture and execution.


95-95: LGTM! Guided metadata propagated to forward function.

The guided_metadata is correctly included in the inputs dictionary passed to the forward function during graph capture and execution.

tensorrt_llm/_torch/speculative/eagle3.py (4)

10-10: LGTM! Import added for guided decoder integration.

The import of GuidedDecoder enables Eagle3 to integrate with the guided decoding framework.


269-269: LGTM! Guided decoder attribute added.

The guided_decoder attribute allows Eagle3OneModelWorker to optionally integrate with guided decoding while maintaining compatibility when not used.


274-274: LGTM! Forward signature extended for guided decoding.

The guided_metadata parameter properly extends the forward method signature to support guided decoding integration.


345-345: LGTM! Guided decoder integration point marked.

The comment appropriately marks where guided decoding steps are integrated into the draft-generation flow.

tensorrt_llm/_torch/pyexecutor/model_engine.py (15)

44-53: LGTM! Import reorganization and new imports added.

The imports have been properly reorganized with new guided decoding components (GuidedMetadata, SampleStateTensors, SampleStateTensorsMTP) and moved speculative utilities to their appropriate locations.


58-64: LGTM! Additional imports for guided decoding support.

The imports of GuidedMetadata and related components properly enable guided decoding integration in the model engine.


403-410: LGTM! Guided metadata state initialization.

The guided metadata fields and host tensor buffers are properly initialized to support the guided decoding workflow.


871-876: LGTM! Lazy guided metadata initialization.

The _set_up_guided_metadata method properly initializes GuidedMetadata on-demand, following the established pattern for other metadata initialization.


1019-1021: LGTM! Guided metadata integration in CUDA graph runner.

The guided_metadata is correctly passed to the DecodingCUDAGraphRunner constructor to enable guided decoding support in CUDA graphs.


1215-1217: LGTM! Token event recording for guided metadata.

The token event is properly recorded when guided metadata is present, enabling proper synchronization in the guided decoding workflow.


1226-1226: LGTM! Guided metadata parameter added to input preparation methods.

The guided_metadata parameter is properly added to the input preparation method signatures.


1331-1331: LGTM! Request ID tracking maintained.

The request_ids.append(request.py_request_id) calls are properly positioned to maintain request ID tracking for guided decoding.


1629-1646: LGTM! Guided metadata population in input preparation.

The guided metadata is properly populated with the scheduled requests and tensor data when provided, enabling the guided decoding workflow to access the necessary information.


1681-1682: LGTM! Guided metadata parameter consistency.

The guided_metadata parameter is consistently added to the no-cache input preparation method.


1779-1782: LGTM! Guided metadata handling in no-cache path.

Guided metadata is properly handled in the no-cache input preparation path, maintaining consistency across different execution modes.


2151-2151: LGTM! Guided metadata parameter added to main preparation method.

The guided_metadata parameter is consistently added to the main _prepare_inputs method signature.


2164-2164: LGTM! Guided metadata forwarding in input preparation.

The guided_metadata parameter is properly forwarded to the appropriate input preparation method.


2212-2214: LGTM! Guided metadata parameter forwarding.

The guided_metadata parameter is properly forwarded to the no-cache input preparation method.


2233-2235: LGTM! Guided metadata parameter forwarding in main path.

The guided_metadata parameter is correctly forwarded in the main forward execution path.

tensorrt_llm/_torch/pyexecutor/guided_decoder.py (18)

2-4: LGTM! New imports for guided decoding data structures.

The imports of dataclass, field, Queue, and typing components properly support the new guided decoding data structures.


7-7: LGTM! XGrammar import added.

The xgrammar import enables the apply_token_bitmask_inplace functionality used in the execute method.


10-10: LGTM! Guided decoding configuration imports.

The imports of GuidedDecodingConfig and GuidedDecodingParams provide the necessary configuration types.


12-12: LGTM! Host function decorator import.

The hostfunc import enables the host-callable method decorators used throughout the guided decoder.


13-15: LGTM! Grammar matcher imports updated.

The import statement properly includes all necessary grammar matcher components.


19-74: LGTM! Well-designed GuidedRequest dataclass.

The GuidedRequest dataclass provides a clean abstraction for request state with proper method logic for determining when matcher initialization and advancement are required. The from_request class method provides a clean conversion from LlmRequest.


76-101: LGTM! Well-structured GuidedRequestBatch class.

The GuidedRequestBatch class provides good encapsulation of batch processing logic with proper offset calculation for guided request processing.


103-141: LGTM! Comprehensive GuidedMetadata class.

The GuidedMetadata dataclass provides proper structure for coordinating host-device guided decoding workflow with appropriate CUDA events and queue management.


171-182: LGTM! Proper bitmask buffer initialization.

The grammar matchers list and bitmask buffers (both host and device) are properly initialized with appropriate sizing for draft tokens support.


200-200: LGTM! Method signature updated for guided metadata.

The build method signature properly accepts GuidedMetadata instead of ScheduledRequests, aligning with the new guided decoding architecture.


208-232: LGTM! Flexible input handling for build method.

The method properly handles both gathered_input_ids from guided metadata and fallback to request-specific tokens, providing flexibility for different execution modes.


233-283: LGTM! Robust matcher processing logic.

The matcher initialization and advancement logic properly handles various request states and draft tokens with appropriate error handling and termination conditions.


284-289: LGTM! Proper bitmask copy implementation.

The bitmask_copy method correctly transfers the host bitmask to device memory using non-blocking copy for performance.


293-293: LGTM! Execute method signature updated.

The execute method signature properly accepts GuidedMetadata for consistency with the new architecture.


310-312: LGTM! Proper logits masking implementation.

The method correctly applies token bitmask using XGrammar's apply_token_bitmask_inplace with proper token count limiting.


318-318: LGTM! Rollback method signature updated.

The rollback_rejected_tokens method signature properly accepts GuidedMetadata for consistency.


328-351: LGTM! Flexible rollback token counting.

The rollback logic properly handles both explicit new_tokens_lens from metadata and fallback to request-specific accepted token counts, providing robustness across different execution modes.


391-411: LGTM! Host-callable method wrappers.

The host-callable method wrappers properly enable device-side execution of guided decoding operations. The methods provide appropriate delegation to the main implementation methods.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (3)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)

342-345: Fix the formatting issue

The commented-out lines exceed the maximum line length and should be properly formatted.

Apply this diff to properly format the commented lines:

-            # if spec_config is not None and not has_spec_drafter:
-            #     raise ValueError(
-            #         "Guided decoding is only supported with speculative decoding that has a dedicated drafter (two-model engine)."
-            #     )
+            # if spec_config is not None and not has_spec_drafter:
+            #     raise ValueError(
+            #         "Guided decoding is only supported with speculative decoding that has a "
+            #         "dedicated drafter (two-model engine)."
+            #     )
tensorrt_llm/_torch/speculative/eagle3.py (1)

451-452: Fix the incorrect attribute name from "dt2" to "d2t"

The draft-to-token offset mapping is defined as d2t throughout the codebase, but this code incorrectly checks for dt2.

Apply this diff to fix the attribute name:

-        if (d2t := getattr(draft_model.model, 'dt2', None)):
+        if (d2t := getattr(draft_model.model, 'd2t', None)):
             draft_tokens = d2t[draft_tokens] + draft_tokens
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (1)

358-363: Add bounds checking for grammar matcher access

The run method accesses self.grammar_matchers[0] without checking if it exists or if the index is valid.

Apply this diff to add proper error handling:

     @hostfunc
     def run(self, token_ids: torch.Tensor):
+        if token_ids.numel() == 0:
+            logger.warning("Empty token_ids tensor provided to run method")
+            return
+        if self.grammar_matchers[0] is None:
+            logger.warning("Grammar matcher at index 0 is not initialized")
+            return
         self.grammar_matchers[0].accept_token(token_ids[0].item())
         self.grammar_matchers[0].fill_next_token_bitmask(self.bitmask_host, 0)
         if not hasattr(self, "token_ids"):
             self.token_ids = []
         self.token_ids.append(token_ids[0].item())
🧹 Nitpick comments (4)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

1211-1213: Consider adding error handling for event recording

The token_event.record() operation could potentially fail. Consider adding error handling to ensure the forward pass can continue even if guided decoding setup fails.

Apply this diff to add error handling:

         if self.guided_worker is not None:
-            self.guided_worker.token_event.record()
+            try:
+                self.guided_worker.token_event.record()
+            except Exception as e:
+                logger.warning(f"Failed to record guided worker token event: {e}")
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (3)

176-178: Add TODO tracking for the "Fix it" comment

The comment "Fix it" on line 176 is vague and should be more descriptive about what needs to be fixed.

Would you like me to help clarify what needs to be fixed here or create an issue to track this TODO?


198-199: Fix line length violations

Lines 198-199 exceed the maximum line length of 120 characters.

Apply this diff to fix the line length:

-                # The last new token must be acceptable unless the matcher is terminated:
-                # 1. For the main model loop, when overlap scheduler is enabled, the matcher may have accepted the EOS token in the draft tokens at the previous iteration.
-                # 2. For the draft model loop, the matcher may have accepted the EOS token at the previous drafting iteration.
+                # The last new token must be acceptable unless the matcher is terminated:
+                # 1. For the main model loop, when overlap scheduler is enabled, the matcher may have
+                #    accepted the EOS token in the draft tokens at the previous iteration.
+                # 2. For the draft model loop, the matcher may have accepted the EOS token at the
+                #    previous drafting iteration.

207-207: Fix additional line length violations

Lines 207 and 295 also exceed the maximum line length.

Apply this diff to fix the line lengths:

-                            f"Draft request {req.request_id} at slot {slot} failed to accept last new token: {req.new_token}."
+                            f"Draft request {req.request_id} at slot {slot} failed to accept "
+                            f"last new token: {req.new_token}."

-                    f"Failed to rollback: num_advanced_tokens={self.num_advanced_tokens[slot]}, num_accepted_tokens={num_accepted_tokens}, num_rollback_tokens={num_rollback_tokens}"
+                    f"Failed to rollback: num_advanced_tokens={self.num_advanced_tokens[slot]}, "
+                    f"num_accepted_tokens={num_accepted_tokens}, "
+                    f"num_rollback_tokens={num_rollback_tokens}"

Also applies to: 295-295

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 566a849 and 0f65562.

📒 Files selected for processing (5)
  • tensorrt_llm/_torch/models/modeling_speculative.py (3 hunks)
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py (5 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (7 hunks)
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (2 hunks)
  • tensorrt_llm/_torch/speculative/eagle3.py (6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tensorrt_llm/_torch/models/modeling_speculative.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py
🧠 Learnings (1)
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
PR: NVIDIA/TensorRT-LLM#7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
🧬 Code graph analysis (4)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (6)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (2)
  • GuidedDecoder (113-363)
  • GuidedWorker (371-445)
tensorrt_llm/_torch/models/modeling_utils.py (1)
  • vocab_size_padded (504-505)
tensorrt_llm/_torch/modules/embedding.py (1)
  • vocab_size_padded (76-80)
tensorrt_llm/_torch/models/modeling_speculative.py (1)
  • set_guided_worker (423-426)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
  • set_guided_worker (472-478)
tensorrt_llm/_torch/speculative/eagle3.py (1)
  • set_guided_worker (502-504)
tensorrt_llm/_torch/speculative/eagle3.py (3)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (3)
  • GuidedWorker (371-445)
  • execute (263-279)
  • execute (433-445)
tensorrt_llm/_torch/models/modeling_speculative.py (1)
  • set_guided_worker (423-426)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
  • set_guided_worker (472-478)
tensorrt_llm/_torch/pyexecutor/model_engine.py (6)
tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py (1)
  • BaseCheckpointLoader (19-87)
tensorrt_llm/_torch/speculative/interface.py (1)
  • SpecMetadata (111-199)
tensorrt_llm/_torch/speculative/utils.py (3)
  • get_num_extra_kv_tokens (166-176)
  • get_spec_metadata (16-59)
  • update_spec_config_from_model_config (179-184)
tensorrt_llm/_torch/speculative/mtp.py (1)
  • SampleStateTensorsMTP (22-24)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (3)
  • GuidedWorker (371-445)
  • SampleStateTensors (366-368)
  • add_batch (395-406)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
  • SampleStateTensors (31-37)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (5)
tensorrt_llm/_torch/hostfunc.py (1)
  • hostfunc (19-24)
tensorrt_llm/_torch/pyexecutor/grammar_matcher.py (18)
  • GrammarMatcher (13-30)
  • LLGuidanceMatcherFactory (174-222)
  • XGrammarMatcherFactory (61-128)
  • create (36-38)
  • create (93-128)
  • create (195-222)
  • is_terminated (29-30)
  • is_terminated (57-58)
  • is_terminated (165-166)
  • accept_token (16-17)
  • accept_token (47-48)
  • accept_token (139-150)
  • fill_next_token_bitmask (24-26)
  • fill_next_token_bitmask (53-55)
  • fill_next_token_bitmask (159-163)
  • rollback (20-21)
  • rollback (50-51)
  • rollback (152-157)
tensorrt_llm/_torch/pyexecutor/llm_request.py (2)
  • LlmRequest (271-411)
  • get (99-108)
tensorrt_llm/_torch/pyexecutor/scheduler.py (2)
  • ScheduledRequests (18-39)
  • all_requests (38-39)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
  • SampleStateTensors (31-37)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

344-344: Line too long (132 > 120)

(E501)

tensorrt_llm/_torch/pyexecutor/guided_decoder.py

198-198: Line too long (171 > 120)

(E501)


199-199: Line too long (126 > 120)

(E501)


207-207: Line too long (126 > 120)

(E501)


269-269: Line too long (126 > 120)

(E501)


295-295: Line too long (181 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (13)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (2)

28-28: LGTM!

The new import for GuidedWorker alongside GuidedDecoder is appropriate for the guided decoding framework changes.


347-359: Spec_worker attribute availability across models verified

I confirmed that all speculative‐decoding model classes (in tensorrt_llm/_torch/models/modeling_speculative.py) assign self.spec_worker in their constructor via get_spec_worker (around line 360) . No other model classes define spec_worker, so the hasattr(model_engine.model, "spec_worker") check in py_executor_creator.py will only be true for those that support speculative decoding, and the fallback to GuidedDecoder correctly covers the rest. No changes required here.

tensorrt_llm/_torch/speculative/eagle3.py (2)

269-269: LGTM!

The addition of the guided_worker attribute with proper typing is well-structured.


502-504: LGTM!

The set_guided_worker method implementation is clean and follows the expected pattern.

tensorrt_llm/_torch/pyexecutor/model_engine.py (4)

44-52: LGTM!

The imports have been properly reorganized to match the new module structure, with the addition of GuidedWorker alongside the existing imports.


403-403: LGTM!

The initialization of guided_worker as an optional attribute is properly typed and follows the existing pattern.


472-478: LGTM!

The set_guided_worker method correctly delegates to the model's method if available and updates the local reference on success.


1624-1626: LGTM!

The guided worker batch addition is properly integrated into the input preparation flow, correctly passing both the scheduled requests and the new tensors when available.

tensorrt_llm/_torch/pyexecutor/guided_decoder.py (5)

19-74: LGTM! Well-structured dataclass for guided request handling

The GuidedRequest dataclass is well-designed with clear field definitions and logical helper methods for determining when matcher initialization and advancement are required.


76-110: LGTM! Clean batch processing abstraction

The GuidedRequests dataclass provides a clean abstraction for batch processing with useful properties and iteration methods.


366-369: LGTM! Clean protocol definition

The SampleStateTensors protocol provides a clean interface for the expected tensor structure.


371-394: LGTM! Well-structured GuidedWorker initialization

The GuidedWorker class properly extends GuidedDecoder with additional queue-based batch processing and CUDA event synchronization.


433-445: LGTM! Clean event-based synchronization

The execute method properly uses CUDA events for synchronization between the host functions and device operations.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (1)

336-349: Wrong method invoked; draft rollback uses rejected-token rollback

rollback_draft_tokens calls _rollback_rejected_tokens instead of _rollback_draft_tokens.

Apply:

-        requests = GuidedRequests.from_scheduled_requests(
-            scheduled_requests, self.max_num_draft_tokens)
-        self._rollback_rejected_tokens(requests)
+        requests = GuidedRequests.from_scheduled_requests(
+            scheduled_requests, self.max_num_draft_tokens)
+        self._rollback_draft_tokens(requests)
♻️ Duplicate comments (4)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (2)

371-377: Harden hostfunc run against empty inputs and matcher state

Apply:

     @hostfunc
     def run(self, token_ids: torch.Tensor):
-        self.grammar_matchers[0].accept_token(token_ids[0].item())
-        self.grammar_matchers[0].fill_next_token_bitmask(self.bitmask_host, 0)
-        if not hasattr(self, "token_ids"):
-            self.token_ids = []
-        self.token_ids.append(token_ids[0].item())
+        if token_ids.numel() == 0:
+            logger.warning("GuidedDecoder.run: empty token_ids")
+            return
+        matcher = self.grammar_matchers[0]
+        if matcher is None:
+            logger.warning("GuidedDecoder.run: matcher[0] is not initialized")
+            return
+        try:
+            matcher.accept_token(token_ids[0].item())
+            matcher.fill_next_token_bitmask(self.bitmask_host, 0)
+        except Exception as e:
+            logger.error(f"GuidedDecoder.run failed: {e}")
+            return
+        self.token_ids = getattr(self, "token_ids", [])
+        self.token_ids.append(token_ids[0].item())

413-426: Bounds and shape checks for next_batch slot indexing

Apply:

     @hostfunc
     def next_batch(self) -> None:
-        # Fix it.
         if self.queue.empty():
             return
         self.requests_hostfunc, has_new_tensors = self.queue.get()
         if not has_new_tensors:
             return

-        for req in self.requests_hostfunc:
+        for req in self.requests_hostfunc:
             if (slot := req.seq_slot) is None:
                 continue
-            req.new_token, *req.draft_tokens = self.new_tokens[:, slot].tolist()
+            if not isinstance(slot, int) or slot < 0 or slot >= self.max_num_sequences:
+                logger.warning(f"next_batch: invalid slot {slot}")
+                continue
+            col = self.new_tokens[:, slot]
+            if col.numel() == 0:
+                logger.warning(f"next_batch: empty new_tokens for slot {slot}")
+                continue
+            values = col.tolist()
+            req.new_token, *req.draft_tokens = values
tensorrt_llm/_torch/speculative/eagle3.py (2)

281-283: Wrap guided-worker execution to avoid forward crashes

Apply:

-        if self.guided_worker is not None:
-            self.guided_worker.execute(logits)
+        if self.guided_worker is not None:
+            try:
+                self.guided_worker.execute(logits)
+            except Exception as e:
+                import traceback
+                traceback.print_exc()
+                # Continue without guidance
+                print(f"GuidedWorker.execute failed; proceeding without guidance: {e}")

465-468: Apply correct d2t mapping in draft_decoder

Same “dt2” typo here; breaks offset mapping.

Apply:

-        # Apply d2t (offsets between draft model dictionary and main model dictionary).
-        if (d2t := getattr(draft_model.model, 'dt2', None)):
+        # Apply d2t (offsets between draft model dictionary and main model dictionary).
+        if (d2t := getattr(draft_model.model, 'd2t', None)):
             draft_tokens = d2t[draft_tokens] + draft_tokens
🧹 Nitpick comments (10)
tensorrt_llm/_torch/pyexecutor/model_engine.py (3)

1-1: Add NVIDIA copyright header

Per repository guidelines, prepend the current-year NVIDIA copyright header to all source files.

Apply at file top:

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.

403-404: Document and harden set_guided_worker API

Add a short docstring and clarify return semantics. Also be explicit about accepted None to avoid accidental misuse.

Apply:

-    def set_guided_worker(self, guided_worker: GuidedWorker) -> bool:
+    def set_guided_worker(self, guided_worker: GuidedWorker) -> bool:
+        """Attach a GuidedWorker to the underlying model (if supported) and cache it locally.
+        
+        Args:
+            guided_worker: The worker instance to attach.
+        Returns:
+            True if the underlying model accepted the worker; False otherwise.
+        """
         if hasattr(self.model, "set_guided_worker"):
             success = self.model.set_guided_worker(guided_worker)
             if success:
                 self.guided_worker = guided_worker
             return success
         return False

Also applies to: 472-479


1624-1627: Guard against shape/dtype mismatches when passing new_tokens

new_tokens is copied into a pinned-int32 buffer inside GuidedWorker via squeeze(-1). If callers ever pass unexpected ranks/dtypes, it will fail at runtime.

Consider validating before enqueue:

-        if self.guided_worker is not None:
-            self.guided_worker.add_batch(scheduled_requests,
-                                         new_tokens=new_tokens_device)
+        if self.guided_worker is not None:
+            if new_tokens_device is not None:
+                assert new_tokens_device.dtype in (torch.int32, torch.int64), "new_tokens dtype must be int"
+                assert new_tokens_device.dim() in (2, 3), "new_tokens must be [draft+1, batch] or [draft+1, batch, 1]"
+            self.guided_worker.add_batch(scheduled_requests, new_tokens=new_tokens_device)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (4)

1-1: Add NVIDIA copyright header

Per repository guidelines, prepend the current-year NVIDIA copyright header to all source files.

Apply at file top:

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.

445-455: Add None-guards for rollback hostfuncs

Avoid crashes when called before next_batch populated requests_hostfunc.

Apply:

     def rollback_rejected_tokens(self) -> None:
         if self.max_num_draft_tokens <= 0:
             return
-        self._rollback_rejected_tokens(self.requests_hostfunc)
+        if self.requests_hostfunc is not None:
+            self._rollback_rejected_tokens(self.requests_hostfunc)

     @hostfunc
     def rollback_draft_tokens(self) -> None:
         if self.max_num_draft_tokens <= 0:
             return
-        self._rollback_draft_tokens(self.requests_hostfunc)
+        if self.requests_hostfunc is not None:
+            self._rollback_draft_tokens(self.requests_hostfunc)

456-467: Defensive check in add_draft_batch

Validate that a batch exists before recording the event; helps during warmup or misordered calls.

Apply:

     def add_draft_batch(self,
                         new_tokens: torch.Tensor,
                         num_accepted_tokens: torch.Tensor,
                         is_first_step: bool = False) -> None:
-        batch_size = len(self.requests)
+        if self.requests is None:
+            logger.warning("GuidedWorker.add_draft_batch: no active batch; skipping")
+            return
+        batch_size = len(self.requests)

208-209: Long lines flagged by Ruff (E501)

Several lines exceed 120 chars. Consider wrapping to satisfy style checks.

Locations:

  • Line 208
  • Line 217
  • Line 279
  • Line 304

Also applies to: 217-217, 279-279, 304-304

tensorrt_llm/_torch/speculative/eagle3.py (3)

1-1: Add NVIDIA copyright header

Per repository guidelines, prepend the current-year NVIDIA copyright header to all source files.

Apply at file top:

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.

322-327: Guard add_draft_batch against runtime failures

Match the defensive pattern used for execute to prevent guided path failures from aborting drafting.

Apply:

-            if self.guided_worker is not None:
-                new_tokens = inputs["input_ids"][gather_ids]
-                self.guided_worker.add_draft_batch(new_tokens,
-                                                   num_accepted_tokens,
-                                                   is_first_step=(i == 0))
+            if self.guided_worker is not None:
+                try:
+                    new_tokens = inputs["input_ids"][gather_ids]
+                    self.guided_worker.add_draft_batch(new_tokens,
+                                                       num_accepted_tokens,
+                                                       is_first_step=(i == 0))
+                except Exception as e:
+                    import traceback
+                    traceback.print_exc()
+                    print(f"GuidedWorker.add_draft_batch failed; skipping this step: {e}")

517-519: Add docstring to set_guided_worker

Public method should document behavior and return value.

Apply:

     def set_guided_worker(self, guided_worker: GuidedWorker) -> bool:
+        """Attach a GuidedWorker for guided decoding; always returns True for Eagle3 one-model."""
         self.guided_worker = guided_worker
         return True
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 0f65562 and bd852ae.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py (6 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (7 hunks)
  • tensorrt_llm/_torch/speculative/eagle3.py (6 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py

208-208: Line too long (171 > 120)

(E501)


209-209: Line too long (126 > 120)

(E501)


217-217: Line too long (126 > 120)

(E501)


279-279: Line too long (126 > 120)

(E501)


304-304: Line too long (181 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

44-64: Import refactor looks good

The new relative imports and surfaced symbols (BaseCheckpointLoader, SpecMetadata helpers, GuidedWorker, SampleStateTensors) are consistent with the module layout and improve dependency locality.

tensorrt_llm/_torch/speculative/eagle3.py (1)

269-270: Initialize guided_worker to None is fine

The attribute introduction is clean and matches the new GuidedWorker API.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/speculative/eagle3.py (2)

1-1: Add required NVIDIA copyright header (2025).

Per the coding guidelines, all source files must prepend the NVIDIA copyright header. This file is missing it.

Apply this diff at the top of the file:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

200-205: Fix out-of-bounds capture when num_layers == 1.

For one-layer models you set layers_to_capture = (1,) which is out-of-range (valid index is 0). This will make maybe_capture_hidden_states silently skip capture or cause issues depending on caller expectations.

Apply this diff:

-        if self.num_layers == 1:
-            self.layers_to_capture = (1, )
+        if self.num_layers == 1:
+            self.layers_to_capture = (0, )
♻️ Duplicate comments (3)
tensorrt_llm/_torch/speculative/eagle3.py (3)

281-283: Guard guided path: wrap guided_worker.execute in try/except to avoid crashing forward pass.

Uncaught exceptions here would take down the whole iteration. This mirrors prior feedback on guarding guided-decoder paths.

Apply this diff:

-        if self.guided_worker is not None:
-            self.guided_worker.execute(logits)
+        if self.guided_worker is not None:
+            try:
+                self.guided_worker.execute(logits)
+            except Exception:
+                # Continue without guided decoding; keep raw logits path alive
+                self.logger.exception("Guided worker execute() failed; continuing without guidance")

322-327: Guard guided_worker.add_draft_batch to prevent stateful guided failures from aborting drafting.

State errors (e.g., batch shape mismatches) can be raised here; handle gracefully and keep generation going.

Apply this diff:

-            if self.guided_worker is not None:
-                new_tokens = inputs["input_ids"][gather_ids]
-                self.guided_worker.add_draft_batch(new_tokens,
-                                                   num_accepted_tokens,
-                                                   is_first_step=(i == 0))
+            if self.guided_worker is not None:
+                new_tokens = inputs["input_ids"][gather_ids]
+                try:
+                    self.guided_worker.add_draft_batch(
+                        new_tokens,
+                        num_accepted_tokens,
+                        is_first_step=(i == 0),
+                    )
+                except Exception:
+                    self.logger.exception("Guided worker add_draft_batch() failed; ignoring this guided step")

340-347: Guard guided_worker.execute_draft_batch and confirm attribute name.

Good to see the d2t attribute used (fixing prior dt2 typo elsewhere). Please also guard this call to avoid guided failures aborting the loop.

Apply this diff:

-            if self.guided_worker is not None:
-                d2t = getattr(draft_model.model, "d2t", None)
-                self.guided_worker.execute_draft_batch(
-                    logits,
-                    d2t,
-                    is_first_step=(i == 0),
-                    is_last_step=(i == self.max_draft_len - 1))
+            if self.guided_worker is not None:
+                d2t = getattr(draft_model.model, "d2t", None)
+                try:
+                    self.guided_worker.execute_draft_batch(
+                        logits,
+                        d2t,
+                        is_first_step=(i == 0),
+                        is_last_step=(i == self.max_draft_len - 1),
+                    )
+                except Exception:
+                    self.logger.exception("Guided worker execute_draft_batch() failed; skipping guided draft execution for this step")
🧹 Nitpick comments (4)
tensorrt_llm/_torch/speculative/eagle3.py (4)

10-10: Prefer module-namespace import; also add logging import for upcoming error handling.

Guidelines ask to maintain module namespace in imports. Import the module and reference GuidedWorker via the module. We'll also need logging for the error-handling changes below.

Apply this diff:

-from ..pyexecutor.guided_decoder import GuidedWorker
+import logging
+from ..pyexecutor import guided_decoder

And update annotations/usages accordingly (see other diffs in this review).


264-271: Initialize a class logger; annotate guided_worker via module namespace.

We'll need a logger for the guarded guided-decoding calls. Also align the type annotation to module-namespace import.

Apply this diff:

     def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping):
         super().__init__()
         self.spec_config = spec_config
         self.max_draft_len = self.spec_config.max_draft_len
         self.mapping = mapping
-        self.guided_worker: Optional[GuidedWorker] = None
+        # Local logger for error reporting in guided-decoding path
+        self.logger = logging.getLogger(__name__)
+        self.guided_worker: Optional[guided_decoder.GuidedWorker] = None

466-468: Ensure d2t device matches draft_tokens before indexing (avoid device mismatch).

If draft_model.model.d2t is not registered as a CUDA buffer on the same device, d2t[draft_tokens] will error. Either guarantee d2t is a buffer on the model or guard locally.

Apply this diff if you cannot guarantee buffer placement:

-        if (d2t := getattr(draft_model.model, "d2t", None)) is not None:
-            draft_tokens = d2t[draft_tokens] + draft_tokens
+        if (d2t := getattr(draft_model.model, "d2t", None)) is not None:
+            # Ensure device alignment for indexing
+            d2t_local = d2t.to(draft_tokens.device, non_blocking=True) if getattr(d2t, "device", None) != draft_tokens.device else d2t
+            draft_tokens = d2t_local[draft_tokens] + draft_tokens

Follow-up:

  • Prefer registering d2t as a buffer on draft_model.model so it moves with .to(device) and avoids per-step copies.

516-519: Docstring the API and align type with module-namespace import.

Provide a concise docstring and use the module namespace in the type annotation for consistency with import style.

Apply this diff:

-    def set_guided_worker(self, guided_worker: GuidedWorker) -> bool:
-        self.guided_worker = guided_worker
-        return True
+    def set_guided_worker(self, guided_worker: guided_decoder.GuidedWorker) -> bool:
+        """Attach a guided-decoding worker instance.
+
+        Args:
+            guided_worker: guided_decoder.GuidedWorker used to guide decoding.
+
+        Returns:
+            True if the worker was attached.
+        """
+        self.guided_worker = guided_worker
+        return True
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between bd852ae and 827b747.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/speculative/eagle3.py (6 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/speculative/eagle3.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/speculative/eagle3.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tensorrt_llm/_torch/speculative/eagle3.py (1)

330-336: LGTM: explicit toggle of use_spec_decoding for nth steps.

The intent and FIXME are clear; and use_spec_decoding is restored to True after the loop (Line 395). No action needed here.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

1758-1781: No-cache path: enqueue a guided batch and record the token event.

When kv_cache_manager is None, we still need to keep the guided path consistent so downstream execute paths don't see empty/None state.

Apply this diff just before the return:

-        return inputs, None
+        # Keep guided path consistent in no-cache mode
+        if self.guided_worker is not None:
+            # No new_tokens available in no-cache path; enqueue a placeholder batch.
+            self.guided_worker.add_batch(scheduled_requests, new_tokens=None)
+            self.guided_worker.token_event.record()
+        return inputs, None
♻️ Duplicate comments (2)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)

1211-1213: Do not record token_event in _preprocess_inputs; it races with add_batch enqueue.

Recording here can happen before any H2D/D2H work is enqueued for the batch, leading to out-of-order dependencies or later dereferencing of uninitialized guided state. This was flagged earlier and remains unresolved.

Apply this diff to remove premature recording:

-        if self.guided_worker is not None:
-            self.guided_worker.token_event.record()

Then, record immediately after guided_worker.add_batch(...) in _prepare_tp_inputs and add a guarded enqueue+record in the no-cache path (see next comments for diffs).

Run to verify ordering after applying the diffs:

#!/bin/bash
# Expect: no record() in _preprocess_inputs; exactly two record() call sites:
# 1) after add_batch in _prepare_tp_inputs
# 2) in _prepare_tp_inputs_no_cache before return
rg -nP 'guided_worker\.token_event\.record\(' tensorrt_llm/_torch/pyexecutor/model_engine.py -n -C2

# Confirm add_batch call sites
rg -nP 'guided_worker\.add_batch\s*\(' tensorrt_llm/_torch/pyexecutor/model_engine.py -n -C3

1623-1626: Record token_event immediately after enqueueing the guided batch to preserve dependency ordering.

This ensures the event reflects the actual submission of the guided inputs.

Apply this diff:

         if self.guided_worker is not None:
             self.guided_worker.add_batch(scheduled_requests,
                                          new_tokens=new_tokens_device)
+            # Record after enqueue to preserve proper dependency ordering.
+            self.guided_worker.token_event.record()
🧹 Nitpick comments (4)
tensorrt_llm/_torch/pyexecutor/model_engine.py (4)

44-53: Imports look consistent with the refactor; consider type-only import for GuidedWorker to avoid runtime deps/cycles.

  • The added imports for checkpoints, speculative metadata, and sampler types look correct.
  • For GuidedWorker, prefer a type-only import to prevent import-time side effects or circular dependencies when this module is imported.

Apply this diff within the selected range to gate the GuidedWorker import:

-from .guided_decoder import GuidedWorker
+# Avoid importing GuidedWorker at runtime to reduce import-time deps and cycles
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+    from .guided_decoder import GuidedWorker

Additionally, update type annotations at their use sites (examples shown in separate comments) to use forward references: 'GuidedWorker'.

Also applies to: 58-58, 63-63


1-2: Missing NVIDIA copyright header (2025).

Per the coding guidelines, prepend the current-year NVIDIA copyright header.

Add the standard header at the top of the file (outside the changed lines), for example:

# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

403-404: Initialize attribute with forward-referenced type and document intent.

Use a forward reference to align with the type-only import above and add a short attribute docstring.

Apply this diff:

-        self.guided_worker: Optional[GuidedWorker] = None
+        self.guided_worker: Optional['GuidedWorker'] = None  # Guided decoding controller; set via set_guided_worker()

472-479: Public API: add a concise docstring and use a forward reference in the signature.

The method is part of the external surface for enabling guided decoding; add a docstring for discoverability and keep type import lazy.

Apply this diff:

-    def set_guided_worker(self, guided_worker: GuidedWorker) -> bool:
+    def set_guided_worker(self, guided_worker: 'GuidedWorker') -> bool:
+        """Attach a GuidedWorker to this engine and propagate it to the model if supported.
+
+        Returns:
+            bool: True if the model accepted the worker and the engine stored it; False otherwise.
+        """
         if hasattr(self.model, "set_guided_worker"):
             success = self.model.set_guided_worker(guided_worker)
             if success:
                 self.guided_worker = guided_worker
             return success
         return False
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 827b747 and 8f071e2.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (7 hunks)
  • tensorrt_llm/_torch/pyexecutor/py_executor.py (0 hunks)
💤 Files with no reviewable changes (1)
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/pyexecutor/model_engine.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/pyexecutor/model_engine.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (5)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (5)

106-113: Potential index error on empty batches.

Accessing self.requests[0] will raise an IndexError if the batch is empty. Add a check for empty requests before accessing the first element.

 @property
 def num_bitmask_tokens(self) -> int:
+    if not self.requests:
+        return 0
     if self.requests[0].is_draft:
         return len(self.requests)
     else:
         return self.num_contexts + self.num_generations * (
             self.max_num_draft_tokens + 1)

194-253: Incorrect advancement accounting and unsafe draft token access.

The code has two issues:

  1. num_advanced_tokens is incremented unconditionally even when only initializing the matcher (line 233)
  2. req.draft_tokens may be None but is used without checking (line 238)

These issues can lead to incorrect rollback calculations and potential AttributeError.

             if matcher_advance:
                 matcher = self.grammar_matchers[slot]
                 # The last new token must be acceptable unless the matcher is terminated:
                 # 1. For the main model loop, when overlap scheduler is enabled, the matcher may have accepted the EOS token in the draft tokens at the previous iteration.
                 # 2. For the draft model loop, the matcher may have accepted the EOS token at the previous drafting iteration.
                 if matcher.is_terminated() or self.is_draft_terminated[slot]:
                     continue
                 accepted = matcher.accept_token(req.new_token)
                 if not accepted:
                     if req.is_draft:
                         self.is_draft_terminated[slot] = True
                         logger.debug(
                             f"Draft request {req.request_id} at slot {slot} failed to accept last new token: {req.new_token}."
                         )
                         continue
                     # TODO: Make this an error response.
                     raise ValueError(
                         f"Request {req.request_id} at slot {slot} failed to accept last new token: {req.new_token}."
                     )
+                # Only count actually accepted tokens
+                self.num_advanced_tokens[slot] += 1
 
-            self.num_advanced_tokens[slot] += 1
             if not matcher.is_terminated():
                 matcher.fill_next_token_bitmask(self.bitmask_host, offset)
                 self.num_guided_tokens[slot] += 1
                 # Process draft tokens
-                for i, tid in enumerate(req.draft_tokens, 1):
+                for i, tid in enumerate((req.draft_tokens or []), 1):
                     accepted = matcher.accept_token(tid)
                     if not accepted:
                         break
                     self.num_advanced_tokens[slot] += 1
                     if matcher.is_terminated():
                         break
                     matcher.fill_next_token_bitmask(self.bitmask_host,
                                                     offset + i)
                     self.num_guided_tokens[slot] += 1

433-435: Add bounds checking for tensor access.

The fetch_batch method accesses tensors by slot index without validating that the slot is within bounds of the tensor dimensions.

         for req in self.requests_hostfunc:
             if (slot := req.seq_slot) is None:
                 continue
+            if slot >= self.max_num_sequences:
+                logger.warning(f"Slot {slot} exceeds max_num_sequences {self.max_num_sequences}")
+                continue
             req.new_token, *req.draft_tokens = self.new_tokens[:, slot].tolist()

441-452: Guard execute() when no batch is enqueued.

When add_batch hasn't been called or if the queue is empty, self.requests_hostfunc could be None, leading to a potential AttributeError in _build. Add an early return to handle this case gracefully.

 def execute(self,
             logits: torch.Tensor,
             d2t: Optional[torch.Tensor] = None) -> None:
+    if self.requests_hostfunc is None:
+        logger.debug("GuidedWorker.execute: no requests enqueued; skipping guidance")
+        return
     with torch.cuda.stream(self._stream):
         torch.cuda.current_stream().wait_event(self.token_event)
         self.fetch_batch()
         self.build()
         self.copy_bitmask()
         self.bitmask_event.record()
 
     torch.cuda.current_stream().wait_event(self.bitmask_event)
     self.apply_bitmask(logits, d2t=d2t)

249-252: Invalid assertion for draft requests.

The assertion that draft tokens must be empty for draft requests is incorrect. Draft requests can have draft tokens from previous iterations that need to be processed.

             if req.is_draft:
-                assert len(req.draft_tokens) == 0
                 self.num_advanced_draft_tokens[
                     slot] += self.num_advanced_tokens[slot]
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 8f071e2 and 1d2b09b.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py (7 hunks)
  • tensorrt_llm/_torch/pyexecutor/py_executor.py (3 hunks)
  • tensorrt_llm/_torch/speculative/model_drafter.py (3 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/speculative/model_drafter.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/speculative/model_drafter.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py
🧬 Code graph analysis (3)
tensorrt_llm/_torch/speculative/model_drafter.py (1)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (6)
  • add_batch (285-287)
  • add_batch (412-421)
  • execute (303-312)
  • execute (441-452)
  • rollback_draft_tokens (366-367)
  • rollback_draft_tokens (459-460)
tensorrt_llm/_torch/pyexecutor/py_executor.py (4)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (7)
  • add_batch (285-287)
  • add_batch (412-421)
  • init_disagg_gen_requests (384-385)
  • execute (303-312)
  • execute (441-452)
  • rollback_rejected_tokens (362-363)
  • rollback_rejected_tokens (455-456)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
  • request_context (114-138)
tensorrt_llm/_torch/speculative/model_drafter.py (1)
  • prepare_draft_tokens (324-410)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
  • _forward_step (2257-2278)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (5)
tensorrt_llm/sampling_params.py (1)
  • GuidedDecodingParams (14-36)
tensorrt_llm/_torch/hostfunc.py (1)
  • hostfunc (19-24)
tensorrt_llm/_torch/pyexecutor/grammar_matcher.py (16)
  • GrammarMatcher (13-30)
  • create (36-38)
  • create (93-128)
  • create (195-222)
  • is_terminated (29-30)
  • is_terminated (57-58)
  • is_terminated (165-166)
  • accept_token (16-17)
  • accept_token (47-48)
  • accept_token (139-150)
  • fill_next_token_bitmask (24-26)
  • fill_next_token_bitmask (53-55)
  • fill_next_token_bitmask (159-163)
  • rollback (20-21)
  • rollback (50-51)
  • rollback (152-157)
tensorrt_llm/_torch/pyexecutor/llm_request.py (2)
  • LlmRequest (271-411)
  • get (99-108)
tensorrt_llm/_torch/pyexecutor/scheduler.py (3)
  • ScheduledRequests (18-39)
  • all_requests (38-39)
  • batch_size (35-36)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py

216-216: Line too long (171 > 120)

(E501)


217-217: Line too long (126 > 120)

(E501)


225-225: Line too long (126 > 120)

(E501)


336-336: Line too long (181 > 120)

(E501)

🔇 Additional comments (8)
tensorrt_llm/_torch/speculative/model_drafter.py (3)

357-360: LGTM! Batch-oriented guided decoding integration looks good.

The changes properly integrate the new batch-oriented guided decoding API by calling add_batch followed by execute after the initial forward pass, which aligns with the new GuidedDecoder interface design.


380-383: LGTM! Consistent batch-oriented guided decoding in iteration loop.

The implementation correctly applies the same batch-oriented pattern (add_batch + execute) within the draft token generation loop, maintaining consistency with the new API.


403-404: LGTM! Proper rollback handling with new API.

The rollback logic correctly uses the new batch-oriented API by first calling add_batch with the scheduled requests, then invoking the no-argument rollback_draft_tokens() method. This aligns with the updated GuidedDecoder interface where rollback operates on internally stored requests.

tensorrt_llm/_torch/pyexecutor/py_executor.py (5)

744-750: LGTM! Proper batch-oriented guided decoding integration in PP mode.

The implementation correctly integrates guided decoding for pipeline parallelism by calling add_batch followed by the conditional init_disagg_gen_requests() for disaggregated serving, then executing with the logits. This maintains the proper sequencing for the new batch-oriented API.


934-937: LGTM! Consistent guided decoding setup in main executor loop.

The integration properly handles the batch-oriented guided decoding setup before draft token preparation, maintaining the same pattern as in the PP executor loop.


944-944: LGTM! Proper rollback without arguments.

The rollback call correctly uses the new no-argument signature, operating on the internally stored requests from the prior add_batch call.


949-950: LGTM! Clean separation of guided decoding execution.

The execute call properly passes only the logits tensor, following the new simplified interface where batch information is already stored internally from the add_batch call.


1065-1069: LGTM! Consistent guided decoding in overlap scheduler.

The overlap scheduler correctly implements the same batch-oriented guided decoding pattern, maintaining consistency across all executor loops.

@syuoni syuoni force-pushed the guided-with-spec-part2 branch from 1d2b09b to 7111c5a Compare August 22, 2025 08:20
@syuoni syuoni marked this pull request as ready for review August 28, 2025 15:21
@syuoni syuoni requested review from a team as code owners August 28, 2025 15:21
@syuoni
Copy link
Collaborator Author

syuoni commented Sep 2, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17274 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17274 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12980 completed with status: 'FAILURE'

@syuoni
Copy link
Collaborator Author

syuoni commented Sep 2, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17328 [ run ] triggered by Bot

Copy link
Collaborator

@Linda-Stadter Linda-Stadter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at the GIL releases and found two cases where the GIL release is not safe. Let me know what you think

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17328 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13022 completed with status: 'FAILURE'

Signed-off-by: Enwei Zhu <[email protected]>
Signed-off-by: Enwei Zhu <[email protected]>
@syuoni
Copy link
Collaborator Author

syuoni commented Sep 3, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17438 [ run ] triggered by Bot

Copy link
Collaborator

@QiJune QiJune left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17438 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13106 completed with status: 'FAILURE'

@syuoni
Copy link
Collaborator Author

syuoni commented Sep 3, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17490 [ run ] triggered by Bot

@syuoni syuoni enabled auto-merge (squash) September 3, 2025 13:38
@tensorrt-cicd
Copy link
Collaborator

PR_Github #17490 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13147 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@syuoni syuoni merged commit 5ff3a65 into NVIDIA:main Sep 3, 2025
5 checks passed
greg-kwasniewski1 pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Sep 4, 2025
@coderabbitai coderabbitai bot mentioned this pull request Sep 20, 2025
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants