import subprocess import os import torch from transformers import BitsAndBytesConfig, AutoConfig, AutoModelForCausalLM, LlavaNextForConditionalGeneration, LlavaForConditionalGeneration, PaliGemmaForConditionalGeneration, Idefics2ForConditionalGeneration from functools import lru_cache os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # Install required package def install_flash_attn(): subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) # Architecture to model class mapping ARCHITECTURE_MAP = { "LlavaNextForConditionalGeneration": LlavaNextForConditionalGeneration, "LlavaForConditionalGeneration": LlavaForConditionalGeneration, "PaliGemmaForConditionalGeneration": PaliGemmaForConditionalGeneration, "Idefics2ForConditionalGeneration": Idefics2ForConditionalGeneration, "AutoModelForCausalLM": AutoModelForCausalLM } # Function to get the model summary with caching @lru_cache(maxsize=10) def get_model_summary(model_name): """ Retrieve the model summary for the given model name. Args: model_name (str): The name of the model to retrieve the summary for. Returns: tuple: A tuple containing the model summary (str) and an error message (str), if any. """ try: # Fetch the model configuration config = AutoConfig.from_pretrained(model_name) architecture = config.architectures[0] quantization_config = getattr(config, 'quantization_config', None) # Set up BitsAndBytesConfig if the model is quantized if quantization_config: bnb_config = BitsAndBytesConfig( load_in_4bit=quantization_config.get('load_in_4bit', False), load_in_8bit=quantization_config.get('load_in_8bit', False), bnb_4bit_compute_dtype=quantization_config.get('bnb_4bit_compute_dtype', torch.float16), bnb_4bit_quant_type=quantization_config.get('bnb_4bit_quant_type', 'nf4'), bnb_4bit_use_double_quant=quantization_config.get('bnb_4bit_use_double_quant', False), llm_int8_enable_fp32_cpu_offload=quantization_config.get('llm_int8_enable_fp32_cpu_offload', False), llm_int8_has_fp16_weight=quantization_config.get('llm_int8_has_fp16_weight', False), llm_int8_skip_modules=quantization_config.get('llm_int8_skip_modules', None), llm_int8_threshold=quantization_config.get('llm_int8_threshold', 6.0), ) else: bnb_config = None # Get the appropriate model class from the architecture map model_class = ARCHITECTURE_MAP.get(architecture, AutoModelForCausalLM) # Load the model model = model_class.from_pretrained( model_name, config=bnb_config, trust_remote_code=True ) # Move to device only if the model is not quantized if model and not quantization_config: model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) model_summary = str(model) if model else "Model architecture not found." return model_summary, "" except ValueError as ve: return "", f"ValueError: {ve}" except EnvironmentError as ee: return "", f"EnvironmentError: {ee}" except Exception as e: return "", str(e)