ms180 commited on
Commit
cb0fcd5
1 Parent(s): ec295c4

Upload 14 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
2
+
3
+ ENV NUMBA_CACHE_DIR=/tmp
4
+
5
+ RUN apt update && apt install -y ffmpeg
6
+ RUN apt-get install -y curl git gcc libxml2-dev libxslt1-dev zlib1g-dev g++
7
+
8
+ RUN useradd -m -u 1000 user
9
+ USER user
10
+
11
+ WORKDIR /code
12
+
13
+ RUN chmod 777 /code
14
+
15
+ COPY . /code
16
+
17
+ RUN pip install -U pip;
18
+ RUN pip install wheel setuptools;
19
+ RUN pip install -r /code/requirements.txt
20
+
21
+ RUN git clone https://github.com/espnet/espnet.git
22
+ RUN pip install -U -e /code/espnet
23
+
24
+ EXPOSE 7860
25
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
26
+
27
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import shutil
4
+ import sys
5
+ import re
6
+ import tempfile
7
+ import zipfile
8
+ from pathlib import Path
9
+
10
+ import gradio as gr
11
+
12
+ from finetune import finetune_model, baseline_model
13
+
14
+ from language import languages
15
+ from task import tasks
16
+ import matplotlib.pyplot as plt
17
+
18
+
19
+ os.environ['TEMP_DIR'] = tempfile.mkdtemp()
20
+
21
+ def load_markdown():
22
+ with open("intro.md", "r") as f:
23
+ return f.read()
24
+
25
+
26
+ def read_logs():
27
+ try:
28
+ with open(f"output.log", "r") as f:
29
+ return f.read()
30
+ except:
31
+ return None
32
+
33
+
34
+ def plot_loss_acc(temp_dir, log_every):
35
+ sys.stdout.flush()
36
+ lines = []
37
+ with open("output.log", "r") as f:
38
+ for line in f.readlines():
39
+ if re.match(r"^\[\d+\] - loss: \d+\.\d+ - acc: \d+\.\d+$", line):
40
+ lines.append(line)
41
+
42
+ losses = []
43
+ acces = []
44
+ if len(lines) == 0:
45
+ return None, None
46
+
47
+ for line in lines:
48
+ _, loss, acc = line.split(" - ")
49
+ losses.append(float(loss.split(":")[1].strip()))
50
+ acces.append(float(acc.split(":")[1].strip()))
51
+
52
+ x = [i * log_every for i in range(1, len(losses) + 1)]
53
+
54
+ plt.plot(x, losses, label="loss")
55
+ plt.xlim(log_every // 2, x[-1] + log_every // 2)
56
+ plt.savefig(f"{temp_dir}/loss.png")
57
+ plt.clf()
58
+ plt.plot(x, acces, label="acc")
59
+ plt.xlim(log_every // 2, x[-1] + log_every // 2)
60
+ plt.savefig(f"{temp_dir}/acc.png")
61
+ plt.clf()
62
+ return f"{temp_dir}/acc.png", f"{temp_dir}/loss.png"
63
+
64
+
65
+ def upload_file(fileobj, temp_dir):
66
+ """
67
+ Upload a file and check the uploaded zip file.
68
+ """
69
+ # First check if a file is a zip file.
70
+ if not zipfile.is_zipfile(fileobj.name):
71
+ raise gr.Error("Please upload a zip file.")
72
+
73
+ # Then unzip file
74
+ shutil.unpack_archive(fileobj.name, temp_dir)
75
+
76
+ # check zip file
77
+ if not os.path.exists(os.path.join(temp_dir, "text")):
78
+ raise gr.Error("Please upload a valid zip file.")
79
+
80
+ if not os.path.exists(os.path.join(temp_dir, "text_ctc")):
81
+ raise gr.Error("Please upload a valid zip file.")
82
+
83
+ if not os.path.exists(os.path.join(temp_dir, "audio")):
84
+ raise gr.Error("Please upload a valid zip file.")
85
+
86
+ # check if all texts and audio matches
87
+ audio_ids = []
88
+ with open(os.path.join(temp_dir, "text"), "r") as f:
89
+ for line in f.readlines():
90
+ audio_ids.append(line.split(maxsplit=1)[0])
91
+
92
+ with open(os.path.join(temp_dir, "text_ctc"), "r") as f:
93
+ ctc_audio_ids = []
94
+ for line in f.readlines():
95
+ ctc_audio_ids.append(line.split(maxsplit=1)[0])
96
+
97
+ if len(audio_ids) != len(ctc_audio_ids):
98
+ raise gr.Error(
99
+ f"Length of `text` ({len(audio_ids)}) and `text_ctc` ({len(ctc_audio_ids)}) is different."
100
+ )
101
+
102
+ if set(audio_ids) != set(ctc_audio_ids):
103
+ raise gr.Error(f"`text` and `text_ctc` have different audio ids.")
104
+
105
+ for audio_id in glob.glob(os.path.join(temp_dir, "audio", "*")):
106
+ if not Path(audio_id).stem in audio_ids:
107
+ raise gr.Error(f"Audio id {audio_id} is not in `text` or `text_ctc`.")
108
+
109
+ gr.Info("Successfully uploaded and validated zip file.")
110
+
111
+ return [fileobj]
112
+
113
+
114
+ with gr.Blocks(title="OWSM-finetune") as demo:
115
+ tempdir_path = gr.State(os.environ['TEMP_DIR'])
116
+ gr.Markdown(
117
+ """# OWSM finetune demo!
118
+
119
+ Finetune `owsm_v3.1_ebf_base` with your own dataset!
120
+ Due to resource limitation, you can only train 50 epochs on maximum.
121
+
122
+ ## Upload dataset and define settings
123
+ """
124
+ )
125
+
126
+ # main contents
127
+ with gr.Row():
128
+ with gr.Column():
129
+ file_output = gr.File()
130
+ upload_button = gr.UploadButton("Click to Upload a File", file_count="single")
131
+ upload_button.upload(
132
+ upload_file, [upload_button, tempdir_path], [file_output]
133
+ )
134
+
135
+ with gr.Column():
136
+ lang = gr.Dropdown(
137
+ languages["espnet/owsm_v3.1_ebf_base"],
138
+ label="Language",
139
+ info="Choose language!",
140
+ value="jpn",
141
+ interactive=True,
142
+ )
143
+ task = gr.Dropdown(
144
+ tasks["espnet/owsm_v3.1_ebf_base"],
145
+ label="Task",
146
+ info="Choose task!",
147
+ value="asr",
148
+ interactive=True,
149
+ )
150
+
151
+ gr.Markdown("## Set training settings")
152
+
153
+ with gr.Row():
154
+ with gr.Column():
155
+ log_every = gr.Number(value=10, label="log_every", interactive=True)
156
+ max_epoch = gr.Slider(1, 10, step=1, label="max_epoch", interactive=True)
157
+ scheduler = gr.Dropdown(
158
+ ["warmuplr"], label="warmup", value="warmuplr", interactive=True
159
+ )
160
+ warmup_steps = gr.Number(
161
+ value=100, label="warmup_steps", interactive=True
162
+ )
163
+
164
+ with gr.Column():
165
+ optimizer = gr.Dropdown(
166
+ ["adam", "adamw", "sgd", "adadelta", "adagrad", "adamax", "asgd", "rmsprop"],
167
+ label="optimizer",
168
+ value="adam",
169
+ interactive=True
170
+ )
171
+ learning_rate = gr.Number(
172
+ value=1e-4, label="learning_rate", interactive=True
173
+ )
174
+ weight_decay = gr.Number(
175
+ value=0.000001, label="weight_decay", interactive=True
176
+ )
177
+
178
+ gr.Markdown("## Logs and plots")
179
+
180
+ with gr.Row():
181
+ with gr.Column():
182
+ log_output = gr.Textbox(
183
+ show_label=False,
184
+ interactive=False,
185
+ max_lines=23,
186
+ lines=23,
187
+ )
188
+ demo.load(read_logs, None, log_output, every=2)
189
+
190
+ with gr.Column():
191
+ log_acc = gr.Image(label="Accuracy", show_label=True, interactive=False)
192
+ log_loss = gr.Image(label="Loss", show_label=True, interactive=False)
193
+ demo.load(plot_loss_acc, [tempdir_path, log_every], [log_acc, log_loss], every=10)
194
+
195
+ with gr.Row():
196
+ with gr.Column():
197
+ ref_text = gr.Textbox(
198
+ label="Reference text",
199
+ show_label=True,
200
+ interactive=False,
201
+ max_lines=10,
202
+ lines=10,
203
+ )
204
+ with gr.Column():
205
+ base_text = gr.Textbox(
206
+ label="Baseline text",
207
+ show_label=True,
208
+ interactive=False,
209
+ max_lines=10,
210
+ lines=10,
211
+ )
212
+
213
+ with gr.Row():
214
+ with gr.Column():
215
+ hyp_text = gr.Textbox(
216
+ label="Hypothesis text",
217
+ show_label=True,
218
+ interactive=False,
219
+ max_lines=10,
220
+ lines=10,
221
+ )
222
+ with gr.Column():
223
+ trained_model = gr.File(
224
+ label="Trained model",
225
+ interactive=False,
226
+ )
227
+
228
+ with gr.Row():
229
+ with gr.Column():
230
+ baseline_btn = gr.Button("Run Baseline", variant="secondary")
231
+ baseline_btn.click(
232
+ baseline_model,
233
+ [
234
+ lang,
235
+ task,
236
+ tempdir_path,
237
+ ],
238
+ [ref_text, base_text]
239
+ )
240
+ with gr.Column():
241
+ finetune_btn = gr.Button("Finetune Model", variant="primary")
242
+ finetune_btn.click(
243
+ finetune_model,
244
+ [
245
+ lang,
246
+ task,
247
+ tempdir_path,
248
+ log_every,
249
+ max_epoch,
250
+ scheduler,
251
+ warmup_steps,
252
+ optimizer,
253
+ learning_rate,
254
+ weight_decay,
255
+ ],
256
+ [trained_model, hyp_text]
257
+ )
258
+
259
+ gr.Markdown(load_markdown())
260
+
261
+ if __name__ == "__main__":
262
+ try:
263
+ demo.queue().launch()
264
+ except:
265
+ print("Unexpected error:", sys.exc_info()[0])
266
+ raise
267
+ finally:
268
+ shutil.rmtree(os.environ['TEMP_DIR'])
assets/owsm_ebf_v3.1_base/bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d6327da127e870bcb8c737dceb3bd47ccbce63da74ddb094f64afe313d68c8c
3
+ size 1041297
assets/owsm_ebf_v3.1_base/config.yaml ADDED
The diff for this file is too large to render. See raw diff
 
assets/owsm_ebf_v3.1_base/owsm_finetune_base.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ seed: 2022
3
+ num_workers: 4
4
+ batch_type: numel
5
+ batch_bins: 1600000
6
+ accum_grad: 2
7
+ max_epoch: 10
8
+ patience: none
9
+ init: none
10
+ best_model_criterion:
11
+ - - valid
12
+ - acc
13
+ - max
14
+ keep_nbest_models: 3
15
+ use_amp: true
16
+
17
+ optim: adam
18
+ optim_conf:
19
+ lr: 0.0001
20
+ weight_decay: 0.000001
21
+ scheduler: warmuplr
22
+ scheduler_conf:
23
+ warmup_steps: 100
24
+
25
+ specaug: specaug
26
+ specaug_conf:
27
+ apply_time_warp: true
28
+ time_warp_window: 5
29
+ time_warp_mode: bicubic
30
+ apply_freq_mask: true
31
+ freq_mask_width_range:
32
+ - 0
33
+ - 27
34
+ num_freq_mask: 2
35
+ apply_time_mask: true
36
+ time_mask_width_ratio_range:
37
+ - 0.
38
+ - 0.05
39
+ num_time_mask: 5
40
+
assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99e5de1865e2c98308b41ce6f28b7f658bec7b274da60f37b219a99279d43f3a
3
+ size 404971245
assets/owsm_ebf_v3.1_base/tokens.txt ADDED
The diff for this file is too large to render. See raw diff
 
docker-compose.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ services:
2
+ python_310:
3
+ build: .
4
+ ports:
5
+ - "7860:7860"
exp/s2t_stats_raw_bpe50000/train/feats_stats.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ef4b5e465110edf32eec024cf2427eedd677f5733bb87d6b2131e6984a6e13f
3
+ size 1402
finetune.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import sys
3
+ from pathlib import Path
4
+ import shutil
5
+
6
+ from espnet2.tasks.s2t import S2TTask
7
+ from espnet2.text.sentencepiece_tokenizer import SentencepiecesTokenizer
8
+ from espnet2.text.token_id_converter import TokenIDConverter
9
+ from espnet2.s2t.espnet_model import ESPnetS2TModel
10
+ from espnet2.bin.s2t_inference import Speech2Text
11
+ import espnetez as ez
12
+
13
+ import torch
14
+ import numpy as np
15
+ import logging
16
+ import gradio as gr
17
+ import librosa
18
+
19
+
20
+ class Logger:
21
+ def __init__(self, filename):
22
+ self.terminal = sys.stdout
23
+ self.log = open(filename, "w")
24
+
25
+ def write(self, message):
26
+ self.terminal.write(message)
27
+ self.log.write(message)
28
+
29
+ def flush(self):
30
+ self.terminal.flush()
31
+ self.log.flush()
32
+
33
+ def isatty(self):
34
+ return False
35
+
36
+
37
+ sys.stdout = Logger("output.log")
38
+
39
+
40
+ def count_parameters(model):
41
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
42
+
43
+
44
+ def get_dataset(data_path, data_info, test_count=10):
45
+ # load data
46
+ data = {}
47
+ keys = []
48
+ with open(f"{data_path}/text", "r", encoding="utf-8") as f:
49
+ for line in f.readlines():
50
+ audio_id, text = line.split(maxsplit=1)
51
+ data[audio_id.strip()] = {"text": text.strip()}
52
+ keys.append(audio_id.strip())
53
+
54
+ # load text_ctc data
55
+ with open(f"{data_path}/text_ctc", "r", encoding="utf-8") as f:
56
+ for line in f.readlines():
57
+ audio_id, text = line.split(maxsplit=1)
58
+ data[audio_id.strip()]["text_ctc"] = text.strip()
59
+
60
+ # load audio path
61
+ for audio_path in glob.glob(f"{data_path}/audio/*"):
62
+ audio_id = Path(audio_path).stem
63
+ data[audio_id]["audio_path"] = audio_path
64
+
65
+ # Convert to list
66
+ data = [{
67
+ 'id': audio_id,
68
+ 'text': data[audio_id]['text'],
69
+ 'text_ctc': data[audio_id]['text_ctc'],
70
+ 'audio_path': data[audio_id]['audio_path'],
71
+ } for audio_id in keys]
72
+
73
+ return ez.dataset.ESPnetEZDataset(data[test_count:], data_info), ez.dataset.ESPnetEZDataset(data[:test_count], data_info), data[:test_count]
74
+
75
+
76
+ class CustomFinetuneModel(ESPnetS2TModel):
77
+ def __init__(self, model, log_every=500):
78
+ super().__init__(
79
+ vocab_size=model.vocab_size,
80
+ token_list=model.token_list,
81
+ frontend=model.frontend,
82
+ specaug=model.specaug,
83
+ normalize=model.normalize,
84
+ preencoder=model.preencoder,
85
+ encoder=model.encoder,
86
+ postencoder=model.postencoder,
87
+ decoder=model.decoder,
88
+ ctc=model.ctc,
89
+ ctc_weight=model.ctc_weight,
90
+ interctc_weight=model.interctc_weight,
91
+ ignore_id=model.ignore_id,
92
+ lsm_weight=0.0,
93
+ length_normalized_loss=False,
94
+ report_cer=False,
95
+ report_wer=False,
96
+ sym_space="<space>",
97
+ sym_blank="<blank>",
98
+ sym_sos = "<sos>",
99
+ sym_eos = "<eos>",
100
+ sym_sop = "<sop>", # start of prev
101
+ sym_na = "<na>", # not available
102
+ extract_feats_in_collect_stats=model.extract_feats_in_collect_stats,
103
+ )
104
+ self.iter_count = 0
105
+ self.log_every = log_every
106
+ self.log_stats = {
107
+ 'loss': 0.0,
108
+ 'acc': 0.0
109
+ }
110
+
111
+ def forward(self, *args, **kwargs):
112
+ out = super().forward(*args, **kwargs)
113
+ self.log_stats['loss'] += out[1]['loss'].item()
114
+ self.log_stats['acc'] += out[1]['acc'].item()
115
+
116
+ self.iter_count += 1
117
+ if self.iter_count % self.log_every == 0:
118
+ loss = self.log_stats['loss'] / self.log_every
119
+ acc = self.log_stats['acc'] / self.log_every
120
+ print(f"[{self.iter_count}] - loss: {loss:.3f} - acc: {acc:.3f}")
121
+ self.log_stats['loss'] = 0.0
122
+ self.log_stats['acc'] = 0.0
123
+
124
+ return out
125
+
126
+
127
+ def finetune_model(lang, task, tempdir_path, log_every, max_epoch, scheduler, warmup_steps, optimizer, learning_rate, weight_decay):
128
+ """Main function for finetuning the model."""
129
+ print("Start loading dataset...")
130
+ if len(tempdir_path) == 0:
131
+ raise gr.Error("Please upload a zip file first.")
132
+
133
+ # define tokenizer
134
+ tokenizer = SentencepiecesTokenizer("assets/owsm_ebf_v3.1_base/bpe.model")
135
+ converter = TokenIDConverter("assets/owsm_ebf_v3.1_base/tokens.txt")
136
+
137
+ def tokenize(text):
138
+ return np.array(converter.tokens2ids(tokenizer.text2tokens(text)))
139
+
140
+ data_info = {
141
+ "speech": lambda d: librosa.load(d["audio_path"], sr=16000)[0],
142
+ "text": lambda d: tokenize(f"<{lang}><{task}><notimestamps> {d['text']}"),
143
+ "text_ctc": lambda d: tokenize(d["text_ctc"]),
144
+ "text_prev": lambda d: tokenize("<na>"),
145
+ }
146
+
147
+ # load dataset and define data_info
148
+ train_dataset, test_dataset, test_list = get_dataset(tempdir_path, data_info)
149
+ print("Loaded dataset.")
150
+ gr.Info("Loaded dataset.")
151
+
152
+ # load and update configuration
153
+ print("Setting up the training configuration...")
154
+ pretrain_config = ez.config.from_yaml(
155
+ "s2t",
156
+ "assets/owsm_ebf_v3.1_base/config.yaml",
157
+ )
158
+ finetune_config = ez.config.update_finetune_config(
159
+ "s2t", pretrain_config, "assets/owsm_ebf_v3.1_base/owsm_finetune_base.yaml"
160
+ )
161
+ finetune_config['max_epoch'] = max_epoch
162
+ finetune_config['optim'] = optimizer
163
+ finetune_config['optim_conf']['lr'] = learning_rate
164
+ finetune_config['optim_conf']['weight_decay'] = weight_decay
165
+ finetune_config['scheduler'] = scheduler
166
+ finetune_config['scheduler_conf']['warmup_steps'] = warmup_steps
167
+ finetune_config['multiple_iterator'] = False
168
+ finetune_config['num_iters_per_epoch'] = None
169
+
170
+ def build_model_fn(args):
171
+ model, _ = S2TTask.build_model_from_file(
172
+ "assets/owsm_ebf_v3.1_base/config.yaml",
173
+ "assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth",
174
+ device="cuda" if torch.cuda.is_available() else "cpu",
175
+ )
176
+ model.train()
177
+ print(f'Trainable parameters: {count_parameters(model)}')
178
+ model = CustomFinetuneModel(model, log_every=log_every)
179
+ return model
180
+
181
+ trainer = ez.Trainer(
182
+ task='s2t',
183
+ train_config=finetune_config,
184
+ train_dataset=train_dataset,
185
+ valid_dataset=test_dataset,
186
+ build_model_fn=build_model_fn, # provide the pre-trained model
187
+ data_info=data_info,
188
+ output_dir=f"{tempdir_path}/exp/finetune",
189
+ stats_dir=f"{tempdir_path}/exp/stats",
190
+ ngpu=1
191
+ )
192
+ gr.Info("start collect stats")
193
+ print("Start collect stats process...")
194
+ trainer.collect_stats()
195
+ gr.Info("Finished collect stats, starting training.")
196
+ print("Finished collect stats process. Start training.")
197
+ trainer.train()
198
+ gr.Info("Finished Fine-tuning! Archiving experiment files...")
199
+ print("Finished fine-tuning.")
200
+ print("Start archiving experiment files...")
201
+ print("Create zip file for the following files into `finetune.zip`:")
202
+ for f in glob.glob(f"{tempdir_path}/exp/finetune/*"):
203
+ print(f.replace(tempdir_path, ""))
204
+
205
+ shutil.make_archive(f"{tempdir_path}/finetune", 'zip', f"{tempdir_path}/exp/finetune")
206
+ gr.Info("Finished generating result file in zip!")
207
+ print("Finished archiving experiment files.")
208
+
209
+ print("Start generating test result...")
210
+ gr.Info("Start generating output for test set!")
211
+
212
+ del trainer
213
+ model = Speech2Text(
214
+ "assets/owsm_ebf_v3.1_base/config.yaml",
215
+ "assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth",
216
+ device="cuda" if torch.cuda.is_available() else "cpu",
217
+ token_type="bpe",
218
+ bpemodel="assets/owsm_ebf_v3.1_base/bpe.model",
219
+ beam_size=5,
220
+ ctc_weight=0.3,
221
+ lang_sym=f"<{lang}>",
222
+ task_sym=f"<{task}>",
223
+ )
224
+ model.s2t_model.eval()
225
+ d = torch.load(f"{tempdir_path}/exp/finetune/valid.acc.ave.pth")
226
+ model.s2t_model.load_state_dict(d)
227
+
228
+ hyp = ""
229
+ with open(f"{tempdir_path}/hyp.txt", "w") as f_hyp:
230
+ for i in range(len(test_list)):
231
+ data = test_list[i]
232
+ out = model(librosa.load(data['audio_path'], sr=16000)[0])[0][3]
233
+ f_hyp.write(out + '\n')
234
+ hyp += out + '\n'
235
+
236
+ return [f"{tempdir_path}/finetune.zip", f"{tempdir_path}/ref.txt", f"{tempdir_path}/base.txt", f"{tempdir_path}/hyp.txt"], hyp
237
+
238
+
239
+ def baseline_model(lang, task, tempdir_path):
240
+ print("Start loading dataset...")
241
+ if len(tempdir_path) == 0:
242
+ raise gr.Error("Please upload a zip file first.")
243
+
244
+ # define tokenizer
245
+ tokenizer = SentencepiecesTokenizer("assets/owsm_ebf_v3.1_base/bpe.model")
246
+ converter = TokenIDConverter("assets/owsm_ebf_v3.1_base/tokens.txt")
247
+
248
+ def tokenize(text):
249
+ return np.array(converter.tokens2ids(tokenizer.text2tokens(text)))
250
+
251
+ data_info = {
252
+ "speech": lambda d: librosa.load(d["audio_path"], sr=16000)[0],
253
+ "text": lambda d: tokenize(f"<{lang}><{task}><notimestamps> {d['text']}"),
254
+ "text_ctc": lambda d: tokenize(d["text_ctc"]),
255
+ "text_prev": lambda d: tokenize("<na>"),
256
+ }
257
+
258
+ # load dataset and define data_info
259
+ train_dataset, test_dataset, test_list = get_dataset(tempdir_path, data_info)
260
+ print("Loaded dataset.")
261
+ gr.Info("Loaded dataset.")
262
+
263
+ print("Loading pretrained model...")
264
+ gr.Info("Loading pretrained model...")
265
+
266
+ model = Speech2Text(
267
+ "assets/owsm_ebf_v3.1_base/config.yaml",
268
+ "assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth",
269
+ device="cuda" if torch.cuda.is_available() else "cpu",
270
+ token_type="bpe",
271
+ bpemodel="assets/owsm_ebf_v3.1_base/bpe.model",
272
+ beam_size=5,
273
+ ctc_weight=0.3,
274
+ lang_sym=f"<{lang}>",
275
+ task_sym=f"<{task}>",
276
+ )
277
+ model.s2t_model.eval()
278
+
279
+ base = ""
280
+ ref = ""
281
+ with open(f"{tempdir_path}/base.txt", "w") as f_base, open(f"{tempdir_path}/ref.txt", "w") as f_ref:
282
+ for i in range(len(test_list)):
283
+ data = test_list[i]
284
+ f_ref.write(data['text'] + '\n')
285
+ out = model(librosa.load(data['audio_path'], sr=16000)[0])[0][3]
286
+ f_base.write(out + '\n')
287
+ ref += data['text'] + '\n'
288
+ base += out + '\n'
289
+
290
+ return ref, base
intro.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Please create the zip file in the following structure:
3
+
4
+ ```
5
+ train.zip
6
+ - audio
7
+ - audio_id_1.wav
8
+ - audio_id_2.wav
9
+ - ...
10
+
11
+ - text
12
+ - text_ctc
13
+ ```
14
+
15
+ `text_ctc` should contain the transcription in the following format:
16
+
17
+ ```
18
+ audio_id_1 transcription
19
+ audio_id_2 transcription
20
+ ...
21
+ ```
22
+
23
+ `text` should contain the text output in the following format:
24
+
25
+ ```
26
+ audio_id_1 transcription or translated text
27
+ audio_id_2 transcription or translated text
28
+ ...
29
+ ```
language.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ languages = {
2
+ "espnet/owsm_v3.1_ebf_base": [
3
+ "abk",
4
+ "afr",
5
+ "amh",
6
+ "ara",
7
+ "asm",
8
+ "ast",
9
+ "aze",
10
+ "bak",
11
+ "bas",
12
+ "bel",
13
+ "ben",
14
+ "bos",
15
+ "bre",
16
+ "bul",
17
+ "cat",
18
+ "ceb",
19
+ "ces",
20
+ "chv",
21
+ "ckb",
22
+ "cmn",
23
+ "cnh",
24
+ "cym",
25
+ "dan",
26
+ "deu",
27
+ "dgd",
28
+ "div",
29
+ "ell",
30
+ "eng",
31
+ "epo",
32
+ "est",
33
+ "eus",
34
+ "fas",
35
+ "fil",
36
+ "fin",
37
+ "fra",
38
+ "frr",
39
+ "ful",
40
+ "gle",
41
+ "glg",
42
+ "grn",
43
+ "guj",
44
+ "hat",
45
+ "hau",
46
+ "heb",
47
+ "hin",
48
+ "hrv",
49
+ "hsb",
50
+ "hun",
51
+ "hye",
52
+ "ibo",
53
+ "ina",
54
+ "ind",
55
+ "isl",
56
+ "ita",
57
+ "jav",
58
+ "jpn",
59
+ "kab",
60
+ "kam",
61
+ "kan",
62
+ "kat",
63
+ "kaz",
64
+ "kea",
65
+ "khm",
66
+ "kin",
67
+ "kir",
68
+ "kmr",
69
+ "kor",
70
+ "lao",
71
+ "lav",
72
+ "lga",
73
+ "lin",
74
+ "lit",
75
+ "ltz",
76
+ "lug",
77
+ "luo",
78
+ "mal",
79
+ "mar",
80
+ "mas",
81
+ "mdf",
82
+ "mhr",
83
+ "mkd",
84
+ "mlt",
85
+ "mon",
86
+ "mri",
87
+ "mrj",
88
+ "mya",
89
+ "myv",
90
+ "nan",
91
+ "nep",
92
+ "nld",
93
+ "nno",
94
+ "nob",
95
+ "npi",
96
+ "nso",
97
+ "nya",
98
+ "oci",
99
+ "ori",
100
+ "orm",
101
+ "ory",
102
+ "pan",
103
+ "pol",
104
+ "por",
105
+ "pus",
106
+ "quy",
107
+ "roh",
108
+ "ron",
109
+ "rus",
110
+ "sah",
111
+ "sat",
112
+ "sin",
113
+ "skr",
114
+ "slk",
115
+ "slv",
116
+ "sna",
117
+ "snd",
118
+ "som",
119
+ "sot",
120
+ "spa",
121
+ "srd",
122
+ "srp",
123
+ "sun",
124
+ "swa",
125
+ "swe",
126
+ "swh",
127
+ "tam",
128
+ "tat",
129
+ "tel",
130
+ "tgk",
131
+ "tgl",
132
+ "tha",
133
+ "tig",
134
+ "tir",
135
+ "tok",
136
+ "tpi",
137
+ "tsn",
138
+ "tuk",
139
+ "tur",
140
+ "twi",
141
+ "uig",
142
+ "ukr",
143
+ "umb",
144
+ "urd",
145
+ "uzb",
146
+ "vie",
147
+ "vot",
148
+ "wol",
149
+ "xho",
150
+ "yor",
151
+ "yue",
152
+ "zho",
153
+ "zul",
154
+ ]
155
+ }
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio
2
+ torchaudio
task.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tasks = {
2
+ "espnet/owsm_v3.1_ebf_base": [
3
+ "asr",
4
+ "st_ara",
5
+ "st_cat",
6
+ "st_ces",
7
+ "st_cym",
8
+ "st_deu",
9
+ "st_eng",
10
+ "st_est",
11
+ "st_fas",
12
+ "st_fra",
13
+ "st_ind",
14
+ "st_ita",
15
+ "st_jpn",
16
+ "st_lav",
17
+ "st_mon",
18
+ "st_nld",
19
+ "st_por",
20
+ "st_ron",
21
+ "st_rus",
22
+ "st_slv",
23
+ "st_spa",
24
+ "st_swe",
25
+ "st_tam",
26
+ "st_tur",
27
+ "st_vie",
28
+ "st_zho",
29
+ ]
30
+ }