Zaid's picture
add diacritizer
5112867
raw
history blame
No virus
2.2 kB
from .config_manager import ConfigManager
import os
from typing import Dict
from torch import nn
from tqdm import tqdm
from tqdm import trange
from dataset import load_iterators
from trainer import GeneralTrainer
class DiacritizationTester(GeneralTrainer):
def __init__(self, config_path: str, model_kind: str) -> None:
self.config_path = config_path
self.model_kind = model_kind
self.config_manager = ConfigManager(
config_path=config_path, model_kind=model_kind
)
self.config = self.config_manager.config
self.pad_idx = 0
self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_idx)
self.set_device()
self.text_encoder = self.config_manager.text_encoder
self.start_symbol_id = self.text_encoder.start_symbol_id
self.model = self.config_manager.get_model()
self.model = self.model.to(self.device)
self.load_model(model_path=self.config["test_model_path"], load_optimizer=False)
self.load_diacritizer()
self.diacritizer.set_model(self.model)
self.initialize_model()
self.print_config()
def run(self):
self.config_manager.config["load_training_data"] = False
self.config_manager.config["load_validation_data"] = False
self.config_manager.config["load_test_data"] = True
_, test_iterator, _ = load_iterators(self.config_manager)
tqdm_eval = trange(0, len(test_iterator), leave=True)
tqdm_error_rates = trange(0, len(test_iterator), leave=True)
loss, acc = self.evaluate(test_iterator, tqdm_eval, log = False)
error_rates, _ = self.evaluate_with_error_rates(test_iterator, tqdm_error_rates, log = False)
tqdm_eval.close()
tqdm_error_rates.close()
WER = error_rates["WER"]
DER = error_rates["DER"]
DER1 = error_rates["DER*"]
WER1 = error_rates["WER*"]
error_rates = f"DER: {DER}, WER: {WER}, DER*: {DER1}, WER*: {WER1}"
print(f"global step : {self.global_step}")
print(f"Evaluate {self.global_step}: accuracy, {acc}, loss: {loss}")
print(f"WER/DER {self.global_step}: {error_rates}")