import gc import logging from typing import List import torch from h2o_wave import Q from llm_studio.app_utils.sections.chat import chat_tab from llm_studio.app_utils.sections.chat_update import chat_update from llm_studio.app_utils.sections.common import delete_dialog from llm_studio.app_utils.sections.dataset import ( dataset_delete_current_datasets, dataset_delete_single, dataset_display, dataset_edit, dataset_import, dataset_import_uploaded_file, dataset_list, dataset_list_delete, dataset_merge, dataset_newexperiment, ) from llm_studio.app_utils.sections.experiment import ( experiment_delete, experiment_display, experiment_download_logs, experiment_download_model, experiment_download_predictions, experiment_list, experiment_push_to_huggingface_dialog, experiment_rename_ui_workflow, experiment_run, experiment_start, experiment_stop, ) from llm_studio.app_utils.sections.home import home from llm_studio.app_utils.sections.project import ( current_experiment_compare, current_experiment_list_compare, current_experiment_list_delete, current_experiment_list_stop, experiment_rename_action_workflow, list_current_experiments, ) from llm_studio.app_utils.sections.settings import settings from llm_studio.app_utils.setting_utils import ( load_default_user_settings, load_user_settings_and_secrets, save_user_settings_and_secrets, ) from llm_studio.app_utils.utils import add_model_type from llm_studio.app_utils.wave_utils import report_error, wave_utils_handle_error logger = logging.getLogger(__name__) async def handle(q: Q) -> None: """Handles all requests in application and calls according functions.""" # logger.info(f"args: {q.args}") # logger.info(f"events: {q.events}") if not ( q.args.__wave_submission_name__ == "experiment/display/chat/chatbot" or q.args.__wave_submission_name__ == "experiment/display/chat/clear_history" ): if "experiment/display/chat/cfg" in q.client: del q.client["experiment/display/chat/cfg"] if "experiment/display/chat/model" in q.client: del q.client["experiment/display/chat/model"] if "experiment/display/chat/tokenizer" in q.client: del q.client["experiment/display/chat/tokenizer"] torch.cuda.empty_cache() gc.collect() try: if q.args.__wave_submission_name__ == "home": await home(q) elif q.args.__wave_submission_name__ == "settings": await settings(q) elif q.args.__wave_submission_name__ == "save_settings": logger.info("Saving user settings") await save_user_settings_and_secrets(q) await settings(q) elif q.args.__wave_submission_name__ == "load_settings": load_user_settings_and_secrets(q) await settings(q) elif q.args.__wave_submission_name__ == "restore_default_settings": load_default_user_settings(q) await settings(q) elif q.args.__wave_submission_name__ == "report_error": await report_error(q) elif q.args.__wave_submission_name__ == "dataset/import": await dataset_import(q, step=1) elif q.args.__wave_submission_name__ == "dataset/list": await dataset_list(q) elif q.args.__wave_submission_name__ == "dataset/list/delete/abort": q.page["dataset/list"].items[0].table.multiple = False await dataset_list(q, reset=True) elif q.args.__wave_submission_name__ == "dataset/list/abort": q.page["dataset/list"].items[0].table.multiple = False await dataset_list(q, reset=True) elif q.args.__wave_submission_name__ == "dataset/list/delete": await dataset_list_delete(q) elif q.args.__wave_submission_name__ == "dataset/delete/single": dataset_id = q.client["dataset/delete/single/id"] dataset_id = q.client["dataset/list/df_datasets"]["id"].iloc[dataset_id] await dataset_delete_single(q, int(dataset_id)) elif q.args.__wave_submission_name__ == "dataset/delete/dialog/single": dataset_id = int(q.args["dataset/delete/dialog/single"]) q.client["dataset/delete/single/id"] = dataset_id name = q.client["dataset/list/df_datasets"]["name"].iloc[dataset_id] if q.client["delete_dialogs"]: await delete_dialog(q, [name], "dataset/delete/single", "dataset") else: dataset_id = q.client["dataset/list/df_datasets"]["id"].iloc[dataset_id] await dataset_delete_single(q, int(dataset_id)) elif q.args["dataset/delete/dialog"]: names = list( q.client["dataset/list/df_datasets"]["name"].iloc[ list(map(int, q.client["dataset/list/table"])) ] ) if not names: return if q.client["delete_dialogs"]: await delete_dialog(q, names, "dataset/delete", "dataset") else: await dataset_delete_current_datasets(q) elif q.args.__wave_submission_name__ == "dataset/delete": await dataset_delete_current_datasets(q) elif q.args.__wave_submission_name__ == "dataset/edit": if q.client["dataset/list/df_datasets"] is not None: dataset_id = int(q.args["dataset/edit"]) dataset_id = q.client["dataset/list/df_datasets"]["id"].iloc[dataset_id] await dataset_edit(q, int(dataset_id)) elif q.args.__wave_submission_name__ == "dataset/newexperiment": if q.client["dataset/list/df_datasets"] is not None: dataset_id = int(q.args["dataset/newexperiment"]) dataset_id = q.client["dataset/list/df_datasets"]["id"].iloc[dataset_id] await dataset_newexperiment(q, int(dataset_id)) elif q.args.__wave_submission_name__ == "dataset/newexperiment/from_current": idx = q.client["dataset/display/id"] dataset_id = q.client["dataset/list/df_datasets"]["id"].iloc[idx] await dataset_newexperiment(q, dataset_id) elif q.args.__wave_submission_name__ == "dataset/list/table": q.client["dataset/display/id"] = int(q.args["dataset/list/table"][0]) await dataset_display(q) elif q.args.__wave_submission_name__ == "dataset/display/visualization": await dataset_display(q) elif q.args.__wave_submission_name__ == "dataset/display/data": await dataset_display(q) elif q.args.__wave_submission_name__ == "dataset/display/statistics": await dataset_display(q) elif q.args["dataset/display/summary"]: await dataset_display(q) elif ( q.args.__wave_submission_name__ == "experiment/start/run" or q.args.__wave_submission_name__ == "experiment/start/error/proceed" ): # add model type to cfg file name here q.client["experiment/start/cfg_file"] = add_model_type( q.client["experiment/start/cfg_file"], q.client["experiment/start/cfg_sub"], ) q.client.delete_cards.add("experiment/start") await experiment_run(q, pre="experiment/start") q.client["experiment/list/mode"] = "train" elif ( q.args.__wave_submission_name__ == "experiment/start_experiment" or q.args.__wave_submission_name__ == "experiment/list/new" ): if q.client["experiment/list/df_experiments"] is not None: selected_idx = int(q.args["experiment/list/new"]) experiment_id = q.client["experiment/list/df_experiments"]["id"].iloc[ selected_idx ] q.client["experiment/start/cfg_category"] = "experiment" q.client["experiment/start/cfg_file"] = "experiment" q.client["experiment/start/cfg_experiment"] = str(experiment_id) await experiment_start(q) elif q.args.__wave_submission_name__ == "experiment/start": q.client["experiment/start/cfg_category"] = None q.client["experiment/start/cfg_file"] = None datasets_df = q.client.app_db.get_datasets_df() if datasets_df.shape[0] == 0: info = "Import dataset before you create an experiment. " await dataset_import(q, step=1, info=info) else: await experiment_start(q) elif q.args.__wave_submission_name__ == "experiment/display/download_logs": await experiment_download_logs(q) elif ( q.args.__wave_submission_name__ == "experiment/display/download_predictions" ): await experiment_download_predictions(q) elif q.args.__wave_submission_name__ == "experiment/list": q.client["experiment/list/mode"] = "train" await experiment_list(q) elif q.args.__wave_submission_name__ == "experiment/list/current": await list_current_experiments(q) elif q.args.__wave_submission_name__ == "experiment/list/current/noreset": await list_current_experiments(q, reset=False) elif q.args.__wave_submission_name__ == "experiment/list/refresh": await experiment_list(q) elif q.args.__wave_submission_name__ == "experiment/list/abort": await list_current_experiments(q) elif q.args.__wave_submission_name__ == "experiment/list/stop": await current_experiment_list_stop(q) elif q.args.__wave_submission_name__ == "experiment/list/delete": await current_experiment_list_delete(q) elif q.args.__wave_submission_name__ == "experiment/list/rename": await experiment_rename_ui_workflow(q) elif q.args.__wave_submission_name__ == "experiment/list/compare": await current_experiment_list_compare(q) elif ( q.args.__wave_submission_name__ == "experiment/stop" or q.args.__wave_submission_name__ == "experiment/list/stop/table" ): if q.args["experiment/list/stop/table"]: idx = int(q.args["experiment/list/stop/table"]) selected_id = q.client["experiment/list/df_experiments"]["id"].iloc[idx] experiment_ids = [selected_id] else: selected_idxs = q.client["experiment/list/table"] experiment_ids = list( q.client["experiment/list/df_experiments"]["id"].iloc[ list(map(int, selected_idxs)) ] ) await experiment_stop(q, experiment_ids) await list_current_experiments(q) elif q.args.__wave_submission_name__ == "experiment/list/delete/table/dialog": idx = int(q.args["experiment/list/delete/table/dialog"]) names = [q.client["experiment/list/df_experiments"]["name"].iloc[idx]] selected_id = q.client["experiment/list/df_experiments"]["id"].iloc[idx] q.client["experiment/delete/single/id"] = selected_id if q.client["delete_dialogs"]: await delete_dialog( q, names, "experiment/list/delete/table", "experiment" ) else: await experiment_delete_all_artifacts(q, [selected_id]) elif q.args.__wave_submission_name__ == "experiment/delete/dialog": selected_idxs = q.client["experiment/list/table"] exp_df = q.client["experiment/list/df_experiments"] names = list(exp_df["name"].iloc[list(map(int, selected_idxs))]) if not names: return if q.client["delete_dialogs"]: await delete_dialog(q, names, "experiment/delete", "experiment") else: experiment_ids = list(exp_df["id"].iloc[list(map(int, selected_idxs))]) await experiment_delete_all_artifacts(q, experiment_ids) elif ( q.args.__wave_submission_name__ == "experiment/delete" or q.args.__wave_submission_name__ == "experiment/list/delete/table" ): if q.args["experiment/list/delete/table"]: selected_id = q.client["experiment/delete/single/id"] experiment_ids = [selected_id] else: selected_idxs = q.client["experiment/list/table"] exp_df = q.client["experiment/list/df_experiments"] experiment_ids = list(exp_df["id"].iloc[list(map(int, selected_idxs))]) await experiment_delete_all_artifacts(q, experiment_ids) elif q.args.__wave_submission_name__ == "experiment/rename/action": await experiment_rename_action_workflow(q) elif q.args.__wave_submission_name__ == "experiment/compare": await current_experiment_compare(q) elif q.args.__wave_submission_name__ == "experiment/compare/charts": await current_experiment_compare(q) elif q.args.__wave_submission_name__ == "experiment/compare/config": await current_experiment_compare(q) elif q.args.__wave_submission_name__ == "experiment/compare/diff_toggle": q.client["experiment/compare/diff_toggle"] = q.args[ "experiment/compare/diff_toggle" ] await current_experiment_compare(q) elif q.args.__wave_submission_name__ == "experiment/list/table": q.client["experiment/display/id"] = int(q.args["experiment/list/table"][0]) q.client["experiment/display/logs_path"] = None q.client["experiment/display/preds_path"] = None q.client["experiment/display/tab"] = None await experiment_display(q) elif q.args.__wave_submission_name__ == "experiment/display/refresh": await experiment_display(q) elif q.args.__wave_submission_name__ == "experiment/display/charts": await experiment_display(q) elif q.args.__wave_submission_name__ == "experiment/display/summary": await experiment_display(q) elif ( q.args.__wave_submission_name__ == "experiment/display/train_data_insights" ): await experiment_display(q) elif ( q.args.__wave_submission_name__ == "experiment/display/validation_prediction_insights" ): await experiment_display(q) elif ( q.args.__wave_submission_name__ == "experiment/display/push_to_huggingface" ): await experiment_push_to_huggingface_dialog(q) elif q.args.__wave_submission_name__ == "experiment/display/download_model": await experiment_download_model(q) elif ( q.args.__wave_submission_name__ == "experiment/display/push_to_huggingface_submit" ): await experiment_push_to_huggingface_dialog(q) elif q.args.__wave_submission_name__ == "experiment/display/config": await experiment_display(q) elif q.args.__wave_submission_name__ == "experiment/display/logs": await experiment_display(q) elif q.args.__wave_submission_name__ == "experiment/display/chat": await experiment_display(q) elif q.args.__wave_submission_name__ == "experiment/display/chat/chatbot": await chat_update(q) elif q.args.__wave_submission_name__ == "experiment/display/chat/clear_history": await chat_tab(q, load_model=False) elif q.args.__wave_submission_name__ == "dataset/import/local_upload": await dataset_import_uploaded_file(q) elif q.args.__wave_submission_name__ == "dataset/import/local_path_list": await dataset_import(q, step=1) elif q.args.__wave_submission_name__ == "dataset/import/2": await dataset_import(q, step=2) elif q.args.__wave_submission_name__ == "dataset/import/3": await dataset_import(q, step=3) elif q.args.__wave_submission_name__ == "dataset/import/3/edit": await dataset_import(q, step=3, edit=True) elif q.args.__wave_submission_name__ == "dataset/import/4": await dataset_import(q, step=4) elif q.args.__wave_submission_name__ == "dataset/import/4/edit": await dataset_import(q, step=4, edit=True) elif q.args.__wave_submission_name__ == "dataset/import/6": await dataset_import(q, step=6) elif ( q.args.__wave_submission_name__ == "dataset/import/source" and not q.args["dataset/list"] ): await dataset_import(q, step=1) elif q.args.__wave_submission_name__ == "dataset/merge": await dataset_merge(q, step=1) elif q.args.__wave_submission_name__ == "dataset/merge/action": await dataset_merge(q, step=2) elif q.args.__wave_submission_name__ == "dataset/import/cfg_file": await dataset_import(q, step=3) # leave at the end of dataset import routing, # would also be triggered if user clicks on # a continue button in the dataset import wizard elif q.args.__wave_submission_name__ == "dataset/import/cfg/train_dataframe": await dataset_import(q, step=3) elif q.args.__wave_submission_name__ == "experiment/start/cfg_file": q.client["experiment/start/cfg_file"] = q.args["experiment/start/cfg_file"] await experiment_start(q) elif q.args.__wave_submission_name__ == "experiment/start/dataset": await experiment_start(q) elif q.client["nav/active"] == "experiment/start": await experiment_start(q) except Exception as unknown_exception: logger.error("Unknown exception", exc_info=True) await wave_utils_handle_error( q, error=unknown_exception, ) async def experiment_delete_all_artifacts(q: Q, experiment_ids: List[int]): await experiment_stop(q, experiment_ids) await experiment_delete(q, experiment_ids) await list_current_experiments(q)