Skip to content

Commit 9866325

Browse files
Merge branch 'main' into qwen-image-edit-controlnet
2 parents 5c8fac1 + edd614e commit 9866325

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@
4848
if is_transformers_available():
4949
import transformers
5050
from transformers import PreTrainedModel, PreTrainedTokenizerBase
51-
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
5251
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
5352
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
5453

54+
if is_transformers_version("<=", "4.56.2"):
55+
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
56+
5557
if is_accelerate_available():
5658
import accelerate
5759
from accelerate import dispatch_model
@@ -112,7 +114,9 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
112114
]
113115

114116
if is_transformers_available():
115-
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
117+
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
118+
if is_transformers_version("<=", "4.56.2"):
119+
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
116120

117121
# model_pytorch, diffusion_model_pytorch, ...
118122
weight_prefixes = [w.split(".")[0] for w in weight_names]
@@ -191,7 +195,9 @@ def filter_model_files(filenames):
191195
]
192196

193197
if is_transformers_available():
194-
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
198+
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
199+
if is_transformers_version("<=", "4.56.2"):
200+
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
195201

196202
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
197203

@@ -212,7 +218,9 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
212218
]
213219

214220
if is_transformers_available():
215-
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
221+
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
222+
if is_transformers_version("<=", "4.56.2"):
223+
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
216224

217225
# model_pytorch, diffusion_model_pytorch, ...
218226
weight_prefixes = [w.split(".")[0] for w in weight_names]

0 commit comments

Comments
 (0)