ms180 commited on
Commit
498c6f5
1 Parent(s): e1c0f34
Files changed (1) hide show
  1. app.py +35 -12
app.py CHANGED
@@ -9,23 +9,23 @@ from pathlib import Path
9
 
10
  import gradio as gr
11
 
12
- from finetune import finetune_model
13
 
14
  from language import languages
15
  from task import tasks
16
  import matplotlib.pyplot as plt
17
 
18
 
19
- os.environ['TEMP_DIR'] = tempfile.mkdtemp()
20
-
21
  def load_markdown():
22
  with open("intro.md", "r") as f:
23
  return f.read()
24
 
25
 
26
- def read_logs():
 
 
27
  try:
28
- with open(f"output.log", "r") as f:
29
  return f.read()
30
  except:
31
  return None
@@ -34,7 +34,10 @@ def read_logs():
34
  def plot_loss_acc(temp_dir, log_every):
35
  sys.stdout.flush()
36
  lines = []
37
- with open("output.log", "r") as f:
 
 
 
38
  for line in f.readlines():
39
  if re.match(r"^\[\d+\] - loss: \d+\.\d+ - acc: \d+\.\d+$", line):
40
  lines.append(line)
@@ -68,22 +71,28 @@ def upload_file(fileobj, temp_dir):
68
  """
69
  # First check if a file is a zip file.
70
  if not zipfile.is_zipfile(fileobj.name):
 
71
  raise gr.Error("Please upload a zip file.")
72
 
73
  # Then unzip file
 
74
  shutil.unpack_archive(fileobj.name, temp_dir)
75
 
76
  # check zip file
77
  if not os.path.exists(os.path.join(temp_dir, "text")):
 
78
  raise gr.Error("Please upload a valid zip file.")
79
 
80
  if not os.path.exists(os.path.join(temp_dir, "text_ctc")):
 
81
  raise gr.Error("Please upload a valid zip file.")
82
 
83
  if not os.path.exists(os.path.join(temp_dir, "audio")):
 
84
  raise gr.Error("Please upload a valid zip file.")
85
 
86
  # check if all texts and audio matches
 
87
  audio_ids = []
88
  with open(os.path.join(temp_dir, "text"), "r") as f:
89
  for line in f.readlines():
@@ -100,25 +109,39 @@ def upload_file(fileobj, temp_dir):
100
  )
101
 
102
  if set(audio_ids) != set(ctc_audio_ids):
 
103
  raise gr.Error(f"`text` and `text_ctc` have different audio ids.")
104
 
105
  for audio_id in glob.glob(os.path.join(temp_dir, "audio", "*")):
106
  if not Path(audio_id).stem in audio_ids:
107
  raise gr.Error(f"Audio id {audio_id} is not in `text` or `text_ctc`.")
108
 
 
109
  gr.Info("Successfully uploaded and validated zip file.")
110
 
111
  return [fileobj]
112
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  with gr.Blocks(title="OWSM-finetune") as demo:
115
- tempdir_path = gr.State(os.environ['TEMP_DIR'])
116
  gr.Markdown(
117
  """# OWSM finetune demo!
118
-
119
  Finetune `owsm_v3.1_ebf_base` with your own dataset!
120
  Due to resource limitation, you can only train 10 epochs on maximum.
121
-
122
  ## Upload dataset and define settings
123
  """
124
  )
@@ -153,7 +176,7 @@ Due to resource limitation, you can only train 10 epochs on maximum.
153
  with gr.Row():
154
  with gr.Column():
155
  log_every = gr.Number(value=10, label="log_every", interactive=True)
156
- max_epoch = gr.Slider(1, 10, step=1, label="max_epoch", interactive=True)
157
  scheduler = gr.Dropdown(
158
  ["warmuplr"], label="warmup", value="warmuplr", interactive=True
159
  )
@@ -185,7 +208,7 @@ Due to resource limitation, you can only train 10 epochs on maximum.
185
  max_lines=23,
186
  lines=23,
187
  )
188
- demo.load(read_logs, None, log_output, every=2)
189
 
190
  with gr.Column():
191
  log_acc = gr.Image(label="Accuracy", show_label=True, interactive=False)
@@ -241,7 +264,7 @@ Due to resource limitation, you can only train 10 epochs on maximum.
241
  learning_rate,
242
  weight_decay,
243
  ],
244
- [trained_model, hyp_text]
245
  )
246
 
247
  gr.Markdown(load_markdown())
 
9
 
10
  import gradio as gr
11
 
12
+ from finetune import finetune_model, log
13
 
14
  from language import languages
15
  from task import tasks
16
  import matplotlib.pyplot as plt
17
 
18
 
 
 
19
  def load_markdown():
20
  with open("intro.md", "r") as f:
21
  return f.read()
22
 
23
 
24
+ def read_logs(temp_dir):
25
+ if not os.path.exists(f"{temp_dir}/output.log"):
26
+ return "Log file not found."
27
  try:
28
+ with open(f"{temp_dir}/output.log", "r") as f:
29
  return f.read()
30
  except:
31
  return None
 
34
  def plot_loss_acc(temp_dir, log_every):
35
  sys.stdout.flush()
36
  lines = []
37
+ if not os.path.exists(f"{temp_dir}/output.log"):
38
+ return None, None
39
+
40
+ with open(f"{temp_dir}/output.log", "r") as f:
41
  for line in f.readlines():
42
  if re.match(r"^\[\d+\] - loss: \d+\.\d+ - acc: \d+\.\d+$", line):
43
  lines.append(line)
 
71
  """
72
  # First check if a file is a zip file.
73
  if not zipfile.is_zipfile(fileobj.name):
74
+ log(temp_dir, "Please upload a zip file.")
75
  raise gr.Error("Please upload a zip file.")
76
 
77
  # Then unzip file
78
+ log(temp_dir, "Unzipping file...")
79
  shutil.unpack_archive(fileobj.name, temp_dir)
80
 
81
  # check zip file
82
  if not os.path.exists(os.path.join(temp_dir, "text")):
83
+ log(temp_dir, "Please upload a valid zip file.")
84
  raise gr.Error("Please upload a valid zip file.")
85
 
86
  if not os.path.exists(os.path.join(temp_dir, "text_ctc")):
87
+ log(temp_dir, "Please upload a valid zip file.")
88
  raise gr.Error("Please upload a valid zip file.")
89
 
90
  if not os.path.exists(os.path.join(temp_dir, "audio")):
91
+ log(temp_dir, "Please upload a valid zip file.")
92
  raise gr.Error("Please upload a valid zip file.")
93
 
94
  # check if all texts and audio matches
95
+ log(temp_dir, "Checking if all texts and audio matches...")
96
  audio_ids = []
97
  with open(os.path.join(temp_dir, "text"), "r") as f:
98
  for line in f.readlines():
 
109
  )
110
 
111
  if set(audio_ids) != set(ctc_audio_ids):
112
+ log(temp_dir, f"`text` and `text_ctc` have different audio ids.")
113
  raise gr.Error(f"`text` and `text_ctc` have different audio ids.")
114
 
115
  for audio_id in glob.glob(os.path.join(temp_dir, "audio", "*")):
116
  if not Path(audio_id).stem in audio_ids:
117
  raise gr.Error(f"Audio id {audio_id} is not in `text` or `text_ctc`.")
118
 
119
+ log(temp_dir, "Successfully uploaded and validated zip file.")
120
  gr.Info("Successfully uploaded and validated zip file.")
121
 
122
  return [fileobj]
123
 
124
 
125
+ def delete_tmp_dir(tmp_dir):
126
+ if os.path.exists(tmp_dir):
127
+ shutil.rmtree(tmp_dir)
128
+ print(f"Deleted temporary directory: {tmp_dir}")
129
+ else:
130
+ print("Temporary directory already deleted")
131
+
132
+
133
+ def create_tmp_dir():
134
+ tmp_dir = tempfile.mkdtemp()
135
+ print(f"Created temporary directory: {tmp_dir}")
136
+ return tmp_dir
137
+
138
+
139
  with gr.Blocks(title="OWSM-finetune") as demo:
140
+ tempdir_path=gr.State(create_tmp_dir, delete_callback=delete_tmp_dir, time_to_live=600)
141
  gr.Markdown(
142
  """# OWSM finetune demo!
 
143
  Finetune `owsm_v3.1_ebf_base` with your own dataset!
144
  Due to resource limitation, you can only train 10 epochs on maximum.
 
145
  ## Upload dataset and define settings
146
  """
147
  )
 
176
  with gr.Row():
177
  with gr.Column():
178
  log_every = gr.Number(value=10, label="log_every", interactive=True)
179
+ max_epoch = gr.Slider(1, 30, step=1, label="max_epoch", interactive=True)
180
  scheduler = gr.Dropdown(
181
  ["warmuplr"], label="warmup", value="warmuplr", interactive=True
182
  )
 
208
  max_lines=23,
209
  lines=23,
210
  )
211
+ demo.load(read_logs, [tempdir_path], log_output, every=2)
212
 
213
  with gr.Column():
214
  log_acc = gr.Image(label="Accuracy", show_label=True, interactive=False)
 
264
  learning_rate,
265
  weight_decay,
266
  ],
267
+ [trained_model, ref_text, base_text, hyp_text]
268
  )
269
 
270
  gr.Markdown(load_markdown())