from pathlib import Path from functools import partial from joeynmt.prediction import predict from joeynmt.helpers import ( check_version, load_checkpoint, load_config, parse_train_args, resolve_ckpt_path, ) from joeynmt.model import build_model from joeynmt.tokenizers import build_tokenizer from joeynmt.vocabulary import build_vocab from joeynmt.datasets import build_dataset import gradio as gr # INPUT = "سلاو لە ناو گلی کرد" cfg_file = 'config.yaml' ckpt = './models/Sorani-Arabic/best.ckpt' cfg = load_config(Path(cfg_file)) # parse and validate cfg model_dir, load_model, device, n_gpu, num_workers, _, fp16 = parse_train_args( cfg["training"], mode="prediction") test_cfg = cfg["testing"] src_cfg = cfg["data"]["src"] trg_cfg = cfg["data"]["trg"] load_model = load_model if ckpt is None else Path(ckpt) ckpt = resolve_ckpt_path(load_model, model_dir) src_vocab, trg_vocab = build_vocab(cfg["data"], model_dir=model_dir) model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab) # load model state from disk model_checkpoint = load_checkpoint(ckpt, device=device) model.load_state_dict(model_checkpoint["model_state"]) if device.type == "cuda": model.to(device) tokenizer = build_tokenizer(cfg["data"]) sequence_encoder = { src_cfg["lang"]: partial(src_vocab.sentences_to_ids, bos=False, eos=True), trg_cfg["lang"]: None, } test_cfg["batch_size"] = 1 # CAUTION: this will raise an error if n_gpus > 1 test_cfg["batch_type"] = "sentence" test_data = build_dataset( dataset_type="stream", path=None, src_lang=src_cfg["lang"], trg_lang=trg_cfg["lang"], split="test", tokenizer=tokenizer, sequence_encoder=sequence_encoder, ) # test_data.set_item(INPUT.rstrip()) def _translate_data(test_data, cfg=test_cfg): """Translates given dataset, using parameters from outer scope.""" _, _, hypotheses, trg_tokens, trg_scores, _ = predict( model=model, data=test_data, compute_loss=False, device=device, n_gpu=n_gpu, normalization="none", num_workers=num_workers, cfg=cfg, fp16=fp16, ) return hypotheses[0] def normalize(text): test_data.set_item(text) result = _translate_data(test_data) return result examples = [ ["ياخوا تةمةن دريژبيت بوئةم ميللةتة"], ["سلاو برا جونی؟"], ] demo = gr.Interface( fn=normalize, inputs=gr.inputs.Textbox(lines=5, label="Input Text"), outputs=gr.outputs.Textbox(label="Output Text" ), examples=examples ) demo.launch()