from typing import Dict import torch from .config_manager import ConfigManager class Diacritizer: def __init__( self, config_path: str, model_kind: str, load_model: bool = False ) -> None: self.config_path = config_path self.model_kind = model_kind self.config_manager = ConfigManager( config_path=config_path, model_kind=model_kind ) self.config = self.config_manager.config self.text_encoder = self.config_manager.text_encoder if self.config.get("device"): self.device = self.config["device"] else: self.device = "cuda" if torch.cuda.is_available() else "cpu" if load_model: self.model, self.global_step = self.config_manager.load_model() self.model = self.model.to(self.device) self.start_symbol_id = self.text_encoder.start_symbol_id def set_model(self, model: torch.nn.Module): self.model = model def diacritize_text(self, text: str): seq = self.text_encoder.input_to_sequence(text) output = self.diacritize_batch(torch.LongTensor([seq]).to(self.device)) def diacritize_batch(self, batch): raise NotImplementedError() def diacritize_iterators(self, iterator): pass class CBHGDiacritizer(Diacritizer): def diacritize_batch(self, batch): self.model.eval() inputs = batch["src"] lengths = batch["lengths"] outputs = self.model(inputs.to(self.device), lengths.to("cpu")) diacritics = outputs["diacritics"] predictions = torch.max(diacritics, 2).indices sentences = [] for src, prediction in zip(inputs, predictions): sentence = self.text_encoder.combine_text_and_haraqat( list(src.detach().cpu().numpy()), list(prediction.detach().cpu().numpy()), ) sentences.append(sentence) return sentences class Seq2SeqDiacritizer(Diacritizer): def diacritize_batch(self, batch): self.model.eval() inputs = batch["src"] lengths = batch["lengths"] outputs = self.model(inputs.to(self.device), lengths.to("cpu")) diacritics = outputs["diacritics"] predictions = torch.max(diacritics, 2).indices sentences = [] for src, prediction in zip(inputs, predictions): sentence = self.text_encoder.combine_text_and_haraqat( list(src.detach().cpu().numpy()), list(prediction.detach().cpu().numpy()), ) sentences.append(sentence) return sentences class GPTDiacritizer(Diacritizer): def diacritize_batch(self, batch): self.model.eval() inputs = batch["src"] lengths = batch["lengths"] outputs = self.model(inputs.to(self.device), lengths.to("cpu")) diacritics = outputs["diacritics"] predictions = torch.max(diacritics, 2).indices sentences = [] for src, prediction in zip(inputs, predictions): sentence = self.text_encoder.combine_text_and_haraqat( list(src.detach().cpu().numpy()), list(prediction.detach().cpu().numpy()), ) sentences.append(sentence) return sentences