SinaAhmadi commited on
Commit
75b9522
1 Parent(s): 64500d7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from functools import partial
3
+
4
+ from joeynmt.prediction import predict
5
+ from joeynmt.helpers import (
6
+ check_version,
7
+ load_checkpoint,
8
+ load_config,
9
+ parse_train_args,
10
+ resolve_ckpt_path,
11
+
12
+ )
13
+ from joeynmt.model import build_model
14
+ from joeynmt.tokenizers import build_tokenizer
15
+ from joeynmt.vocabulary import build_vocab
16
+ from joeynmt.datasets import build_dataset
17
+
18
+ import gradio as gr
19
+
20
+ # INPUT = "سلاو لە ناو گلی کرد"
21
+
22
+ cfg_file = 'config.yaml'
23
+ ckpt = './models/Sorani-Arabic/best.ckpt'
24
+
25
+ cfg = load_config(Path(cfg_file))
26
+ # parse and validate cfg
27
+ model_dir, load_model, device, n_gpu, num_workers, _, fp16 = parse_train_args(
28
+ cfg["training"], mode="prediction")
29
+ test_cfg = cfg["testing"]
30
+ src_cfg = cfg["data"]["src"]
31
+ trg_cfg = cfg["data"]["trg"]
32
+
33
+ load_model = load_model if ckpt is None else Path(ckpt)
34
+ ckpt = resolve_ckpt_path(load_model, model_dir)
35
+
36
+ src_vocab, trg_vocab = build_vocab(cfg["data"], model_dir=model_dir)
37
+
38
+ model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)
39
+
40
+ # load model state from disk
41
+ model_checkpoint = load_checkpoint(ckpt, device=device)
42
+ model.load_state_dict(model_checkpoint["model_state"])
43
+
44
+ if device.type == "cuda":
45
+ model.to(device)
46
+
47
+ tokenizer = build_tokenizer(cfg["data"])
48
+ sequence_encoder = {
49
+ src_cfg["lang"]: partial(src_vocab.sentences_to_ids, bos=False, eos=True),
50
+ trg_cfg["lang"]: None,
51
+ }
52
+
53
+ test_cfg["batch_size"] = 1 # CAUTION: this will raise an error if n_gpus > 1
54
+ test_cfg["batch_type"] = "sentence"
55
+
56
+ test_data = build_dataset(
57
+ dataset_type="stream",
58
+ path=None,
59
+ src_lang=src_cfg["lang"],
60
+ trg_lang=trg_cfg["lang"],
61
+ split="test",
62
+ tokenizer=tokenizer,
63
+ sequence_encoder=sequence_encoder,
64
+ )
65
+ # test_data.set_item(INPUT.rstrip())
66
+
67
+
68
+ def _translate_data(test_data, cfg=test_cfg):
69
+ """Translates given dataset, using parameters from outer scope."""
70
+ _, _, hypotheses, trg_tokens, trg_scores, _ = predict(
71
+ model=model,
72
+ data=test_data,
73
+ compute_loss=False,
74
+ device=device,
75
+ n_gpu=n_gpu,
76
+ normalization="none",
77
+ num_workers=num_workers,
78
+ cfg=cfg,
79
+ fp16=fp16,
80
+ )
81
+ return hypotheses[0]
82
+
83
+
84
+
85
+ def normalize(text):
86
+ test_data.set_item(text)
87
+ result = _translate_data(test_data)
88
+ return result
89
+
90
+ examples = [
91
+ ["ياخوا تةمةن دريژبيت بوئةم ميللةتة"],
92
+ ["سلاو برا جونی؟"],
93
+ ]
94
+
95
+
96
+
97
+ demo = gr.Interface(
98
+ fn=normalize,
99
+ inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
100
+ outputs=gr.outputs.Textbox(label="Output Text" ),
101
+ examples=examples
102
+ )
103
+
104
+ demo.launch()