|
import logging |
|
import os |
|
import shutil |
|
from tempfile import NamedTemporaryFile |
|
|
|
from bokeh.resources import Resources as BokehResources |
|
from h2o_wave import Q, ui |
|
|
|
from llm_studio.app_utils.config import default_cfg |
|
from llm_studio.app_utils.db import Database, Dataset |
|
from llm_studio.app_utils.default_datasets import ( |
|
prepare_default_dataset_causal_language_modeling, |
|
prepare_default_dataset_dpo_modeling, |
|
) |
|
from llm_studio.app_utils.sections.common import interface |
|
from llm_studio.app_utils.setting_utils import load_user_settings_and_secrets |
|
from llm_studio.app_utils.utils import ( |
|
get_data_dir, |
|
get_database_dir, |
|
get_download_dir, |
|
get_output_dir, |
|
get_user_db_path, |
|
get_user_name, |
|
) |
|
from llm_studio.src.utils.config_utils import load_config_py, save_config_yaml |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
async def import_default_data(q: Q): |
|
"""Imports default data""" |
|
|
|
try: |
|
if q.client.app_db.get_dataset(1) is None: |
|
logger.info("Downloading default dataset...") |
|
q.page["meta"].dialog = ui.dialog( |
|
title="Creating default datasets", |
|
blocking=True, |
|
items=[ui.progress(label="Please be patient...")], |
|
) |
|
await q.page.save() |
|
|
|
dataset = prepare_oasst(q) |
|
q.client.app_db.add_dataset(dataset) |
|
dataset = prepare_dpo(q) |
|
q.client.app_db.add_dataset(dataset) |
|
|
|
except Exception as e: |
|
q.client.app_db._session.rollback() |
|
logger.warning(f"Could not download default dataset: {e}") |
|
pass |
|
|
|
|
|
def prepare_oasst(q: Q) -> Dataset: |
|
path = f"{get_data_dir(q)}/oasst" |
|
if os.path.exists(path): |
|
shutil.rmtree(path) |
|
os.makedirs(path, exist_ok=True) |
|
df = prepare_default_dataset_causal_language_modeling(path) |
|
cfg = load_config_py( |
|
config_path=os.path.join("llm_studio/python_configs", default_cfg.cfg_file), |
|
config_name="ConfigProblemBase", |
|
) |
|
cfg.dataset.train_dataframe = os.path.join(path, "train_full.pq") |
|
cfg.dataset.prompt_column = ("instruction",) |
|
cfg.dataset.answer_column = "output" |
|
cfg.dataset.parent_id_column = "None" |
|
cfg_path = os.path.join(path, f"{default_cfg.cfg_file}.yaml") |
|
save_config_yaml(cfg_path, cfg) |
|
dataset = Dataset( |
|
id=1, |
|
name="oasst", |
|
path=path, |
|
config_file=cfg_path, |
|
train_rows=df.shape[0], |
|
) |
|
return dataset |
|
|
|
|
|
def prepare_dpo(q): |
|
path = f"{get_data_dir(q)}/dpo" |
|
if os.path.exists(path): |
|
shutil.rmtree(path) |
|
os.makedirs(path, exist_ok=True) |
|
train_df = prepare_default_dataset_dpo_modeling() |
|
train_df.to_parquet(os.path.join(path, "train.pq"), index=False) |
|
|
|
from llm_studio.python_configs.text_dpo_modeling_config import ConfigDPODataset |
|
from llm_studio.python_configs.text_dpo_modeling_config import ( |
|
ConfigProblemBase as ConfigProblemBaseDPO, |
|
) |
|
|
|
cfg: ConfigProblemBaseDPO = ConfigProblemBaseDPO( |
|
dataset=ConfigDPODataset( |
|
train_dataframe=os.path.join(path, "train.pq"), |
|
system_column="system", |
|
prompt_column=("question",), |
|
answer_column="chosen", |
|
rejected_answer_column="rejected", |
|
), |
|
) |
|
|
|
cfg_path = os.path.join(path, "text_dpo_modeling_config.yaml") |
|
save_config_yaml(cfg_path, cfg) |
|
dataset = Dataset( |
|
id=2, |
|
name="dpo", |
|
path=path, |
|
config_file=cfg_path, |
|
train_rows=train_df.shape[0], |
|
) |
|
return dataset |
|
|
|
|
|
async def initialize_client(q: Q) -> None: |
|
"""Initialize the client.""" |
|
|
|
logger.info(f"Initializing client {q.client.client_initialized}") |
|
|
|
if not q.client.client_initialized: |
|
q.client.delete_cards = set() |
|
q.client.delete_cards.add("init_app") |
|
|
|
os.makedirs(get_data_dir(q), exist_ok=True) |
|
os.makedirs(get_database_dir(q), exist_ok=True) |
|
os.makedirs(get_output_dir(q), exist_ok=True) |
|
os.makedirs(get_download_dir(q), exist_ok=True) |
|
|
|
db_path = get_user_db_path(q) |
|
|
|
q.client.app_db = Database(db_path) |
|
|
|
logger.info(f"User name: {get_user_name(q)}") |
|
|
|
q.client.client_initialized = True |
|
|
|
q.client["mode_curr"] = "full" |
|
load_user_settings_and_secrets(q) |
|
await interface(q) |
|
|
|
await import_default_data(q) |
|
q.args.__wave_submission_name__ = default_cfg.start_page |
|
|
|
return |
|
|
|
|
|
async def initialize_app(q: Q) -> None: |
|
""" |
|
Initialize the app. |
|
|
|
This function is called once when the app is started and stores values in q.app. |
|
""" |
|
|
|
logger.info("Initializing app ...") |
|
|
|
icons_pth = "llm_studio/app_utils/static/" |
|
(q.app["icon_path"],) = await q.site.upload([f"{icons_pth}/icon.png"]) |
|
|
|
script_sources = [] |
|
|
|
with NamedTemporaryFile(mode="w", suffix=".min.js") as f: |
|
|
|
|
|
for js_raw in BokehResources(mode="inline").js_raw: |
|
f.write(js_raw) |
|
f.write("\n") |
|
|
|
(url,) = await q.site.upload([f.name]) |
|
script_sources.append(url) |
|
|
|
q.app["script_sources"] = script_sources |
|
q.app["initialized"] = True |
|
q.app.version = default_cfg.version |
|
q.app.name = default_cfg.name |
|
q.app.heap_mode = default_cfg.heap_mode |
|
|
|
logger.info("Initializing app ... done") |
|
|