Skip to content
185 changes: 185 additions & 0 deletions src/compressed_tensors/modeling/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Callable, Optional
from weakref import ref

from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationScheme,
QuantizationStrategy,
forward_quantize,
)
from compressed_tensors.quantization.lifecycle.initialize import (
_initialize_scale_zero_point,
)
from compressed_tensors.utils import getattr_chain
from compressed_tensors.utils.internal import InternalModule
from torch import Tensor
from torch.nn import Module
from torch.utils.hooks import RemovableHandle
from transformers import AttentionInterface, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS


__all__ = [
"QuantizedAttentionImpl",
"initialize_hooked_attention",
"register_query_hook",
]


IMPL_ATTR = "impl"
HOOKED_ATTENTION_NAME = "ct_hooked_attention"


class QuantizedAttentionImpl(InternalModule):
"""
QuantizedAttentionImpl module which wraps the functionality of the original
attention implementation. Unlike the original attention function, this
implementation is a `torch.nn.Module` which can be hooked to trigger
transforms and calibration hooks.

This module works by being registered as a submodule to attention modules via
`initialize_hooked_attention`, registering a new attention implementation function
which calls this module, then setting the model attention implementation to the new
function. After triggering hooks and quantization, this module calls the original
attention implementation function.

:param attn_module: parent attention module
"""

def __init__(self, attn_module: Module):
super().__init__()
self.attn_module = ref(attn_module) # avoid circular references
self._qparams_initialized = False

def forward(
self,
module: Module,
query: Tensor,
key: Tensor,
value: Tensor,
*args,
**kwargs,
):
# quantization
quant_args_attr = "quantization_scheme.input_activations"
quant_args = getattr_chain(module, quant_args_attr, None)
quant_enabled = getattr(module, "quantization_enabled", True)
if quant_args is not None and quant_enabled and self._qparams_initialized:
query = forward_quantize(module, query, "q", quant_args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is only query quantized and not key & value?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nm, i see below the kv quantization implementation. So QuantizedAttention only refers to quantized query, and QuantizedKVCache always refers to the key/value states?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! I can make a note, but this guards against misuse. I don't think there's any reason to add a quantization hook to attention but not to kv cache


# original attention
return ALL_ATTENTION_FUNCTIONS[_original_impl](
module,
query,
key,
value,
*args,
**kwargs,
)

def initialize_qparams_once(self, model: PreTrainedModel, module: Module):
"""
Initialize attention quantization parameters if they have not already been
initialized. KV cache quantization parameters are initialized by the
`QuantizedKVCache`

:param model: parent model of attention module
:param module: attention module to initialize with
"""
# TODO: move to initialize.py
assert module is self.attn_module()
scheme: Optional[QuantizationScheme] = getattr(
module, "quantization_scheme", None
)
quant_args: Optional[QuantizationArgs] = getattr(
scheme, "input_activations", None
)

if (
not self._qparams_initialized
and quant_args is not None
and not scheme.kv_cache_only
):
assert quant_args.strategy == QuantizationStrategy.TENSOR
_initialize_scale_zero_point(module, "q", quant_args)
self._qparams_initialized = True


# ----- initialize ----- #


def _ct_hooked_attention(module: Module, *args, **kwargs):
if hasattr(module, IMPL_ATTR):
return module.impl(module, *args, **kwargs)
else:
return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs)


def initialize_hooked_attention(
model: PreTrainedModel, module: Module, quantize: bool = True
):
"""
Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances
attached to attention

:param model: parent model of attention module
:param module: attention module to initialize with
:param quantize: initialize attention quantization parameters
"""
if not hasattr(module, IMPL_ATTR):
module.register_module(IMPL_ATTR, QuantizedAttentionImpl(module))
if model.config._attn_implementation != HOOKED_ATTENTION_NAME:
# assumes only one model at a time
global _original_impl
Comment on lines +148 to +149
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😬 i don't want to delay things, but we should briefly consider if there are alternative solutions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spent 20 minutes exploring this, it requires creating specialized _ct_hooked_attention functions and specialized QuantizedAttentionImpl, which is more complexity than value added imho

_original_impl = model.config._attn_implementation

AttentionInterface.register(HOOKED_ATTENTION_NAME, _ct_hooked_attention)
model.config._attn_implementation = HOOKED_ATTENTION_NAME

impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR)
if quantize:
impl.initialize_qparams_once(model, module)

initialize_hooked_kv_cache(model, module, quantize=quantize)


# ----- hooks ----- #


def register_query_hook(
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
) -> RemovableHandle:
"""
Register a hook which takes post-rope query states as an argument and
returns the modified query states or `None`

:param module: attention module to add hook to
:param hook: query hook function
"""
impl = getattr(module, IMPL_ATTR)

def _hook(impl: QuantizedAttentionImpl, args, kwargs):
bound = inspect.signature(impl.forward).bind(*args, **kwargs)
value = hook(module, bound.arguments["query"])
if value is not None:
bound.arguments["query"] = value

return bound.args, bound.kwargs

return impl.register_forward_pre_hook(_hook, with_kwargs=True)
191 changes: 191 additions & 0 deletions src/compressed_tensors/modeling/kvcache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Callable, Optional, Tuple
from weakref import ref

from compressed_tensors.quantization import QuantizationStrategy, forward_quantize
from compressed_tensors.quantization.lifecycle.initialize import (
_initialize_scale_zero_point,
)
from compressed_tensors.utils import getattr_chain
from compressed_tensors.utils.internal import InternalModule
from torch import Tensor
from torch.nn import Module
from torch.utils.hooks import RemovableHandle
from transformers import Cache, PreTrainedModel


__all__ = [
"QuantizedKVCache",
"initialize_hooked_kv_cache",
"register_key_hook",
"register_value_hook",
]


KV_CACHE_ATTR = "kv_cache"


class QuantizedKVCache(InternalModule):
"""
QuantizedKVCache module which wraps the functionality of any existing kvcache args.
Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be
hooked to trigger transforms and calibration hooks.

This module works by being registered as a submodule to attention modules via
`initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values`
kwargs with this module. This module adopts the functionality of the replaced cache,
preserving caching functionality such as sliding window attention, ect.

:param attn_module: parent attention module
"""

def __init__(self, attn_module: Module):
super().__init__()
self.attn_module = ref(attn_module) # avoid circular reference
self.past_key_values: Optional[Cache] = None
self._qparams_initialized = False

def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
return self(*args, **kwargs)

def forward(
self,
key_states: Tensor,
value_states: Tensor,
*args,
**kwargs,
) -> Tuple[Tensor, Tensor]:
# quantization
module = self.attn_module()
quant_args_attr = "quantization_scheme.input_activations"
quant_args = getattr_chain(module, quant_args_attr, None)
quant_enabled = getattr(module, "quantization_enabled", True)
if quant_args is not None and quant_enabled and self._qparams_initialized:
key_states = forward_quantize(module, key_states, "k", quant_args)
value_states = forward_quantize(module, value_states, "v", quant_args)

# original cache
if self.past_key_values is not None:
ret = self.past_key_values.update(key_states, value_states, *args, **kwargs)
else:
ret = (key_states, value_states)

self.past_key_values = None
return ret

def initialize_qparams_once(self, model: PreTrainedModel, module: Module):
"""
Initialize kv cache quantization parameters if they have not already been
initialized

:param model: parent model of attention module
:param module: attention module to initialize with
"""
# TODO: move to initialize.py
assert module is self.attn_module()
scheme = getattr(module, "quantization_scheme", None)
quant_args = getattr(scheme, "input_activations", None)

if not self._qparams_initialized and quant_args is not None:
assert quant_args.strategy == QuantizationStrategy.TENSOR
_initialize_scale_zero_point(module, "k", quant_args)
_initialize_scale_zero_point(module, "v", quant_args)
self._qparams_initialized = True


# ----- initialize ----- #


def _kv_cache_attention_hook(module: Module, args, kwargs):
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
_past_kv_name = (
"past_key_values" # transformers#39956
if "past_key_values" in inspect.signature(module.forward).parameters
else "past_key_value"
)
kv_cache.past_key_values = kwargs.get(_past_kv_name, None)
kwargs[_past_kv_name] = kv_cache

return args, kwargs


def initialize_hooked_kv_cache(
model: PreTrainedModel, module: Module, quantize: bool = False
):
"""
Initialize a `QuantizedKVCache` instance attached to attention

:param model: parent model of attention module
:param module: attention module to initialize with
:param quantize: initialize kv cache quantization parameters
"""
if not hasattr(module, KV_CACHE_ATTR):
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(module))
module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True)

kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
if quantize:
kv_cache.initialize_qparams_once(model, module)


# ----- hooks ----- #


def register_key_hook(
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
) -> RemovableHandle:
"""
Register a hook which takes post-rope key states as an argument and
returns the modified key states or `None`

:param module: attention module to add hook to
:param hook: key hook function
"""
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)

def _hook(cache: QuantizedKVCache, args, kwargs):
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
value = hook(module, bound.arguments["key_states"])
if value is not None:
bound.arguments["key_states"] = value

return bound.args, bound.kwargs

return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)


def register_value_hook(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a lot of equivalent code in register_key_hook and register_value_hook, can they call into the same logic with a id string that is either value_states or key_states?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic of "create a kwarg hook which uses the signature of the module child, but calls the hook with the parent module" is pretty specific to these use cases, and is a hard function to name 🙃.

I think it's better to be explicit in this case.

module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
) -> RemovableHandle:
"""
Register a hook which takes value states as an argument and
returns the modified value states or `None`

:param module: attention module to add hook to
:param hook: value hook function
"""
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)

def _hook(cache: QuantizedKVCache, args, kwargs):
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
value = hook(module, bound.arguments["value_states"])
if value is not None:
bound.arguments["value_states"] = value

return bound.args, bound.kwargs

return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)
Loading