import glob import os import shutil import sys import re import tempfile import zipfile from pathlib import Path import gradio as gr from finetune import finetune_model from language import languages from task import tasks import matplotlib.pyplot as plt os.environ['TEMP_DIR'] = tempfile.mkdtemp() def load_markdown(): with open("intro.md", "r") as f: return f.read() def read_logs(): try: with open(f"output.log", "r") as f: return f.read() except: return None def plot_loss_acc(temp_dir, log_every): sys.stdout.flush() lines = [] with open("output.log", "r") as f: for line in f.readlines(): if re.match(r"^\[\d+\] - loss: \d+\.\d+ - acc: \d+\.\d+$", line): lines.append(line) losses = [] acces = [] if len(lines) == 0: return None, None for line in lines: _, loss, acc = line.split(" - ") losses.append(float(loss.split(":")[1].strip())) acces.append(float(acc.split(":")[1].strip())) x = [i * log_every for i in range(1, len(losses) + 1)] plt.plot(x, losses, label="loss") plt.xlim(log_every // 2, x[-1] + log_every // 2) plt.savefig(f"{temp_dir}/loss.png") plt.clf() plt.plot(x, acces, label="acc") plt.xlim(log_every // 2, x[-1] + log_every // 2) plt.savefig(f"{temp_dir}/acc.png") plt.clf() return f"{temp_dir}/acc.png", f"{temp_dir}/loss.png" def upload_file(fileobj, temp_dir): """ Upload a file and check the uploaded zip file. """ # First check if a file is a zip file. if not zipfile.is_zipfile(fileobj.name): raise gr.Error("Please upload a zip file.") # Then unzip file shutil.unpack_archive(fileobj.name, temp_dir) # check zip file if not os.path.exists(os.path.join(temp_dir, "text")): raise gr.Error("Please upload a valid zip file.") if not os.path.exists(os.path.join(temp_dir, "text_ctc")): raise gr.Error("Please upload a valid zip file.") if not os.path.exists(os.path.join(temp_dir, "audio")): raise gr.Error("Please upload a valid zip file.") # check if all texts and audio matches audio_ids = [] with open(os.path.join(temp_dir, "text"), "r") as f: for line in f.readlines(): audio_ids.append(line.split(maxsplit=1)[0]) with open(os.path.join(temp_dir, "text_ctc"), "r") as f: ctc_audio_ids = [] for line in f.readlines(): ctc_audio_ids.append(line.split(maxsplit=1)[0]) if len(audio_ids) != len(ctc_audio_ids): raise gr.Error( f"Length of `text` ({len(audio_ids)}) and `text_ctc` ({len(ctc_audio_ids)}) is different." ) if set(audio_ids) != set(ctc_audio_ids): raise gr.Error(f"`text` and `text_ctc` have different audio ids.") for audio_id in glob.glob(os.path.join(temp_dir, "audio", "*")): if not Path(audio_id).stem in audio_ids: raise gr.Error(f"Audio id {audio_id} is not in `text` or `text_ctc`.") gr.Info("Successfully uploaded and validated zip file.") return [fileobj] with gr.Blocks(title="OWSM-finetune") as demo: tempdir_path = gr.State(os.environ['TEMP_DIR']) gr.Markdown( """# OWSM finetune demo! Finetune `owsm_v3.1_ebf_base` with your own dataset! Due to resource limitation, you can only train 10 epochs on maximum. ## Upload dataset and define settings """ ) # main contents with gr.Row(): with gr.Column(): file_output = gr.File() upload_button = gr.UploadButton("Click to Upload a File", file_count="single") upload_button.upload( upload_file, [upload_button, tempdir_path], [file_output] ) with gr.Column(): lang = gr.Dropdown( languages["espnet/owsm_v3.1_ebf_base"], label="Language", info="Choose language!", value="jpn", interactive=True, ) task = gr.Dropdown( tasks["espnet/owsm_v3.1_ebf_base"], label="Task", info="Choose task!", value="asr", interactive=True, ) gr.Markdown("## Set training settings") with gr.Row(): with gr.Column(): log_every = gr.Number(value=10, label="log_every", interactive=True) max_epoch = gr.Slider(1, 10, step=1, label="max_epoch", interactive=True) scheduler = gr.Dropdown( ["warmuplr"], label="warmup", value="warmuplr", interactive=True ) warmup_steps = gr.Number( value=100, label="warmup_steps", interactive=True ) with gr.Column(): optimizer = gr.Dropdown( ["adam", "adamw", "sgd", "adadelta", "adagrad", "adamax", "asgd", "rmsprop"], label="optimizer", value="adam", interactive=True ) learning_rate = gr.Number( value=1e-4, label="learning_rate", interactive=True ) weight_decay = gr.Number( value=0.000001, label="weight_decay", interactive=True ) gr.Markdown("## Logs and plots") with gr.Row(): with gr.Column(): log_output = gr.Textbox( show_label=False, interactive=False, max_lines=23, lines=23, ) demo.load(read_logs, None, log_output, every=2) with gr.Column(): log_acc = gr.Image(label="Accuracy", show_label=True, interactive=False) log_loss = gr.Image(label="Loss", show_label=True, interactive=False) demo.load(plot_loss_acc, [tempdir_path, log_every], [log_acc, log_loss], every=10) with gr.Row(): with gr.Column(): ref_text = gr.Textbox( label="Reference text", show_label=True, interactive=False, max_lines=10, lines=10, ) with gr.Column(): base_text = gr.Textbox( label="Baseline text", show_label=True, interactive=False, max_lines=10, lines=10, ) with gr.Row(): with gr.Column(): hyp_text = gr.Textbox( label="Hypothesis text", show_label=True, interactive=False, max_lines=10, lines=10, ) with gr.Column(): trained_model = gr.File( label="Trained model", interactive=False, ) with gr.Row(): finetune_btn = gr.Button("Finetune Model", variant="primary") finetune_btn.click( finetune_model, [ lang, task, tempdir_path, log_every, max_epoch, scheduler, warmup_steps, optimizer, learning_rate, weight_decay, ], [trained_model, hyp_text] ) gr.Markdown(load_markdown()) if __name__ == "__main__": try: demo.queue().launch() except: print("Unexpected error:", sys.exc_info()[0]) raise finally: shutil.rmtree(os.environ['TEMP_DIR'])