-
Notifications
You must be signed in to change notification settings - Fork 31
[Transform] Attention/Cache transforms #436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good, though i have a number of questions and minor suggestions
|
||
def __init__(self, attn_module: Module): | ||
super().__init__() | ||
self.attn_module_container = [attn_module] # avoid circular reference |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
avoid circular reference by placing in a list? Can a weakref be used here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but to be clear, the circular reference is a module circular reference (ie a module cannot be the child of its child), not one of garbage collection
quant_args = getattr_chain(module, quant_args_attr, None) | ||
quant_enabled = getattr(module, "quantization_enabled", True) | ||
if quant_args is not None and quant_enabled and self._qparams_initialized: | ||
query = forward_quantize(module, query, "q", quant_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is only query quantized and not key & value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nm, i see below the kv quantization implementation. So QuantizedAttention
only refers to quantized query, and QuantizedKVCache
always refers to the key/value states?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep! I can make a note, but this guards against misuse. I don't think there's any reason to add a quantization hook to attention but not to kv cache
def initialize_qparams_once(self, model: PreTrainedModel, module: Module): | ||
""" | ||
Initialize attention quantization parameters if they have not already been | ||
intialized. KV cache quantization parameters are initialized by the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
intialized. KV cache quantization parameters are initialized by the | |
initialized. KV cache quantization parameters are initialized by the |
# assumes only one model at a time | ||
global _original_impl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😬 i don't want to delay things, but we should briefly consider if there are alternative solutions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I spent 20 minutes exploring this, it requires creating specialized _ct_hooked_attention
functions and specialized QuantizedAttentionImpl
, which is more complexity than value added imho
|
||
def __init__(self, attn_module: Module): | ||
super().__init__() | ||
self.attn_module_container = [attn_module] # avoid circular reference |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, weakref?
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) | ||
|
||
|
||
def register_value_hook( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's a lot of equivalent code in register_key_hook
and register_value_hook
, can they call into the same logic with a id string that is either value_states
or key_states
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic of "create a kwarg hook which uses the signature of the module child, but calls the hook with the parent module" is pretty specific to these use cases, and is a hard function to name 🙃.
I think it's better to be explicit in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the goal is to use this generally for kv_cache and attn quantize, can we move the initialize_hooked_attention
and initialize_hooked_kv_cache
to initialize.py
?
I understand we haven't hooked them in yet for those workflows but I think these belong there.
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
7bf4b57
to
75056bf
Compare
Signed-off-by: Kyle Sayers <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do a pass through on any missing docstring, otherwise lgtm.
nice work
The base branch was changed.
Purpose
Prerequisites
Changes
QuantizedAttentionImpl
injects itself into the model by registering a new attention implementation calledct_hooked_attention
overridingmodel.config._attn_implementation
to be the new implementation nameQuantizedKVCache
injects itself into the model by overriding thepast_key_values
input kwarg to attention, and wrapping the functionality of the original cacheregister_query_hook
,register_key_hook
register_value_hook
Q_ATTN
andK_CACHE
locationsTesting
test_correctness_attention_heads
test, which simulates R3