|
from pathlib import Path |
|
from functools import partial |
|
|
|
from joeynmt.prediction import predict |
|
from joeynmt.helpers import ( |
|
check_version, |
|
load_checkpoint, |
|
load_config, |
|
parse_train_args, |
|
resolve_ckpt_path, |
|
|
|
) |
|
from joeynmt.model import build_model |
|
from joeynmt.tokenizers import build_tokenizer |
|
from joeynmt.vocabulary import build_vocab |
|
from joeynmt.datasets import build_dataset |
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
cfg_file = 'config.yaml' |
|
ckpt = './models/Sorani-Arabic/best.ckpt' |
|
|
|
cfg = load_config(Path(cfg_file)) |
|
|
|
model_dir, load_model, device, n_gpu, num_workers, _, fp16 = parse_train_args( |
|
cfg["training"], mode="prediction") |
|
test_cfg = cfg["testing"] |
|
src_cfg = cfg["data"]["src"] |
|
trg_cfg = cfg["data"]["trg"] |
|
|
|
load_model = load_model if ckpt is None else Path(ckpt) |
|
ckpt = resolve_ckpt_path(load_model, model_dir) |
|
|
|
src_vocab, trg_vocab = build_vocab(cfg["data"], model_dir=model_dir) |
|
|
|
model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab) |
|
|
|
|
|
model_checkpoint = load_checkpoint(ckpt, device=device) |
|
model.load_state_dict(model_checkpoint["model_state"]) |
|
|
|
if device.type == "cuda": |
|
model.to(device) |
|
|
|
tokenizer = build_tokenizer(cfg["data"]) |
|
sequence_encoder = { |
|
src_cfg["lang"]: partial(src_vocab.sentences_to_ids, bos=False, eos=True), |
|
trg_cfg["lang"]: None, |
|
} |
|
|
|
test_cfg["batch_size"] = 1 |
|
test_cfg["batch_type"] = "sentence" |
|
|
|
test_data = build_dataset( |
|
dataset_type="stream", |
|
path=None, |
|
src_lang=src_cfg["lang"], |
|
trg_lang=trg_cfg["lang"], |
|
split="test", |
|
tokenizer=tokenizer, |
|
sequence_encoder=sequence_encoder, |
|
) |
|
|
|
|
|
|
|
def _translate_data(test_data, cfg=test_cfg): |
|
"""Translates given dataset, using parameters from outer scope.""" |
|
_, _, hypotheses, trg_tokens, trg_scores, _ = predict( |
|
model=model, |
|
data=test_data, |
|
compute_loss=False, |
|
device=device, |
|
n_gpu=n_gpu, |
|
normalization="none", |
|
num_workers=num_workers, |
|
cfg=cfg, |
|
fp16=fp16, |
|
) |
|
return hypotheses[0] |
|
|
|
|
|
|
|
def normalize(text): |
|
test_data.set_item(text) |
|
result = _translate_data(test_data) |
|
return result |
|
|
|
examples = [ |
|
["ياخوا تةمةن دريژبيت بوئةم ميللةتة"], |
|
["سلاو برا جونی؟"], |
|
] |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=normalize, |
|
inputs=gr.inputs.Textbox(lines=5, label="Input Text"), |
|
outputs=gr.outputs.Textbox(label="Output Text" ), |
|
examples=examples |
|
) |
|
|
|
demo.launch() |
|
|