pere's picture
first
5ba903e
raw
history blame
1.38 kB
#!/usr/bin/env python
from argparse import ArgumentDefaultsHelpFormatter
from collections import namedtuple
from functools import partialmethod
import json
from tqdm import tqdm
from eval import main
# tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
Args = namedtuple("Args", "model_id dataset config split log_outputs chunk_length_s stride_length_s device")
args = Args("./", "NbAiLab/NPSC", "16K_mp3_bokmaal", "test", True, None, None, 0)
with open("grid.csv", "w") as grid:
grid.write("alpha,beta,wer,cer")
for alpha in [0.001, 0.01, 0.1, 0.25, 0.5, 0.75, 1, 1.5, 2, 3]:
for beta in [0.001, 0.01, 0.1, 0.25, 0.5, 0.75, 1, 1.5, 2, 3]:
with open("./language_model/attrs.json", "r") as attrs_file:
attrs = json.load(attrs_file)
attrs["alpha"] = alpha
attrs["beta"] = beta
with open("./language_model/attrs.json", "w") as attrs_file:
json.dump(attrs, attrs_file)
print(f"alpha = {alpha}, beta = {beta}")
main(args)
with open("NbAiLab_NPSC_16K_mp3_bokmaal_test_eval_results.txt") as results_file:
results = results_file.read().strip().split("\n")
wer = float(results[0][5:])
cer = float(results[1][5:])
grid.write(f"\n{alpha},{beta},{wer},{cer}")
print("--------------")