ms180 commited on
Commit
1efbfe3
1 Parent(s): 9520224

Update finetune.py

Browse files
Files changed (1) hide show
  1. finetune.py +290 -290
finetune.py CHANGED
@@ -1,290 +1,290 @@
1
- import glob
2
- import sys
3
- from pathlib import Path
4
- import shutil
5
-
6
- from espnet2.tasks.s2t import S2TTask
7
- from espnet2.text.sentencepiece_tokenizer import SentencepiecesTokenizer
8
- from espnet2.text.token_id_converter import TokenIDConverter
9
- from espnet2.s2t.espnet_model import ESPnetS2TModel
10
- from espnet2.bin.s2t_inference import Speech2Text
11
- import espnetez as ez
12
-
13
- import torch
14
- import numpy as np
15
- import logging
16
- import gradio as gr
17
- import librosa
18
-
19
-
20
- class Logger:
21
- def __init__(self, filename):
22
- self.terminal = sys.stdout
23
- self.log = open(filename, "w")
24
-
25
- def write(self, message):
26
- self.terminal.write(message)
27
- self.log.write(message)
28
-
29
- def flush(self):
30
- self.terminal.flush()
31
- self.log.flush()
32
-
33
- def isatty(self):
34
- return False
35
-
36
-
37
- sys.stdout = Logger("output.log")
38
-
39
-
40
- def count_parameters(model):
41
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
42
-
43
-
44
- def get_dataset(data_path, data_info, test_count=10):
45
- # load data
46
- data = {}
47
- keys = []
48
- with open(f"{data_path}/text", "r", encoding="utf-8") as f:
49
- for line in f.readlines():
50
- audio_id, text = line.split(maxsplit=1)
51
- data[audio_id.strip()] = {"text": text.strip()}
52
- keys.append(audio_id.strip())
53
-
54
- # load text_ctc data
55
- with open(f"{data_path}/text_ctc", "r", encoding="utf-8") as f:
56
- for line in f.readlines():
57
- audio_id, text = line.split(maxsplit=1)
58
- data[audio_id.strip()]["text_ctc"] = text.strip()
59
-
60
- # load audio path
61
- for audio_path in glob.glob(f"{data_path}/audio/*"):
62
- audio_id = Path(audio_path).stem
63
- data[audio_id]["audio_path"] = audio_path
64
-
65
- # Convert to list
66
- data = [{
67
- 'id': audio_id,
68
- 'text': data[audio_id]['text'],
69
- 'text_ctc': data[audio_id]['text_ctc'],
70
- 'audio_path': data[audio_id]['audio_path'],
71
- } for audio_id in keys]
72
-
73
- return ez.dataset.ESPnetEZDataset(data[test_count:], data_info), ez.dataset.ESPnetEZDataset(data[:test_count], data_info), data[:test_count]
74
-
75
-
76
- class CustomFinetuneModel(ESPnetS2TModel):
77
- def __init__(self, model, log_every=500):
78
- super().__init__(
79
- vocab_size=model.vocab_size,
80
- token_list=model.token_list,
81
- frontend=model.frontend,
82
- specaug=model.specaug,
83
- normalize=model.normalize,
84
- preencoder=model.preencoder,
85
- encoder=model.encoder,
86
- postencoder=model.postencoder,
87
- decoder=model.decoder,
88
- ctc=model.ctc,
89
- ctc_weight=model.ctc_weight,
90
- interctc_weight=model.interctc_weight,
91
- ignore_id=model.ignore_id,
92
- lsm_weight=0.0,
93
- length_normalized_loss=False,
94
- report_cer=False,
95
- report_wer=False,
96
- sym_space="<space>",
97
- sym_blank="<blank>",
98
- sym_sos = "<sos>",
99
- sym_eos = "<eos>",
100
- sym_sop = "<sop>", # start of prev
101
- sym_na = "<na>", # not available
102
- extract_feats_in_collect_stats=model.extract_feats_in_collect_stats,
103
- )
104
- self.iter_count = 0
105
- self.log_every = log_every
106
- self.log_stats = {
107
- 'loss': 0.0,
108
- 'acc': 0.0
109
- }
110
-
111
- def forward(self, *args, **kwargs):
112
- out = super().forward(*args, **kwargs)
113
- self.log_stats['loss'] += out[1]['loss'].item()
114
- self.log_stats['acc'] += out[1]['acc'].item()
115
-
116
- self.iter_count += 1
117
- if self.iter_count % self.log_every == 0:
118
- loss = self.log_stats['loss'] / self.log_every
119
- acc = self.log_stats['acc'] / self.log_every
120
- print(f"[{self.iter_count}] - loss: {loss:.3f} - acc: {acc:.3f}")
121
- self.log_stats['loss'] = 0.0
122
- self.log_stats['acc'] = 0.0
123
-
124
- return out
125
-
126
-
127
- def finetune_model(lang, task, tempdir_path, log_every, max_epoch, scheduler, warmup_steps, optimizer, learning_rate, weight_decay):
128
- """Main function for finetuning the model."""
129
- print("Start loading dataset...")
130
- if len(tempdir_path) == 0:
131
- raise gr.Error("Please upload a zip file first.")
132
-
133
- # define tokenizer
134
- tokenizer = SentencepiecesTokenizer("assets/owsm_ebf_v3.1_base/bpe.model")
135
- converter = TokenIDConverter("assets/owsm_ebf_v3.1_base/tokens.txt")
136
-
137
- def tokenize(text):
138
- return np.array(converter.tokens2ids(tokenizer.text2tokens(text)))
139
-
140
- data_info = {
141
- "speech": lambda d: librosa.load(d["audio_path"], sr=16000)[0],
142
- "text": lambda d: tokenize(f"<{lang}><{task}><notimestamps> {d['text']}"),
143
- "text_ctc": lambda d: tokenize(d["text_ctc"]),
144
- "text_prev": lambda d: tokenize("<na>"),
145
- }
146
-
147
- # load dataset and define data_info
148
- train_dataset, test_dataset, test_list = get_dataset(tempdir_path, data_info)
149
- print("Loaded dataset.")
150
- gr.Info("Loaded dataset.")
151
-
152
- # load and update configuration
153
- print("Setting up the training configuration...")
154
- pretrain_config = ez.config.from_yaml(
155
- "s2t",
156
- "assets/owsm_ebf_v3.1_base/config.yaml",
157
- )
158
- finetune_config = ez.config.update_finetune_config(
159
- "s2t", pretrain_config, "assets/owsm_ebf_v3.1_base/owsm_finetune_base.yaml"
160
- )
161
- finetune_config['max_epoch'] = max_epoch
162
- finetune_config['optim'] = optimizer
163
- finetune_config['optim_conf']['lr'] = learning_rate
164
- finetune_config['optim_conf']['weight_decay'] = weight_decay
165
- finetune_config['scheduler'] = scheduler
166
- finetune_config['scheduler_conf']['warmup_steps'] = warmup_steps
167
- finetune_config['multiple_iterator'] = False
168
- finetune_config['num_iters_per_epoch'] = None
169
-
170
- def build_model_fn(args):
171
- model, _ = S2TTask.build_model_from_file(
172
- "assets/owsm_ebf_v3.1_base/config.yaml",
173
- "assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth",
174
- device="cuda" if torch.cuda.is_available() else "cpu",
175
- )
176
- model.train()
177
- print(f'Trainable parameters: {count_parameters(model)}')
178
- model = CustomFinetuneModel(model, log_every=log_every)
179
- return model
180
-
181
- trainer = ez.Trainer(
182
- task='s2t',
183
- train_config=finetune_config,
184
- train_dataset=train_dataset,
185
- valid_dataset=test_dataset,
186
- build_model_fn=build_model_fn, # provide the pre-trained model
187
- data_info=data_info,
188
- output_dir=f"{tempdir_path}/exp/finetune",
189
- stats_dir=f"{tempdir_path}/exp/stats",
190
- ngpu=1
191
- )
192
- gr.Info("start collect stats")
193
- print("Start collect stats process...")
194
- trainer.collect_stats()
195
- gr.Info("Finished collect stats, starting training.")
196
- print("Finished collect stats process. Start training.")
197
- trainer.train()
198
- gr.Info("Finished Fine-tuning! Archiving experiment files...")
199
- print("Finished fine-tuning.")
200
- print("Start archiving experiment files...")
201
- print("Create zip file for the following files into `finetune.zip`:")
202
- for f in glob.glob(f"{tempdir_path}/exp/finetune/*"):
203
- print(f.replace(tempdir_path, ""))
204
-
205
- shutil.make_archive(f"{tempdir_path}/finetune", 'zip', f"{tempdir_path}/exp/finetune")
206
- gr.Info("Finished generating result file in zip!")
207
- print("Finished archiving experiment files.")
208
-
209
- print("Start generating test result...")
210
- gr.Info("Start generating output for test set!")
211
-
212
- del trainer
213
- model = Speech2Text(
214
- "assets/owsm_ebf_v3.1_base/config.yaml",
215
- "assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth",
216
- device="cuda" if torch.cuda.is_available() else "cpu",
217
- token_type="bpe",
218
- bpemodel="assets/owsm_ebf_v3.1_base/bpe.model",
219
- beam_size=5,
220
- ctc_weight=0.3,
221
- lang_sym=f"<{lang}>",
222
- task_sym=f"<{task}>",
223
- )
224
- model.s2t_model.eval()
225
- d = torch.load(f"{tempdir_path}/exp/finetune/valid.acc.ave.pth")
226
- model.s2t_model.load_state_dict(d)
227
-
228
- hyp = ""
229
- with open(f"{tempdir_path}/hyp.txt", "w") as f_hyp:
230
- for i in range(len(test_list)):
231
- data = test_list[i]
232
- out = model(librosa.load(data['audio_path'], sr=16000)[0])[0][3]
233
- f_hyp.write(out + '\n')
234
- hyp += out + '\n'
235
-
236
- return [f"{tempdir_path}/finetune.zip", f"{tempdir_path}/ref.txt", f"{tempdir_path}/base.txt", f"{tempdir_path}/hyp.txt"], hyp
237
-
238
-
239
- def baseline_model(lang, task, tempdir_path):
240
- print("Start loading dataset...")
241
- if len(tempdir_path) == 0:
242
- raise gr.Error("Please upload a zip file first.")
243
-
244
- # define tokenizer
245
- tokenizer = SentencepiecesTokenizer("assets/owsm_ebf_v3.1_base/bpe.model")
246
- converter = TokenIDConverter("assets/owsm_ebf_v3.1_base/tokens.txt")
247
-
248
- def tokenize(text):
249
- return np.array(converter.tokens2ids(tokenizer.text2tokens(text)))
250
-
251
- data_info = {
252
- "speech": lambda d: librosa.load(d["audio_path"], sr=16000)[0],
253
- "text": lambda d: tokenize(f"<{lang}><{task}><notimestamps> {d['text']}"),
254
- "text_ctc": lambda d: tokenize(d["text_ctc"]),
255
- "text_prev": lambda d: tokenize("<na>"),
256
- }
257
-
258
- # load dataset and define data_info
259
- train_dataset, test_dataset, test_list = get_dataset(tempdir_path, data_info)
260
- print("Loaded dataset.")
261
- gr.Info("Loaded dataset.")
262
-
263
- print("Loading pretrained model...")
264
- gr.Info("Loading pretrained model...")
265
-
266
- model = Speech2Text(
267
- "assets/owsm_ebf_v3.1_base/config.yaml",
268
- "assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth",
269
- device="cuda" if torch.cuda.is_available() else "cpu",
270
- token_type="bpe",
271
- bpemodel="assets/owsm_ebf_v3.1_base/bpe.model",
272
- beam_size=5,
273
- ctc_weight=0.3,
274
- lang_sym=f"<{lang}>",
275
- task_sym=f"<{task}>",
276
- )
277
- model.s2t_model.eval()
278
-
279
- base = ""
280
- ref = ""
281
- with open(f"{tempdir_path}/base.txt", "w") as f_base, open(f"{tempdir_path}/ref.txt", "w") as f_ref:
282
- for i in range(len(test_list)):
283
- data = test_list[i]
284
- f_ref.write(data['text'] + '\n')
285
- out = model(librosa.load(data['audio_path'], sr=16000)[0])[0][3]
286
- f_base.write(out + '\n')
287
- ref += data['text'] + '\n'
288
- base += out + '\n'
289
-
290
- return ref, base
 
1
+ import glob
2
+ import sys
3
+ from pathlib import Path
4
+ import shutil
5
+
6
+ from espnet2.tasks.s2t import S2TTask
7
+ from espnet2.text.sentencepiece_tokenizer import SentencepiecesTokenizer
8
+ from espnet2.text.token_id_converter import TokenIDConverter
9
+ from espnet2.s2t.espnet_model import ESPnetS2TModel
10
+ from espnet2.bin.s2t_inference import Speech2Text
11
+ import espnetez as ez
12
+
13
+ import torch
14
+ import numpy as np
15
+ import logging
16
+ import gradio as gr
17
+ import librosa
18
+
19
+
20
+ class Logger:
21
+ def __init__(self, filename):
22
+ self.terminal = sys.stdout
23
+ self.log = open(filename, "w")
24
+
25
+ def write(self, message):
26
+ self.terminal.write(message)
27
+ self.log.write(message)
28
+
29
+ def flush(self):
30
+ self.terminal.flush()
31
+ self.log.flush()
32
+
33
+ def isatty(self):
34
+ return False
35
+
36
+
37
+ sys.stdout = Logger("output.log")
38
+
39
+
40
+ def count_parameters(model):
41
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
42
+
43
+
44
+ def get_dataset(data_path, data_info, test_count=10):
45
+ # load data
46
+ data = {}
47
+ keys = []
48
+ with open(f"{data_path}/text", "r", encoding="utf-8") as f:
49
+ for line in f.readlines():
50
+ audio_id, text = line.split(maxsplit=1)
51
+ data[audio_id.strip()] = {"text": text.strip()}
52
+ keys.append(audio_id.strip())
53
+
54
+ # load text_ctc data
55
+ with open(f"{data_path}/text_ctc", "r", encoding="utf-8") as f:
56
+ for line in f.readlines():
57
+ audio_id, text = line.split(maxsplit=1)
58
+ data[audio_id.strip()]["text_ctc"] = text.strip()
59
+
60
+ # load audio path
61
+ for audio_path in glob.glob(f"{data_path}/audio/*"):
62
+ audio_id = Path(audio_path).stem
63
+ data[audio_id]["audio_path"] = audio_path
64
+
65
+ # Convert to list
66
+ data = [{
67
+ 'id': audio_id,
68
+ 'text': data[audio_id]['text'],
69
+ 'text_ctc': data[audio_id]['text_ctc'],
70
+ 'audio_path': data[audio_id]['audio_path'],
71
+ } for audio_id in keys]
72
+
73
+ return ez.dataset.ESPnetEZDataset(data[test_count:], data_info), ez.dataset.ESPnetEZDataset(data[:test_count], data_info), data[:test_count]
74
+
75
+
76
+ class CustomFinetuneModel(ESPnetS2TModel):
77
+ def __init__(self, model, log_every=500):
78
+ super().__init__(
79
+ vocab_size=model.vocab_size,
80
+ token_list=model.token_list,
81
+ frontend=model.frontend,
82
+ specaug=model.specaug,
83
+ normalize=model.normalize,
84
+ preencoder=model.preencoder,
85
+ encoder=model.encoder,
86
+ postencoder=model.postencoder,
87
+ decoder=model.decoder,
88
+ ctc=model.ctc,
89
+ ctc_weight=model.ctc_weight,
90
+ interctc_weight=model.interctc_weight,
91
+ ignore_id=model.ignore_id,
92
+ lsm_weight=0.0,
93
+ length_normalized_loss=False,
94
+ report_cer=False,
95
+ report_wer=False,
96
+ sym_space="<space>",
97
+ sym_blank="<blank>",
98
+ sym_sos = "<sos>",
99
+ sym_eos = "<eos>",
100
+ sym_sop = "<sop>", # start of prev
101
+ sym_na = "<na>", # not available
102
+ extract_feats_in_collect_stats=model.extract_feats_in_collect_stats,
103
+ )
104
+ self.iter_count = 0
105
+ self.log_every = log_every
106
+ self.log_stats = {
107
+ 'loss': 0.0,
108
+ 'acc': 0.0
109
+ }
110
+
111
+ def forward(self, *args, **kwargs):
112
+ out = super().forward(*args, **kwargs)
113
+ self.log_stats['loss'] += out[1]['loss'].item()
114
+ self.log_stats['acc'] += out[1]['acc'].item()
115
+
116
+ self.iter_count += 1
117
+ if self.iter_count % self.log_every == 0:
118
+ loss = self.log_stats['loss'] / self.log_every
119
+ acc = self.log_stats['acc'] / self.log_every
120
+ print(f"[{self.iter_count}] - loss: {loss:.3f} - acc: {acc:.3f}")
121
+ self.log_stats['loss'] = 0.0
122
+ self.log_stats['acc'] = 0.0
123
+
124
+ return out
125
+
126
+
127
+ def finetune_model(lang, task, tempdir_path, log_every, max_epoch, scheduler, warmup_steps, optimizer, learning_rate, weight_decay):
128
+ """Main function for finetuning the model."""
129
+ print("Start loading dataset...")
130
+ if len(tempdir_path) == 0:
131
+ raise gr.Error("Please upload a zip file first.")
132
+
133
+ # define tokenizer
134
+ tokenizer = SentencepiecesTokenizer("assets/owsm_ebf_v3.1_base/bpe.model")
135
+ converter = TokenIDConverter("assets/owsm_ebf_v3.1_base/tokens.txt")
136
+
137
+ def tokenize(text):
138
+ return np.array(converter.tokens2ids(tokenizer.text2tokens(text)))
139
+
140
+ data_info = {
141
+ "speech": lambda d: librosa.load(d["audio_path"], sr=16000)[0],
142
+ "text": lambda d: tokenize(f"<{lang}><{task}><notimestamps> {d['text']}"),
143
+ "text_ctc": lambda d: tokenize(d["text_ctc"]),
144
+ "text_prev": lambda d: tokenize("<na>"),
145
+ }
146
+
147
+ # load dataset and define data_info
148
+ train_dataset, test_dataset, test_list = get_dataset(tempdir_path, data_info)
149
+ print("Loaded dataset.")
150
+ gr.Info("Loaded dataset.")
151
+
152
+ # load and update configuration
153
+ print("Setting up the training configuration...")
154
+ pretrain_config = ez.config.from_yaml(
155
+ "s2t",
156
+ "assets/owsm_ebf_v3.1_base/config.yaml",
157
+ )
158
+ finetune_config = ez.config.update_finetune_config(
159
+ "s2t", pretrain_config, "assets/owsm_ebf_v3.1_base/owsm_finetune_base.yaml"
160
+ )
161
+ finetune_config['max_epoch'] = max_epoch
162
+ finetune_config['optim'] = optimizer
163
+ finetune_config['optim_conf']['lr'] = learning_rate
164
+ finetune_config['optim_conf']['weight_decay'] = weight_decay
165
+ finetune_config['scheduler'] = scheduler
166
+ finetune_config['scheduler_conf']['warmup_steps'] = warmup_steps
167
+ finetune_config['multiple_iterator'] = False
168
+ finetune_config['num_iters_per_epoch'] = None
169
+
170
+ def build_model_fn(args):
171
+ model, _ = S2TTask.build_model_from_file(
172
+ "assets/owsm_ebf_v3.1_base/config.yaml",
173
+ "assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth",
174
+ device="cuda" if torch.cuda.is_available() else "cpu",
175
+ )
176
+ model.train()
177
+ print(f'Trainable parameters: {count_parameters(model)}')
178
+ model = CustomFinetuneModel(model, log_every=log_every)
179
+ return model
180
+
181
+ trainer = ez.Trainer(
182
+ task='s2t',
183
+ train_config=finetune_config,
184
+ train_dataset=train_dataset,
185
+ valid_dataset=test_dataset,
186
+ build_model_fn=build_model_fn, # provide the pre-trained model
187
+ data_info=data_info,
188
+ output_dir=f"{tempdir_path}/exp/finetune",
189
+ stats_dir=f"{tempdir_path}/exp/stats",
190
+ ngpu=1
191
+ )
192
+ gr.Info("start collect stats")
193
+ print("Start collect stats process...")
194
+ trainer.collect_stats()
195
+ gr.Info("Finished collect stats, starting training.")
196
+ print("Finished collect stats process. Start training.")
197
+ trainer.train()
198
+ gr.Info("Finished Fine-tuning! Archiving experiment files...")
199
+ print("Finished fine-tuning.")
200
+ print("Start archiving experiment files...")
201
+ print("Create zip file for the following files into `finetune.zip`:")
202
+ for f in glob.glob(f"{tempdir_path}/exp/finetune/*"):
203
+ print(f.replace(tempdir_path, ""))
204
+
205
+ shutil.make_archive(f"{tempdir_path}/finetune", 'zip', f"{tempdir_path}/exp")
206
+ gr.Info("Finished generating result file in zip!")
207
+ print("Finished archiving experiment files.")
208
+
209
+ print("Start generating test result...")
210
+ gr.Info("Start generating output for test set!")
211
+
212
+ del trainer
213
+ model = Speech2Text(
214
+ "assets/owsm_ebf_v3.1_base/config.yaml",
215
+ "assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth",
216
+ device="cuda" if torch.cuda.is_available() else "cpu",
217
+ token_type="bpe",
218
+ bpemodel="assets/owsm_ebf_v3.1_base/bpe.model",
219
+ beam_size=5,
220
+ ctc_weight=0.3,
221
+ lang_sym=f"<{lang}>",
222
+ task_sym=f"<{task}>",
223
+ )
224
+ model.s2t_model.eval()
225
+ d = torch.load(f"{tempdir_path}/exp/finetune/valid.acc.ave.pth")
226
+ model.s2t_model.load_state_dict(d)
227
+
228
+ hyp = ""
229
+ with open(f"{tempdir_path}/hyp.txt", "w") as f_hyp:
230
+ for i in range(len(test_list)):
231
+ data = test_list[i]
232
+ out = model(librosa.load(data['audio_path'], sr=16000)[0])[0][3]
233
+ f_hyp.write(out + '\n')
234
+ hyp += out + '\n'
235
+
236
+ return [f"{tempdir_path}/finetune.zip", f"{tempdir_path}/ref.txt", f"{tempdir_path}/base.txt", f"{tempdir_path}/hyp.txt"], hyp
237
+
238
+
239
+ def baseline_model(lang, task, tempdir_path):
240
+ print("Start loading dataset...")
241
+ if len(tempdir_path) == 0:
242
+ raise gr.Error("Please upload a zip file first.")
243
+
244
+ # define tokenizer
245
+ tokenizer = SentencepiecesTokenizer("assets/owsm_ebf_v3.1_base/bpe.model")
246
+ converter = TokenIDConverter("assets/owsm_ebf_v3.1_base/tokens.txt")
247
+
248
+ def tokenize(text):
249
+ return np.array(converter.tokens2ids(tokenizer.text2tokens(text)))
250
+
251
+ data_info = {
252
+ "speech": lambda d: librosa.load(d["audio_path"], sr=16000)[0],
253
+ "text": lambda d: tokenize(f"<{lang}><{task}><notimestamps> {d['text']}"),
254
+ "text_ctc": lambda d: tokenize(d["text_ctc"]),
255
+ "text_prev": lambda d: tokenize("<na>"),
256
+ }
257
+
258
+ # load dataset and define data_info
259
+ train_dataset, test_dataset, test_list = get_dataset(tempdir_path, data_info)
260
+ print("Loaded dataset.")
261
+ gr.Info("Loaded dataset.")
262
+
263
+ print("Loading pretrained model...")
264
+ gr.Info("Loading pretrained model...")
265
+
266
+ model = Speech2Text(
267
+ "assets/owsm_ebf_v3.1_base/config.yaml",
268
+ "assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth",
269
+ device="cuda" if torch.cuda.is_available() else "cpu",
270
+ token_type="bpe",
271
+ bpemodel="assets/owsm_ebf_v3.1_base/bpe.model",
272
+ beam_size=5,
273
+ ctc_weight=0.3,
274
+ lang_sym=f"<{lang}>",
275
+ task_sym=f"<{task}>",
276
+ )
277
+ model.s2t_model.eval()
278
+
279
+ base = ""
280
+ ref = ""
281
+ with open(f"{tempdir_path}/base.txt", "w") as f_base, open(f"{tempdir_path}/ref.txt", "w") as f_ref:
282
+ for i in range(len(test_list)):
283
+ data = test_list[i]
284
+ f_ref.write(data['text'] + '\n')
285
+ out = model(librosa.load(data['audio_path'], sr=16000)[0])[0][3]
286
+ f_base.write(out + '\n')
287
+ ref += data['text'] + '\n'
288
+ base += out + '\n'
289
+
290
+ return ref, base