Spaces:
Sleeping
Sleeping
import os | |
from unittest.mock import patch | |
from transformers.dynamic_module_utils import get_imports | |
def fixed_get_imports(filename: str | os.PathLike) -> list[str]: | |
"""Workaround for flash_attn import issue.""" | |
if not str(filename).endswith(("/modeling_florence2.py", "configuration_florence2.py")): | |
return get_imports(filename) | |
imports = get_imports(filename) | |
if "flash_attn" in imports: | |
imports.remove("flash_attn") | |
return imports | |
def load_model_without_flash_attn(model_loader): | |
"""Load a model using the flash_attn workaround.""" | |
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): | |
return model_loader() | |