model_explorer2 / app.py
donb-hf's picture
improve examples and error handling
4b29566
raw
history blame
3.14 kB
import gradio as gr
import os
import torch, torchvision, einops
import spaces
import subprocess
from transformers import AutoModelForCausalLM
from huggingface_hub import login
# Install required package
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token, add_to_git_credential=True)
# Cache for storing loaded models and their summaries
model_cache = {}
# Function to get the model summary
@spaces.GPU
def get_model_summary(model_name):
if model_name in model_cache:
return model_cache[model_name], ""
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(device)
model_summary = str(model)
model_cache[model_name] = model_summary
return model_summary, ""
except Exception as e:
return "", str(e)
# Create the Gradio Blocks interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
textbox = gr.Textbox(label="Model Name", placeholder="Enter the model name here OR select example below...", lines=1)
gr.Markdown("### Vision Models")
vision_examples = gr.Examples(
examples=[
["liuhaotian/llava-v1.6-mistral-7b"],
["microsoft/llava-med-v1.5-mistral-7b"],
["llava-hf/llava-v1.6-mistral-7b-hf"],
["xtuner/llava-phi-3-mini-hf"],
["xtuner/llava-llama-3-8b-v1_1-transformers"],
["vikhyatk/moondream2"],
["openbmb/MiniCPM-Llama3-V-2_5"],
["microsoft/Phi-3-vision-128k-instruct"],
["google/paligemma-3b-mix-224"],
["HuggingFaceM4/idefics2-8b-chatty"]
],
inputs=textbox
)
gr.Markdown("### Other Models")
other_examples = gr.Examples(
examples=[
["google/gemma-7b"],
["microsoft/Phi-3-mini-4k-instruct"],
["meta-llama/Meta-Llama-3-8B"],
["mistralai/Mistral-7B-Instruct-v0.3"],
["mistralai/Codestral-22B-v0.1"]
],
inputs=textbox
)
submit_button = gr.Button("Submit")
with gr.Column():
output = gr.Textbox(label="Model Architecture", lines=20, placeholder="Model architecture will appear here...", show_copy_button=True)
error_output = gr.Textbox(label="Error", lines=10, placeholder="Exceptions will appear here...", show_copy_button=True)
def handle_click(model_name):
model_summary, error_message = get_model_summary(model_name)
return model_summary, error_message
submit_button.click(fn=handle_click, inputs=textbox, outputs=[output, error_output])
# Launch the interface
demo.launch()