import asyncio import gradio as gr import src.constants as constants from src.hub import list_models, load_model_card async def load_model_tree(result_paths_per_model, model_ids): # TODO: Multiple models? model_id = model_ids[0] model_tree = await asyncio.gather( load_base_models(model_id), *[ load_derived_models_by_type(model_id, derived_model_type[1]) for derived_model_type in constants.DERIVED_MODEL_TYPES ], ) model_tree_choices = [ [model_id for model_id in model_ids if model_id in result_paths_per_model] for model_ids in model_tree ] model_tree_labels = [constants.BASE_MODEL_TYPE[0]] + [ derived_model_type[0] for derived_model_type in constants.DERIVED_MODEL_TYPES ] return [ gr.Dropdown(choices=choices, label=f"{label} ({len(choices)})", interactive=True if choices else False) for choices, label in zip(model_tree_choices, model_tree_labels) ] async def load_base_models(model_id) -> list[str]: card = await load_model_card(model_id) if not card: return [] base_models = getattr(card.data, constants.BASE_MODEL_TYPE[1]) if not isinstance(base_models, list): base_models = [base_models] return base_models async def load_derived_models_by_type(model_id, derived_model_type) -> list[str]: models = await list_models(filtering=f"base_model:{derived_model_type}:{model_id}") if not models: return [] models = [model["id"] for model in models] return models