Spaces:
Sleeping
Sleeping
import glob | |
import os | |
import shutil | |
import sys | |
import re | |
import tempfile | |
import zipfile | |
from pathlib import Path | |
import gradio as gr | |
from finetune import finetune_model, log | |
from language import languages | |
from task import tasks | |
import matplotlib.pyplot as plt | |
def load_markdown(): | |
with open("intro.md", "r") as f: | |
return f.read() | |
def read_logs(temp_dir): | |
if not os.path.exists(f"{temp_dir}/output.log"): | |
return "Log file not found." | |
try: | |
with open(f"{temp_dir}/output.log", "r") as f: | |
return f.read() | |
except: | |
return None | |
def plot_loss_acc(temp_dir, log_every): | |
sys.stdout.flush() | |
lines = [] | |
if not os.path.exists(f"{temp_dir}/output.log"): | |
return None, None | |
with open(f"{temp_dir}/output.log", "r") as f: | |
for line in f.readlines(): | |
if re.match(r"^\[\d+\] - loss: \d+\.\d+ - acc: \d+\.\d+$", line): | |
lines.append(line) | |
losses = [] | |
acces = [] | |
if len(lines) == 0: | |
return None, None | |
for line in lines: | |
_, loss, acc = line.split(" - ") | |
losses.append(float(loss.split(":")[1].strip())) | |
acces.append(float(acc.split(":")[1].strip())) | |
x = [i * log_every for i in range(1, len(losses) + 1)] | |
plt.plot(x, losses, label="loss") | |
plt.xlim(log_every // 2, x[-1] + log_every // 2) | |
plt.savefig(f"{temp_dir}/loss.png") | |
plt.clf() | |
plt.plot(x, acces, label="acc") | |
plt.xlim(log_every // 2, x[-1] + log_every // 2) | |
plt.savefig(f"{temp_dir}/acc.png") | |
plt.clf() | |
return f"{temp_dir}/acc.png", f"{temp_dir}/loss.png" | |
def upload_file(fileobj, temp_dir): | |
""" | |
Upload a file and check the uploaded zip file. | |
""" | |
# First check if a file is a zip file. | |
if not zipfile.is_zipfile(fileobj.name): | |
log(temp_dir, "Please upload a zip file.") | |
raise gr.Error("Please upload a zip file.") | |
# Then unzip file | |
log(temp_dir, "Unzipping file...") | |
shutil.unpack_archive(fileobj.name, temp_dir) | |
# check zip file | |
if not os.path.exists(os.path.join(temp_dir, "text")): | |
log(temp_dir, "Please upload a valid zip file.") | |
raise gr.Error("Please upload a valid zip file.") | |
if not os.path.exists(os.path.join(temp_dir, "text_ctc")): | |
log(temp_dir, "Please upload a valid zip file.") | |
raise gr.Error("Please upload a valid zip file.") | |
if not os.path.exists(os.path.join(temp_dir, "audio")): | |
log(temp_dir, "Please upload a valid zip file.") | |
raise gr.Error("Please upload a valid zip file.") | |
# check if all texts and audio matches | |
log(temp_dir, "Checking if all texts and audio matches...") | |
audio_ids = [] | |
with open(os.path.join(temp_dir, "text"), "r") as f: | |
for line in f.readlines(): | |
audio_ids.append(line.split(maxsplit=1)[0]) | |
with open(os.path.join(temp_dir, "text_ctc"), "r") as f: | |
ctc_audio_ids = [] | |
for line in f.readlines(): | |
ctc_audio_ids.append(line.split(maxsplit=1)[0]) | |
if len(audio_ids) != len(ctc_audio_ids): | |
raise gr.Error( | |
f"Length of `text` ({len(audio_ids)}) and `text_ctc` ({len(ctc_audio_ids)}) is different." | |
) | |
if set(audio_ids) != set(ctc_audio_ids): | |
log(temp_dir, f"`text` and `text_ctc` have different audio ids.") | |
raise gr.Error(f"`text` and `text_ctc` have different audio ids.") | |
for audio_id in glob.glob(os.path.join(temp_dir, "audio", "*")): | |
if not Path(audio_id).stem in audio_ids: | |
raise gr.Error(f"Audio id {audio_id} is not in `text` or `text_ctc`.") | |
log(temp_dir, "Successfully uploaded and validated zip file.") | |
gr.Info("Successfully uploaded and validated zip file.") | |
return [fileobj] | |
def delete_tmp_dir(tmp_dir): | |
if os.path.exists(tmp_dir): | |
shutil.rmtree(tmp_dir) | |
print(f"Deleted temporary directory: {tmp_dir}") | |
else: | |
print("Temporary directory already deleted") | |
def create_tmp_dir(): | |
tmp_dir = tempfile.mkdtemp() | |
print(f"Created temporary directory: {tmp_dir}") | |
return tmp_dir | |
with gr.Blocks(title="OWSM-finetune") as demo: | |
tempdir_path=gr.State(create_tmp_dir, delete_callback=delete_tmp_dir, time_to_live=600) | |
gr.Markdown( | |
"""# OWSM finetune demo! | |
Finetune `owsm_v3.1_ebf_base` with your own dataset! | |
Due to resource limitation, you can only train 5 epochs on maximum. | |
## Upload dataset and define settings | |
""" | |
) | |
# main contents | |
with gr.Row(): | |
with gr.Column(): | |
file_output = gr.File() | |
upload_button = gr.UploadButton("Click to Upload a File", file_count="single") | |
upload_button.upload( | |
upload_file, [upload_button, tempdir_path], [file_output] | |
) | |
with gr.Column(): | |
lang = gr.Dropdown( | |
languages["espnet/owsm_v3.1_ebf_base"], | |
label="Language", | |
info="Choose language!", | |
value="jpn", | |
interactive=True, | |
) | |
task = gr.Dropdown( | |
tasks["espnet/owsm_v3.1_ebf_base"], | |
label="Task", | |
info="Choose task!", | |
value="asr", | |
interactive=True, | |
) | |
gr.Markdown("## Set training settings") | |
with gr.Row(): | |
with gr.Column(): | |
log_every = gr.Number(value=10, label="log_every", interactive=True) | |
max_epoch = gr.Slider(1, 5, step=1, label="max_epoch", interactive=True) | |
scheduler = gr.Dropdown( | |
["warmuplr"], label="warmup", value="warmuplr", interactive=True | |
) | |
warmup_steps = gr.Number( | |
value=100, label="warmup_steps", interactive=True | |
) | |
with gr.Column(): | |
optimizer = gr.Dropdown( | |
["adam", "adamw", "sgd", "adadelta", "adagrad", "adamax", "asgd", "rmsprop"], | |
label="optimizer", | |
value="adam", | |
interactive=True | |
) | |
learning_rate = gr.Number( | |
value=1e-4, label="learning_rate", interactive=True | |
) | |
weight_decay = gr.Number( | |
value=0.000001, label="weight_decay", interactive=True | |
) | |
gr.Markdown("## Logs and plots") | |
with gr.Row(): | |
with gr.Column(): | |
log_output = gr.Textbox( | |
show_label=False, | |
interactive=False, | |
max_lines=23, | |
lines=23, | |
) | |
demo.load(read_logs, [tempdir_path], log_output, every=2) | |
with gr.Column(): | |
log_acc = gr.Image(label="Accuracy", show_label=True, interactive=False) | |
log_loss = gr.Image(label="Loss", show_label=True, interactive=False) | |
demo.load(plot_loss_acc, [tempdir_path, log_every], [log_acc, log_loss], every=10) | |
with gr.Row(): | |
with gr.Column(): | |
ref_text = gr.Textbox( | |
label="Reference text", | |
show_label=True, | |
interactive=False, | |
max_lines=10, | |
lines=10, | |
) | |
with gr.Column(): | |
base_text = gr.Textbox( | |
label="Baseline text", | |
show_label=True, | |
interactive=False, | |
max_lines=10, | |
lines=10, | |
) | |
with gr.Row(): | |
with gr.Column(): | |
hyp_text = gr.Textbox( | |
label="Hypothesis text", | |
show_label=True, | |
interactive=False, | |
max_lines=10, | |
lines=10, | |
) | |
with gr.Column(): | |
trained_model = gr.File( | |
label="Trained model", | |
interactive=False, | |
) | |
with gr.Row(): | |
finetune_btn = gr.Button("Finetune Model", variant="primary") | |
finetune_btn.click( | |
finetune_model, | |
[ | |
lang, | |
task, | |
tempdir_path, | |
log_every, | |
max_epoch, | |
scheduler, | |
warmup_steps, | |
optimizer, | |
learning_rate, | |
weight_decay, | |
], | |
[trained_model, ref_text, base_text, hyp_text] | |
) | |
gr.Markdown(load_markdown()) | |
if __name__ == "__main__": | |
try: | |
demo.queue().launch() | |
except: | |
print("Unexpected error:", sys.exc_info()[0]) | |
raise | |
finally: | |
shutil.rmtree(os.environ['TEMP_DIR']) | |