Skip to content

Conversation

kylesayrs
Copy link
Contributor

@kylesayrs kylesayrs commented Aug 26, 2025

Purpose

  • Enable transforms applied to attention for R3 rotations of spinquant
10cf70de-d58b-4e78-9851-bab24e91d228

Prerequisites

Changes

  • Add hookable attention and kvcache implementations which are registered to the attention module as submodules
    • QuantizedAttentionImpl injects itself into the model by registering a new attention implementation called ct_hooked_attention overriding model.config._attn_implementation to be the new implementation name
    • QuantizedKVCache injects itself into the model by overriding the past_key_values input kwarg to attention, and wrapping the functionality of the original cache
    • Calibration and transform hooks can be added to these modules via the hook functions
      • register_query_hook,
      • register_key_hook
      • register_value_hook
    • These modules are responsible for initializing quantization parameters on the parent attention module (used in the next PR)
  • Implement transform hooks for Q_ATTN and K_CACHE locations

Testing

  • Added test_correctness_attention_heads test, which simulates R3
  • Tested with R3 using this LC branch

Copy link
Contributor

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, though i have a number of questions and minor suggestions


def __init__(self, attn_module: Module):
super().__init__()
self.attn_module_container = [attn_module] # avoid circular reference
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid circular reference by placing in a list? Can a weakref be used here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but to be clear, the circular reference is a module circular reference (ie a module cannot be the child of its child), not one of garbage collection

quant_args = getattr_chain(module, quant_args_attr, None)
quant_enabled = getattr(module, "quantization_enabled", True)
if quant_args is not None and quant_enabled and self._qparams_initialized:
query = forward_quantize(module, query, "q", quant_args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is only query quantized and not key & value?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nm, i see below the kv quantization implementation. So QuantizedAttention only refers to quantized query, and QuantizedKVCache always refers to the key/value states?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! I can make a note, but this guards against misuse. I don't think there's any reason to add a quantization hook to attention but not to kv cache

def initialize_qparams_once(self, model: PreTrainedModel, module: Module):
"""
Initialize attention quantization parameters if they have not already been
intialized. KV cache quantization parameters are initialized by the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
intialized. KV cache quantization parameters are initialized by the
initialized. KV cache quantization parameters are initialized by the

Comment on lines +146 to +147
# assumes only one model at a time
global _original_impl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😬 i don't want to delay things, but we should briefly consider if there are alternative solutions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spent 20 minutes exploring this, it requires creating specialized _ct_hooked_attention functions and specialized QuantizedAttentionImpl, which is more complexity than value added imho


def __init__(self, attn_module: Module):
super().__init__()
self.attn_module_container = [attn_module] # avoid circular reference
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, weakref?

return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)


def register_value_hook(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a lot of equivalent code in register_key_hook and register_value_hook, can they call into the same logic with a id string that is either value_states or key_states?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic of "create a kwarg hook which uses the signature of the module child, but calls the hook with the parent module" is pretty specific to these use cases, and is a hard function to name 🙃.

I think it's better to be explicit in this case.

Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the goal is to use this generally for kv_cache and attn quantize, can we move the initialize_hooked_attention and initialize_hooked_kv_cache to initialize.py?

I understand we haven't hooked them in yet for those workflows but I think these belong there.

Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
dsikka
dsikka previously approved these changes Sep 2, 2025
Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do a pass through on any missing docstring, otherwise lgtm.
nice work

Base automatically changed from kylesayrs/transform-simplify-key to main September 8, 2025 18:46
@dsikka dsikka dismissed stale reviews from brian-dellabetta and themself September 8, 2025 18:46

The base branch was changed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants