File size: 3,473 Bytes
ad6330a
5f6d3e9
 
 
 
f70face
ad6330a
79ab92b
33101b1
ad6330a
 
 
 
 
 
 
 
5f6d3e9
 
 
 
 
 
 
 
ad6330a
f70face
dbfd6b3
5f6d3e9
ad6330a
5f6d3e9
 
 
 
 
ad6330a
5f6d3e9
 
 
ad6330a
5f6d3e9
 
 
 
ad6330a
 
5f6d3e9
 
 
 
 
 
 
 
 
 
 
ad6330a
 
5f6d3e9
 
 
 
 
 
 
 
 
ad6330a
 
5f6d3e9
 
ad6330a
5f6d3e9
ad6330a
5f6d3e9
 
 
 
ad6330a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import subprocess
import os
import torch
from transformers import BitsAndBytesConfig, AutoConfig, AutoModelForCausalLM, LlavaNextForConditionalGeneration, LlavaForConditionalGeneration, PaliGemmaForConditionalGeneration, Idefics2ForConditionalGeneration
from functools import lru_cache
import spaces

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

# Install required package
def install_flash_attn():
    subprocess.run(
        "pip install flash-attn --no-build-isolation",
        env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
        shell=True,
    )

# Architecture to model class mapping
ARCHITECTURE_MAP = {
    "LlavaNextForConditionalGeneration": LlavaNextForConditionalGeneration,
    "LlavaForConditionalGeneration": LlavaForConditionalGeneration,
    "PaliGemmaForConditionalGeneration": PaliGemmaForConditionalGeneration,
    "Idefics2ForConditionalGeneration": Idefics2ForConditionalGeneration,
    "AutoModelForCausalLM": AutoModelForCausalLM
}

# Function to get the model summary with caching and GPU support
@spaces.GPU(duration=120)
@lru_cache(maxsize=10)
def get_model_summary(model_name):
    """
    Retrieve the model summary for the given model name.
    
    Args:
    model_name (str): The name of the model to retrieve the summary for.

    Returns:
    tuple: A tuple containing the model summary (str) and an error message (str), if any.
    """
    try:
        # Fetch the model configuration
        config = AutoConfig.from_pretrained(model_name)
        architecture = config.architectures[0]
        quantization_config = getattr(config, 'quantization_config', None)

        # Set up BitsAndBytesConfig if the model is quantized
        if quantization_config:
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=quantization_config.get('load_in_4bit', False),
                load_in_8bit=quantization_config.get('load_in_8bit', False),
                bnb_4bit_compute_dtype=quantization_config.get('bnb_4bit_compute_dtype', torch.float16),
                bnb_4bit_quant_type=quantization_config.get('bnb_4bit_quant_type', 'nf4'),
                bnb_4bit_use_double_quant=quantization_config.get('bnb_4bit_use_double_quant', False),
                llm_int8_enable_fp32_cpu_offload=quantization_config.get('llm_int8_enable_fp32_cpu_offload', False),
                llm_int8_has_fp16_weight=quantization_config.get('llm_int8_has_fp16_weight', False),
                llm_int8_skip_modules=quantization_config.get('llm_int8_skip_modules', None),
                llm_int8_threshold=quantization_config.get('llm_int8_threshold', 6.0),
            )
        else:
            bnb_config = None

        # Get the appropriate model class from the architecture map
        model_class = ARCHITECTURE_MAP.get(architecture, AutoModelForCausalLM)
        
        # Load the model
        model = model_class.from_pretrained(
            model_name, config=bnb_config, trust_remote_code=True
        )

        # Move to device only if the model is not quantized
        if model and not quantization_config:
            model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

        model_summary = str(model) if model else "Model architecture not found."
        return model_summary, ""
    except ValueError as ve:
        return "", f"ValueError: {ve}"
    except EnvironmentError as ee:
        return "", f"EnvironmentError: {ee}"
    except Exception as e:
        return "", str(e)