import gradio as gr import os import torch, torchvision, einops import spaces import subprocess from transformers import AutoModelForCausalLM from huggingface_hub import login # Install required package subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) hf_token = os.getenv("HF_TOKEN") login(token=hf_token, add_to_git_credential=True) # Cache for storing loaded models and their summaries model_cache = {} # Function to get the model summary @spaces.GPU def get_model_summary(model_name): if model_name in model_cache: return model_cache[model_name], "" try: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(device) model_summary = str(model) model_cache[model_name] = model_summary return model_summary, "" except Exception as e: return "", str(e) # Create the Gradio Blocks interface with gr.Blocks() as demo: with gr.Row(): with gr.Column(): textbox = gr.Textbox(label="Model Name", placeholder="Enter the model name here OR select example below...", lines=1) gr.Markdown("### Vision Models") vision_examples = gr.Examples( examples=[ ["liuhaotian/llava-v1.6-mistral-7b"], ["microsoft/llava-med-v1.5-mistral-7b"], ["llava-hf/llava-v1.6-mistral-7b-hf"], ["xtuner/llava-phi-3-mini-hf"], ["xtuner/llava-llama-3-8b-v1_1-transformers"], ["vikhyatk/moondream2"], ["openbmb/MiniCPM-Llama3-V-2_5"], ["microsoft/Phi-3-vision-128k-instruct"], ["google/paligemma-3b-mix-224"], ["HuggingFaceM4/idefics2-8b-chatty"] ], inputs=textbox ) gr.Markdown("### Other Models") other_examples = gr.Examples( examples=[ ["google/gemma-7b"], ["microsoft/Phi-3-mini-4k-instruct"], ["meta-llama/Meta-Llama-3-8B"], ["mistralai/Mistral-7B-Instruct-v0.3"], ["mistralai/Codestral-22B-v0.1"] ], inputs=textbox ) submit_button = gr.Button("Submit") with gr.Column(): output = gr.Textbox(label="Model Architecture", lines=20, placeholder="Model architecture will appear here...", show_copy_button=True) error_output = gr.Textbox(label="Error", lines=10, placeholder="Exceptions will appear here...", show_copy_button=True) def handle_click(model_name): model_summary, error_message = get_model_summary(model_name) return model_summary, error_message submit_button.click(fn=handle_click, inputs=textbox, outputs=[output, error_output]) # Launch the interface demo.launch()