Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 134 additions & 18 deletions src/diffusers/quantizers/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
"""

import copy
import dataclasses
import importlib.metadata
import inspect
import json
import os
import warnings
from dataclasses import dataclass
from dataclasses import dataclass, is_dataclass
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
Expand Down Expand Up @@ -443,7 +444,7 @@ class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques.

Args:
quant_type (`str`):
quant_type (Union[`str`, AOBaseConfig]):
The type of quantization we want to use, currently supporting:
- **Integer quantization:**
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
Expand All @@ -465,6 +466,7 @@ class TorchAoConfig(QuantizationConfigMixin):
- **Unsigned Integer quantization:**
- Full function names: `uintx_weight_only`
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
- An AOBaseConfig instance: for more advanced configuration options.
modules_to_not_convert (`List[str]`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision.
Expand All @@ -478,6 +480,12 @@ class TorchAoConfig(QuantizationConfigMixin):
```python
from diffusers import FluxTransformer2DModel, TorchAoConfig

# AOBaseConfig-based configuration
from torchao.quantization import Int8WeightOnlyConfig

quantization_config = TorchAoConfig(Int8WeightOnlyConfig())

# String-based config

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can deprecate this one since this is less scalable than AOBaseConfig

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will do so after this PR. Meanwhile, if you could review the PR, it'd be helpful.

quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
Expand All @@ -490,7 +498,7 @@ class TorchAoConfig(QuantizationConfigMixin):

def __init__(
self,
quant_type: str,
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
modules_to_not_convert: Optional[List[str]] = None,
**kwargs,
) -> None:
Expand All @@ -504,8 +512,13 @@ def __init__(
else:
self.quant_type_kwargs = kwargs

self.post_init()

def post_init(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think it would be cleaner to run the checks in the begining and exit/fail early

def post_init(self):
    if not isinstance(self.quant_type, str):
        if not is_torchao_version(">=", "0.9.0"):
            raise ValueError(
                f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. "
                f"Upgrade to torchao > 0.9.0 to use AOBaseConfig."
            )
        
        from torchao.quantization.quant_api import AOBaseConfig
        if not isinstance(self.quant_type, AOBaseConfig):
            raise TypeError(
                f"quant_type must be a string or AOBaseConfig instance, got {type(self.quant_type).__name__}"
            )
        return
    
    TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
    
    if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS:
        # remaining str type validation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not isinstance(self.quant_type, str):
        if not is_torchao_version(">=", "0.9.0"):
            raise ValueError(
                f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. "
                f"Upgrade to torchao > 0.9.0 to use AOBaseConfig."
            )

This will be a breaking change. I think we should introduce a deprecation cycle before enforcing this.

Copy link
Collaborator

@DN6 DN6 Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why breaking? The error is only raised if quant_type is not a string and torchao<=0.9.0? Can change the second check to if is_torchao_version("<=", "0.9.0"): if that is more clear?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
AO_VERSION = self._get_ao_version()

if isinstance(self.quant_type, str) and self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
raise ValueError(
Expand All @@ -517,22 +530,95 @@ def __init__(
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)
elif AO_VERSION > version.parse("0.9.0"):
from torchao.quantization.quant_api import AOBaseConfig

method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
signature = inspect.signature(method)
all_kwargs = {
param.name
for param in signature.parameters.values()
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
}
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)

if len(unsupported_kwargs) > 0:
if not isinstance(self.quant_type, AOBaseConfig):
raise TypeError(
f"`quant_type` must be either a string or an `AOBaseConfig` instance, got {type(self.quant_type)}."
)
else:
raise ValueError(
f'The quantization method "{quant_type}" does not support the following keyword arguments: '
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
f"In torchao <= 0.9.0, quant_type must be a string. Got {type(self.quant_type)}. "
f"Please upgrade to torchao > 0.9.0 to use `AOBaseConfig` instances."
)

if isinstance(self.quant_type, str):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: merge this with the branch in L521 to keep relevant things relevant together?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
signature = inspect.signature(method)
all_kwargs = {
param.name
for param in signature.parameters.values()
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
}
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)

if len(unsupported_kwargs) > 0:
raise ValueError(
f'The quantization method "{self.quant_type}" does not support the following keyword arguments: '
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
)

def to_dict(self):
"""Convert configuration to a dictionary."""
d = super().to_dict()

if isinstance(self.quant_type, str):
# Handle layout serialization if present
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
if is_dataclass(d["quant_type_kwargs"]["layout"]):
d["quant_type_kwargs"]["layout"] = [
d["quant_type_kwargs"]["layout"].__class__.__name__,
dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
]
if isinstance(d["quant_type_kwargs"]["layout"], list):
assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layout kwargs"
assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
else:
raise ValueError("layout must be a list")
else:
# Handle AOBaseConfig serialization
from torchao.core.config import config_to_dict

# For now we assume there is 1 config per Transformer, however in the future
# We may want to support a config per fqn.
d["quant_type"] = {"default": config_to_dict(self.quant_type)}

return d

@classmethod
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
"""Create configuration from a dictionary."""
ao_version = cls._get_ao_version()
assert ao_version > version.parse("0.9.0"), "TorchAoConfig requires torchao > 0.9.0 for construction from dict"
config_dict = config_dict.copy()
quant_type = config_dict.pop("quant_type")

if isinstance(quant_type, str):
return cls(quant_type=quant_type, **config_dict)
# Check if we only have one key which is "default"
# In the future we may update this
assert len(quant_type) == 1 and "default" in quant_type, (
"Expected only one key 'default' in quant_type dictionary"
)
quant_type = quant_type["default"]

# Deserialize quant_type if needed
from torchao.core.config import config_from_dict

quant_type = config_from_dict(quant_type)

return cls(quant_type=quant_type, **config_dict)

@staticmethod
def _get_ao_version() -> version.Version:
"""Centralized check for TorchAO availability and version requirements."""
if not is_torchao_available():
raise ValueError("TorchAoConfig requires torchao to be installed. Install with `pip install torchao`")

return version.parse(importlib.metadata.version("torchao"))

@classmethod
def _get_torchao_quant_type_to_method(cls):
r"""
Expand Down Expand Up @@ -681,8 +767,38 @@ def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")

def get_apply_tensor_subclass(self):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs)
"""Create the appropriate quantization method based on configuration."""
if isinstance(self.quant_type, str):
methods = self._get_torchao_quant_type_to_method()
quant_type_kwargs = self.quant_type_kwargs.copy()
if (

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably clean this up in the future..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not even needed since we will deprecate this codepath after this PR is merged.

not torch.cuda.is_available()
and is_torchao_available()
and self.quant_type == "int4_weight_only"
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
and quant_type_kwargs.get("layout", None) is None
):
if torch.xpu.is_available():
if version.parse(importlib.metadata.version("torchao")) >= version.parse(
"0.11.0"
) and version.parse(importlib.metadata.version("torch")) > version.parse("2.7.9"):
from torchao.dtypes import Int4XPULayout
from torchao.quantization.quant_primitives import ZeroPointDomain

quant_type_kwargs["layout"] = Int4XPULayout()
quant_type_kwargs["zero_point_domain"] = ZeroPointDomain.INT
else:
raise ValueError(
"TorchAoConfig requires torchao >= 0.11.0 and torch >= 2.8.0 for XPU support. Please upgrade the version or use run on CPU with the cpu version pytorch."
)
else:
from torchao.dtypes import Int4CPULayout

quant_type_kwargs["layout"] = Int4CPULayout()

return methods[self.quant_type](**quant_type_kwargs)
else:
return self.quant_type

def __repr__(self):
r"""
Expand Down
92 changes: 71 additions & 21 deletions src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
"""

import importlib
import re
import types
from fnmatch import fnmatch
from typing import TYPE_CHECKING, Any, Dict, List, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from packaging import version

Expand Down Expand Up @@ -107,6 +108,21 @@ def _update_torch_safe_globals():
_update_torch_safe_globals()


def fuzzy_match_size(config_name: str) -> Optional[str]:
"""
Extract the size digit from strings like "4weight", "8weight". Returns the digit as an integer if found, otherwise
None.
"""
config_name = config_name.lower()

str_match = re.search(r"(\d)weight", config_name)

if str_match:
return str_match.group(1)

return None


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -176,8 +192,7 @@ def validate_environment(self, *args, **kwargs):

def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type

if quant_type.startswith("int") or quant_type.startswith("uint"):
if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
Expand All @@ -197,24 +212,44 @@ def update_torch_dtype(self, torch_dtype):

def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
quant_type = self.quantization_config.quant_type

if quant_type.startswith("int8") or quant_type.startswith("int4"):
# Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
return torch.int8
elif quant_type == "uintx_weight_only":
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
elif quant_type.startswith("uint"):
return {
1: torch.uint1,
2: torch.uint2,
3: torch.uint3,
4: torch.uint4,
5: torch.uint5,
6: torch.uint6,
7: torch.uint7,
}[int(quant_type[4])]
elif quant_type.startswith("float") or quant_type.startswith("fp"):
return torch.bfloat16
from accelerate.utils import CustomDtype

if isinstance(quant_type, str):
if quant_type.startswith("int8"):
# Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
return torch.int8
elif quant_type.startswith("int4"):
return CustomDtype.INT4
elif quant_type == "uintx_weight_only":
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
elif quant_type.startswith("uint"):
return {
1: torch.uint1,
2: torch.uint2,
3: torch.uint3,
4: torch.uint4,
5: torch.uint5,
6: torch.uint6,
7: torch.uint7,
}[int(quant_type[4])]
elif quant_type.startswith("float") or quant_type.startswith("fp"):
return torch.bfloat16

elif self.quantization_config._get_ao_version() > version.Version("0.9.0"):
from torchao.core.config import AOBaseConfig

quant_type = self.quantization_config.quant_type
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)

# Map the extracted digit to appropriate dtype
if size_digit == "4":
return CustomDtype.INT4
else:
# Default to int8
return torch.int8

if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
return target_dtype
Expand Down Expand Up @@ -297,6 +332,21 @@ def get_cuda_warm_up_factor(self):
# Original mapping for non-AOBaseConfig types
# For the uint types, this is a best guess. Once these types become more used
# we can look into their nuances.
if self.quantization_config._get_ao_version() > version.Version("0.9.0"):
from torchao.core.config import AOBaseConfig

quant_type = self.quantization_config.quant_type
# For autoquant case, it will be treated in the string implementation below in map_to_target_dtype
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)

if size_digit == "4":
return 8
else:
return 4

map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4}
quant_type = self.quantization_config.quant_type
for pattern, target_dtype in map_to_target_dtype.items():
Expand Down
Loading