-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[quantization] feat: support aobaseconfig classes in TorchAOConfig
#12275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
a93c429
fe3ae22
4187dd0
3cebb5f
44ed55e
fd22426
0123112
beb05c6
3e0521a
5524a9d
f0b1b11
a8bcb03
32c09b3
b0fc610
ee62edf
8d6ed57
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,12 +21,13 @@ | |
""" | ||
|
||
import copy | ||
import dataclasses | ||
import importlib.metadata | ||
import inspect | ||
import json | ||
import os | ||
import warnings | ||
from dataclasses import dataclass | ||
from dataclasses import dataclass, is_dataclass | ||
from enum import Enum | ||
from functools import partial | ||
from typing import Any, Callable, Dict, List, Optional, Union | ||
|
@@ -443,7 +444,7 @@ class TorchAoConfig(QuantizationConfigMixin): | |
"""This is a config class for torchao quantization/sparsity techniques. | ||
|
||
Args: | ||
quant_type (`str`): | ||
quant_type (Union[`str`, AOBaseConfig]): | ||
The type of quantization we want to use, currently supporting: | ||
- **Integer quantization:** | ||
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`, | ||
|
@@ -465,6 +466,7 @@ class TorchAoConfig(QuantizationConfigMixin): | |
- **Unsigned Integer quantization:** | ||
- Full function names: `uintx_weight_only` | ||
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` | ||
- An AOBaseConfig instance: for more advanced configuration options. | ||
modules_to_not_convert (`List[str]`, *optional*, default to `None`): | ||
The list of modules to not quantize, useful for quantizing models that explicitly require to have some | ||
modules left in their original precision. | ||
|
@@ -478,6 +480,12 @@ class TorchAoConfig(QuantizationConfigMixin): | |
```python | ||
from diffusers import FluxTransformer2DModel, TorchAoConfig | ||
|
||
# AOBaseConfig-based configuration | ||
from torchao.quantization import Int8WeightOnlyConfig | ||
|
||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig()) | ||
|
||
# String-based config | ||
quantization_config = TorchAoConfig("int8wo") | ||
transformer = FluxTransformer2DModel.from_pretrained( | ||
"black-forest-labs/Flux.1-Dev", | ||
|
@@ -490,7 +498,7 @@ class TorchAoConfig(QuantizationConfigMixin): | |
|
||
def __init__( | ||
self, | ||
quant_type: str, | ||
quant_type: Union[str, "AOBaseConfig"], # noqa: F821 | ||
modules_to_not_convert: Optional[List[str]] = None, | ||
**kwargs, | ||
) -> None: | ||
|
@@ -504,8 +512,13 @@ def __init__( | |
else: | ||
self.quant_type_kwargs = kwargs | ||
|
||
self.post_init() | ||
|
||
def post_init(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think it would be cleaner to run the checks in the begining and exit/fail early def post_init(self):
if not isinstance(self.quant_type, str):
if not is_torchao_version(">=", "0.9.0"):
raise ValueError(
f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. "
f"Upgrade to torchao > 0.9.0 to use AOBaseConfig."
)
from torchao.quantization.quant_api import AOBaseConfig
if not isinstance(self.quant_type, AOBaseConfig):
raise TypeError(
f"quant_type must be a string or AOBaseConfig instance, got {type(self.quant_type).__name__}"
)
return
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS:
# remaining str type validation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This will be a breaking change. I think we should introduce a deprecation cycle before enforcing this.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why breaking? The error is only raised if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() | ||
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): | ||
AO_VERSION = self._get_ao_version() | ||
|
||
if isinstance(self.quant_type, str) and self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): | ||
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp") | ||
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9(): | ||
raise ValueError( | ||
|
@@ -517,22 +530,95 @@ def __init__( | |
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the " | ||
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." | ||
) | ||
elif AO_VERSION > version.parse("0.9.0"): | ||
from torchao.quantization.quant_api import AOBaseConfig | ||
|
||
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type] | ||
signature = inspect.signature(method) | ||
all_kwargs = { | ||
param.name | ||
for param in signature.parameters.values() | ||
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD] | ||
} | ||
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs) | ||
|
||
if len(unsupported_kwargs) > 0: | ||
if not isinstance(self.quant_type, AOBaseConfig): | ||
raise TypeError( | ||
f"`quant_type` must be either a string or an `AOBaseConfig` instance, got {type(self.quant_type)}." | ||
) | ||
else: | ||
raise ValueError( | ||
f'The quantization method "{quant_type}" does not support the following keyword arguments: ' | ||
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}." | ||
f"In torchao <= 0.9.0, quant_type must be a string. Got {type(self.quant_type)}. " | ||
f"Please upgrade to torchao > 0.9.0 to use `AOBaseConfig` instances." | ||
) | ||
|
||
if isinstance(self.quant_type, str): | ||
|
||
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type] | ||
signature = inspect.signature(method) | ||
all_kwargs = { | ||
param.name | ||
for param in signature.parameters.values() | ||
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD] | ||
} | ||
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs) | ||
|
||
if len(unsupported_kwargs) > 0: | ||
raise ValueError( | ||
f'The quantization method "{self.quant_type}" does not support the following keyword arguments: ' | ||
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}." | ||
) | ||
|
||
def to_dict(self): | ||
"""Convert configuration to a dictionary.""" | ||
d = super().to_dict() | ||
|
||
if isinstance(self.quant_type, str): | ||
# Handle layout serialization if present | ||
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]: | ||
if is_dataclass(d["quant_type_kwargs"]["layout"]): | ||
d["quant_type_kwargs"]["layout"] = [ | ||
d["quant_type_kwargs"]["layout"].__class__.__name__, | ||
dataclasses.asdict(d["quant_type_kwargs"]["layout"]), | ||
] | ||
if isinstance(d["quant_type_kwargs"]["layout"], list): | ||
assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layout kwargs" | ||
assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string" | ||
assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict" | ||
else: | ||
raise ValueError("layout must be a list") | ||
else: | ||
# Handle AOBaseConfig serialization | ||
from torchao.core.config import config_to_dict | ||
|
||
# For now we assume there is 1 config per Transformer, however in the future | ||
# We may want to support a config per fqn. | ||
d["quant_type"] = {"default": config_to_dict(self.quant_type)} | ||
|
||
return d | ||
|
||
@classmethod | ||
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): | ||
"""Create configuration from a dictionary.""" | ||
ao_version = cls._get_ao_version() | ||
assert ao_version > version.parse("0.9.0"), "TorchAoConfig requires torchao > 0.9.0 for construction from dict" | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
config_dict = config_dict.copy() | ||
quant_type = config_dict.pop("quant_type") | ||
|
||
if isinstance(quant_type, str): | ||
return cls(quant_type=quant_type, **config_dict) | ||
# Check if we only have one key which is "default" | ||
# In the future we may update this | ||
assert len(quant_type) == 1 and "default" in quant_type, ( | ||
"Expected only one key 'default' in quant_type dictionary" | ||
) | ||
quant_type = quant_type["default"] | ||
|
||
# Deserialize quant_type if needed | ||
from torchao.core.config import config_from_dict | ||
|
||
quant_type = config_from_dict(quant_type) | ||
|
||
return cls(quant_type=quant_type, **config_dict) | ||
|
||
@staticmethod | ||
def _get_ao_version() -> version.Version: | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
"""Centralized check for TorchAO availability and version requirements.""" | ||
if not is_torchao_available(): | ||
raise ValueError("TorchAoConfig requires torchao to be installed. Install with `pip install torchao`") | ||
|
||
return version.parse(importlib.metadata.version("torchao")) | ||
|
||
@classmethod | ||
def _get_torchao_quant_type_to_method(cls): | ||
r""" | ||
|
@@ -681,8 +767,38 @@ def _is_xpu_or_cuda_capability_atleast_8_9() -> bool: | |
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.") | ||
|
||
def get_apply_tensor_subclass(self): | ||
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() | ||
return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs) | ||
"""Create the appropriate quantization method based on configuration.""" | ||
if isinstance(self.quant_type, str): | ||
methods = self._get_torchao_quant_type_to_method() | ||
quant_type_kwargs = self.quant_type_kwargs.copy() | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should probably clean this up in the future.. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe not even needed since we will deprecate this codepath after this PR is merged. |
||
not torch.cuda.is_available() | ||
and is_torchao_available() | ||
and self.quant_type == "int4_weight_only" | ||
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0") | ||
and quant_type_kwargs.get("layout", None) is None | ||
): | ||
if torch.xpu.is_available(): | ||
if version.parse(importlib.metadata.version("torchao")) >= version.parse( | ||
"0.11.0" | ||
) and version.parse(importlib.metadata.version("torch")) > version.parse("2.7.9"): | ||
from torchao.dtypes import Int4XPULayout | ||
from torchao.quantization.quant_primitives import ZeroPointDomain | ||
|
||
quant_type_kwargs["layout"] = Int4XPULayout() | ||
quant_type_kwargs["zero_point_domain"] = ZeroPointDomain.INT | ||
else: | ||
raise ValueError( | ||
"TorchAoConfig requires torchao >= 0.11.0 and torch >= 2.8.0 for XPU support. Please upgrade the version or use run on CPU with the cpu version pytorch." | ||
) | ||
else: | ||
from torchao.dtypes import Int4CPULayout | ||
|
||
quant_type_kwargs["layout"] = Int4CPULayout() | ||
|
||
return methods[self.quant_type](**quant_type_kwargs) | ||
else: | ||
return self.quant_type | ||
|
||
def __repr__(self): | ||
r""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can deprecate this one since this is less scalable than AOBaseConfig
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will do so after this PR. Meanwhile, if you could review the PR, it'd be helpful.