|
import asyncio |
|
|
|
import gradio as gr |
|
|
|
import src.constants as constants |
|
from src.hub import list_models, load_model_card |
|
|
|
|
|
async def load_model_tree(result_paths_per_model, model_ids): |
|
|
|
model_id = model_ids[0] |
|
model_tree = await asyncio.gather( |
|
load_base_models(model_id), |
|
*[ |
|
load_derived_models_by_type(model_id, derived_model_type[1]) |
|
for derived_model_type in constants.DERIVED_MODEL_TYPES |
|
], |
|
) |
|
model_tree_choices = [ |
|
[model_id for model_id in model_ids if model_id in result_paths_per_model] for model_ids in model_tree |
|
] |
|
model_tree_labels = [constants.BASE_MODEL_TYPE[0]] + [ |
|
derived_model_type[0] for derived_model_type in constants.DERIVED_MODEL_TYPES |
|
] |
|
return [ |
|
gr.Dropdown(choices=choices, label=f"{label} ({len(choices)})", interactive=True if choices else False) |
|
for choices, label in zip(model_tree_choices, model_tree_labels) |
|
] |
|
|
|
|
|
async def load_base_models(model_id) -> list[str]: |
|
card = await load_model_card(model_id) |
|
if not card: |
|
return [] |
|
base_models = getattr(card.data, constants.BASE_MODEL_TYPE[1]) |
|
if not isinstance(base_models, list): |
|
base_models = [base_models] |
|
return base_models |
|
|
|
|
|
async def load_derived_models_by_type(model_id, derived_model_type) -> list[str]: |
|
models = await list_models(filtering=f"base_model:{derived_model_type}:{model_id}") |
|
if not models: |
|
return [] |
|
models = [model["id"] for model in models] |
|
return models |
|
|