File size: 4,239 Bytes
cb97e31
 
 
 
 
 
 
 
 
94752cf
 
e6fd0e8
 
cb97e31
e6fd0e8
cb97e31
 
 
94752cf
cb97e31
 
 
 
820797e
 
 
 
 
cb97e31
 
 
 
 
e6fd0e8
 
 
 
 
cb97e31
 
 
 
e6fd0e8
 
94752cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb97e31
 
 
 
e6fd0e8
94752cf
cb97e31
a830902
 
cb97e31
 
e6fd0e8
 
 
 
94752cf
 
 
 
 
 
 
 
 
 
cb97e31
 
e6fd0e8
 
cb97e31
 
 
 
 
 
e6fd0e8
cb97e31
 
820797e
 
 
 
cb97e31
 
 
 
e6fd0e8
 
94752cf
e6fd0e8
 
94752cf
e6fd0e8
 
 
 
 
 
94752cf
e6fd0e8
 
 
 
 
 
2e23485
 
 
 
e6fd0e8
 
 
cb97e31
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://huggingface.co/spaces/sayakpaul/demo-docker-gradio
"""
import argparse
import json
import platform

from allennlp.models.archival import archive_model, load_archive
from allennlp.predictors.text_classifier import TextClassifierPredictor
import fasttext
from fasttext.FastText import load_model, _FastText
import gradio as gr
from gradio import inputs, outputs
from langid.langid import LanguageIdentifier, model

from project_settings import project_path, temp_directory
from toolbox.os.command import Command


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--language_identification_md_file",
        default=(project_path / "language_identification.md").as_posix(),
        type=str
    )
    parser.add_argument(
        "--lang_id_examples_file",
        default=(project_path / "lang_id_examples.json").as_posix(),
        type=str
    )
    parser.add_argument(
        "--fasttext_model",
        default=(project_path / "pretrained_models/lid.176.bin").as_posix(),
        type=str
    )
    args = parser.parse_args()
    return args


lang_id_identifier: LanguageIdentifier = None
fasttext_model: _FastText = None
qgyd_lang_id_predictor: TextClassifierPredictor = None


trained_model_dir = project_path / "trained_models/huggingface"
trained_model_dir.mkdir(parents=True, exist_ok=True)


def init_qgyd_lang_id_predictor() -> TextClassifierPredictor:
    model_name = "qgyd2021/language_identification"
    model_path = trained_model_dir / model_name
    if not model_path.exists():
        model_path.parent.mkdir(exist_ok=True)
        Command.cd(model_path.parent.as_posix())
        Command.popen("git clone https://huggingface.co/{}".format(model_name))

    archive = load_archive(archive_file=model_path.as_posix())

    predictor = TextClassifierPredictor(
        model=archive.model,
        dataset_reader=archive.dataset_reader,
    )
    return predictor


def click_lang_id_button(text: str, ground_true: str, model_name: str):
    global lang_id_identifier
    global fasttext_model
    global qgyd_lang_id_predictor

    text = str(text).strip()

    if model_name == "langid":
        label, prob = lang_id_identifier.classify(text)
    elif model_name == "fasttext":
        label, prob = fasttext_model.predict(text, k=1)
        label = label[0][9:]
        prob = prob[0]
    elif model_name == "qgyd_lang_id_1":
        json_dict = {
            "sentence": text
        }
        outputs = qgyd_lang_id_predictor.predict_json(
            json_dict
        )
        label = outputs["label"]
        probs = outputs["probs"]
        prob = max(probs)
    else:
        label = "model_name not available."
        prob = -1
    return label, str(round(prob, 4))


def main():
    args = get_args()

    brief_description = """
    Language Identification
    """

    # description
    with open(args.language_identification_md_file, "r", encoding="utf-8") as f:
        description = f.read()

    # examples
    with open(args.lang_id_examples_file, "r", encoding="utf-8") as f:
        lang_id_examples = json.load(f)

    global lang_id_identifier
    global fasttext_model
    global qgyd_lang_id_predictor
    lang_id_identifier = LanguageIdentifier.from_modelstring(model, norm_probs=True)
    fasttext_model = fasttext.load_model(args.fasttext_model)
    qgyd_lang_id_predictor = init_qgyd_lang_id_predictor()

    blocks = gr.Interface(
        click_lang_id_button,
        inputs=[
            inputs.Textbox(lines=3, label="text"),
            inputs.Textbox(label="ground_true"),
            inputs.Dropdown(choices=["langid", "fasttext", "qgyd_lang_id_1"], default="langid", label="model_name"),
        ],
        outputs=[
            outputs.Textbox(label="label"),
            outputs.Textbox(label="prob"),
        ],
        examples=lang_id_examples,
        description=brief_description,
        title="Language Identification",
        server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
        server_port=7860
    )

    blocks.launch(
        share=False if platform.system() == "Windows" else False,
    )
    return


if __name__ == "__main__":
    main()