Spaces:
Paused
Paused
File size: 3,473 Bytes
ad6330a 5f6d3e9 f70face ad6330a 79ab92b 33101b1 ad6330a 5f6d3e9 ad6330a f70face dbfd6b3 5f6d3e9 ad6330a 5f6d3e9 ad6330a 5f6d3e9 ad6330a 5f6d3e9 ad6330a 5f6d3e9 ad6330a 5f6d3e9 ad6330a 5f6d3e9 ad6330a 5f6d3e9 ad6330a 5f6d3e9 ad6330a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import subprocess
import os
import torch
from transformers import BitsAndBytesConfig, AutoConfig, AutoModelForCausalLM, LlavaNextForConditionalGeneration, LlavaForConditionalGeneration, PaliGemmaForConditionalGeneration, Idefics2ForConditionalGeneration
from functools import lru_cache
import spaces
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Install required package
def install_flash_attn():
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
# Architecture to model class mapping
ARCHITECTURE_MAP = {
"LlavaNextForConditionalGeneration": LlavaNextForConditionalGeneration,
"LlavaForConditionalGeneration": LlavaForConditionalGeneration,
"PaliGemmaForConditionalGeneration": PaliGemmaForConditionalGeneration,
"Idefics2ForConditionalGeneration": Idefics2ForConditionalGeneration,
"AutoModelForCausalLM": AutoModelForCausalLM
}
# Function to get the model summary with caching and GPU support
@spaces.GPU(duration=120)
@lru_cache(maxsize=10)
def get_model_summary(model_name):
"""
Retrieve the model summary for the given model name.
Args:
model_name (str): The name of the model to retrieve the summary for.
Returns:
tuple: A tuple containing the model summary (str) and an error message (str), if any.
"""
try:
# Fetch the model configuration
config = AutoConfig.from_pretrained(model_name)
architecture = config.architectures[0]
quantization_config = getattr(config, 'quantization_config', None)
# Set up BitsAndBytesConfig if the model is quantized
if quantization_config:
bnb_config = BitsAndBytesConfig(
load_in_4bit=quantization_config.get('load_in_4bit', False),
load_in_8bit=quantization_config.get('load_in_8bit', False),
bnb_4bit_compute_dtype=quantization_config.get('bnb_4bit_compute_dtype', torch.float16),
bnb_4bit_quant_type=quantization_config.get('bnb_4bit_quant_type', 'nf4'),
bnb_4bit_use_double_quant=quantization_config.get('bnb_4bit_use_double_quant', False),
llm_int8_enable_fp32_cpu_offload=quantization_config.get('llm_int8_enable_fp32_cpu_offload', False),
llm_int8_has_fp16_weight=quantization_config.get('llm_int8_has_fp16_weight', False),
llm_int8_skip_modules=quantization_config.get('llm_int8_skip_modules', None),
llm_int8_threshold=quantization_config.get('llm_int8_threshold', 6.0),
)
else:
bnb_config = None
# Get the appropriate model class from the architecture map
model_class = ARCHITECTURE_MAP.get(architecture, AutoModelForCausalLM)
# Load the model
model = model_class.from_pretrained(
model_name, config=bnb_config, trust_remote_code=True
)
# Move to device only if the model is not quantized
if model and not quantization_config:
model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
model_summary = str(model) if model else "Model architecture not found."
return model_summary, ""
except ValueError as ve:
return "", f"ValueError: {ve}"
except EnvironmentError as ee:
return "", f"EnvironmentError: {ee}"
except Exception as e:
return "", str(e)
|