comparator / src /model_tree.py
albertvillanova's picture
Make the code robust against HTTP errors
719c272 verified
import asyncio
import gradio as gr
import src.constants as constants
from src.hub import list_models, load_model_card
async def load_model_tree(result_paths_per_model, model_ids):
# TODO: Multiple models?
model_id = model_ids[0]
model_tree = await asyncio.gather(
load_base_models(model_id),
*[
load_derived_models_by_type(model_id, derived_model_type[1])
for derived_model_type in constants.DERIVED_MODEL_TYPES
],
)
model_tree_choices = [
[model_id for model_id in model_ids if model_id in result_paths_per_model] for model_ids in model_tree
]
model_tree_labels = [constants.BASE_MODEL_TYPE[0]] + [
derived_model_type[0] for derived_model_type in constants.DERIVED_MODEL_TYPES
]
return [
gr.Dropdown(choices=choices, label=f"{label} ({len(choices)})", interactive=True if choices else False)
for choices, label in zip(model_tree_choices, model_tree_labels)
]
async def load_base_models(model_id) -> list[str]:
card = await load_model_card(model_id)
if not card:
return []
base_models = getattr(card.data, constants.BASE_MODEL_TYPE[1])
if not isinstance(base_models, list):
base_models = [base_models]
return base_models
async def load_derived_models_by_type(model_id, derived_model_type) -> list[str]:
models = await list_models(filtering=f"base_model:{derived_model_type}:{model_id}")
if not models:
return []
models = [model["id"] for model in models]
return models