File size: 1,382 Bytes
5ba903e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#!/usr/bin/env python
from argparse import ArgumentDefaultsHelpFormatter
from collections import namedtuple
from functools import partialmethod
import json

from tqdm import tqdm
from eval import main


# tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

Args = namedtuple("Args", "model_id dataset config split log_outputs chunk_length_s stride_length_s device")
args = Args("./", "NbAiLab/NPSC", "16K_mp3_bokmaal", "test", True, None, None, 0)
with open("grid.csv", "w") as grid:
    grid.write("alpha,beta,wer,cer")
    for alpha in [0.001, 0.01, 0.1, 0.25, 0.5, 0.75, 1, 1.5, 2, 3]:
        for beta in [0.001, 0.01, 0.1, 0.25, 0.5, 0.75, 1, 1.5, 2, 3]:
            with open("./language_model/attrs.json", "r") as attrs_file:
                attrs = json.load(attrs_file)
            attrs["alpha"] = alpha
            attrs["beta"] = beta
            with open("./language_model/attrs.json", "w") as attrs_file:
                json.dump(attrs, attrs_file)
            print(f"alpha = {alpha}, beta = {beta}")
            main(args)
            with open("NbAiLab_NPSC_16K_mp3_bokmaal_test_eval_results.txt") as results_file:
                results = results_file.read().strip().split("\n")
            wer = float(results[0][5:])
            cer = float(results[1][5:])
            grid.write(f"\n{alpha},{beta},{wer},{cer}")
            print("--------------")