48
48
if is_transformers_available ():
49
49
import transformers
50
50
from transformers import PreTrainedModel , PreTrainedTokenizerBase
51
- from transformers .utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
52
51
from transformers .utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
53
52
from transformers .utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
54
53
54
+ if is_transformers_version ("<=" , "4.56.2" ):
55
+ from transformers .utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
56
+
55
57
if is_accelerate_available ():
56
58
import accelerate
57
59
from accelerate import dispatch_model
@@ -112,7 +114,9 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
112
114
]
113
115
114
116
if is_transformers_available ():
115
- weight_names += [TRANSFORMERS_WEIGHTS_NAME , TRANSFORMERS_SAFE_WEIGHTS_NAME , TRANSFORMERS_FLAX_WEIGHTS_NAME ]
117
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME , TRANSFORMERS_SAFE_WEIGHTS_NAME ]
118
+ if is_transformers_version ("<=" , "4.56.2" ):
119
+ weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME ]
116
120
117
121
# model_pytorch, diffusion_model_pytorch, ...
118
122
weight_prefixes = [w .split ("." )[0 ] for w in weight_names ]
@@ -191,7 +195,9 @@ def filter_model_files(filenames):
191
195
]
192
196
193
197
if is_transformers_available ():
194
- weight_names += [TRANSFORMERS_WEIGHTS_NAME , TRANSFORMERS_SAFE_WEIGHTS_NAME , TRANSFORMERS_FLAX_WEIGHTS_NAME ]
198
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME , TRANSFORMERS_SAFE_WEIGHTS_NAME ]
199
+ if is_transformers_version ("<=" , "4.56.2" ):
200
+ weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME ]
195
201
196
202
allowed_extensions = [wn .split ("." )[- 1 ] for wn in weight_names ]
197
203
@@ -212,7 +218,9 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
212
218
]
213
219
214
220
if is_transformers_available ():
215
- weight_names += [TRANSFORMERS_WEIGHTS_NAME , TRANSFORMERS_SAFE_WEIGHTS_NAME , TRANSFORMERS_FLAX_WEIGHTS_NAME ]
221
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME , TRANSFORMERS_SAFE_WEIGHTS_NAME ]
222
+ if is_transformers_version ("<=" , "4.56.2" ):
223
+ weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME ]
216
224
217
225
# model_pytorch, diffusion_model_pytorch, ...
218
226
weight_prefixes = [w .split ("." )[0 ] for w in weight_names ]
0 commit comments