Spaces:
Runtime error
Runtime error
juanpablo4l
commited on
Commit
•
d0fad25
1
Parent(s):
857aac4
Added NLG models
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
from models import
|
4 |
|
5 |
|
6 |
def predict(text: str, model_name: str) -> str:
|
@@ -9,16 +9,30 @@ def predict(text: str, model_name: str) -> str:
|
|
9 |
|
10 |
with gr.Blocks(title="CLARIN-PL Dialogue System Modules") as demo:
|
11 |
gr.Markdown("Dialogue State Tracking Modules")
|
12 |
-
for model_name in
|
13 |
with gr.Row():
|
14 |
gr.Markdown(f"## {model_name}")
|
15 |
model_name_component = gr.Textbox(value=model_name, visible=False)
|
16 |
with gr.Row():
|
17 |
-
text_input = gr.Textbox(label="Input Text", value=
|
18 |
output = gr.Textbox(label="Slot Value", value="")
|
19 |
with gr.Row():
|
20 |
predict_button = gr.Button("Predict")
|
21 |
predict_button.click(fn=predict, inputs=[text_input, model_name_component], outputs=output)
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
demo.queue(concurrency_count=3)
|
24 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
from models import DST_MODELS, NLG_MODELS, PIPELINES
|
4 |
|
5 |
|
6 |
def predict(text: str, model_name: str) -> str:
|
|
|
9 |
|
10 |
with gr.Blocks(title="CLARIN-PL Dialogue System Modules") as demo:
|
11 |
gr.Markdown("Dialogue State Tracking Modules")
|
12 |
+
for model_name in DST_MODELS:
|
13 |
with gr.Row():
|
14 |
gr.Markdown(f"## {model_name}")
|
15 |
model_name_component = gr.Textbox(value=model_name, visible=False)
|
16 |
with gr.Row():
|
17 |
+
text_input = gr.Textbox(label="Input Text", value=DST_MODELS[model_name]["default_input"])
|
18 |
output = gr.Textbox(label="Slot Value", value="")
|
19 |
with gr.Row():
|
20 |
predict_button = gr.Button("Predict")
|
21 |
predict_button.click(fn=predict, inputs=[text_input, model_name_component], outputs=output)
|
22 |
|
23 |
+
|
24 |
+
gr.Markdown("Natural Language Generation / Paraphrasing Modules")
|
25 |
+
for model_name in NLG_MODELS:
|
26 |
+
with gr.Row():
|
27 |
+
gr.Markdown(f"## {model_name}")
|
28 |
+
model_name_component = gr.Textbox(value=model_name, visible=False)
|
29 |
+
with gr.Row():
|
30 |
+
text_input = gr.Textbox(label="Input Text", value=NLG_MODELS[model_name]["default_input"])
|
31 |
+
output = gr.Textbox(label="Slot Value", value="")
|
32 |
+
with gr.Row():
|
33 |
+
predict_button = gr.Button("Predict")
|
34 |
+
predict_button.click(fn=predict, inputs=[text_input, model_name_component], outputs=output)
|
35 |
+
|
36 |
+
|
37 |
demo.queue(concurrency_count=3)
|
38 |
demo.launch()
|
models.py
CHANGED
@@ -2,11 +2,12 @@ import os
|
|
2 |
from typing import Any, Dict
|
3 |
|
4 |
from transformers import (Pipeline, T5ForConditionalGeneration, T5Tokenizer,
|
5 |
-
pipeline)
|
6 |
|
7 |
auth_token = os.environ.get("CLARIN_KNEXT")
|
8 |
|
9 |
-
|
|
|
10 |
"polish": (
|
11 |
"[U] Chciałbym zarezerwować stolik na 4 osoby na piątek o godzinie 18:30. "
|
12 |
"[Dziedzina] Restauracje: Popularna usługa wyszukiwania i rezerwacji restauracji "
|
@@ -19,47 +20,91 @@ DEFAULT_INPUTS: Dict[str, str] = {
|
|
19 |
),
|
20 |
}
|
21 |
|
22 |
-
|
|
|
23 |
"plt5-small": {
|
24 |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-small-dst", use_auth_token=auth_token),
|
25 |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-small-dst", use_auth_token=auth_token),
|
26 |
-
"default_input":
|
27 |
},
|
28 |
"plt5-base": {
|
29 |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-base-dst", use_auth_token=auth_token),
|
30 |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-base-dst", use_auth_token=auth_token),
|
31 |
-
"default_input":
|
32 |
},
|
33 |
"plt5-base-poquad-dst-v2": {
|
34 |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-base-poquad-dst-v2", use_auth_token=auth_token),
|
35 |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-base-poquad-dst-v2", use_auth_token=auth_token),
|
36 |
-
"default_input":
|
37 |
},
|
38 |
"t5-small": {
|
39 |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/t5-small-dst", use_auth_token=auth_token),
|
40 |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/t5-small-dst", use_auth_token=auth_token),
|
41 |
-
"default_input":
|
42 |
},
|
43 |
"t5-base": {
|
44 |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/t5-base-dst", use_auth_token=auth_token),
|
45 |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/t5-base-dst", use_auth_token=auth_token),
|
46 |
-
"default_input":
|
47 |
},
|
48 |
"flant5-small [EN/PL]": {
|
49 |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/flant5-small-dst", use_auth_token=auth_token),
|
50 |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/flant5-small-dst", use_auth_token=auth_token),
|
51 |
-
"default_input":
|
52 |
},
|
53 |
"flant5-base [EN/PL]": {
|
54 |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/flant5-base-dst", use_auth_token=auth_token),
|
55 |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/flant5-base-dst", use_auth_token=auth_token),
|
56 |
-
"default_input":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
},
|
58 |
}
|
59 |
|
60 |
PIPELINES: Dict[str, Pipeline] = {
|
61 |
model_name: pipeline(
|
62 |
-
"text2text-generation", model=
|
63 |
)
|
64 |
-
for model_name in
|
65 |
}
|
|
|
2 |
from typing import Any, Dict
|
3 |
|
4 |
from transformers import (Pipeline, T5ForConditionalGeneration, T5Tokenizer,
|
5 |
+
pipeline, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer)
|
6 |
|
7 |
auth_token = os.environ.get("CLARIN_KNEXT")
|
8 |
|
9 |
+
|
10 |
+
DEFAULT_DST_INPUTS: Dict[str, str] = {
|
11 |
"polish": (
|
12 |
"[U] Chciałbym zarezerwować stolik na 4 osoby na piątek o godzinie 18:30. "
|
13 |
"[Dziedzina] Restauracje: Popularna usługa wyszukiwania i rezerwacji restauracji "
|
|
|
20 |
),
|
21 |
}
|
22 |
|
23 |
+
|
24 |
+
DST_MODELS: Dict[str, Dict[str, Any]] = {
|
25 |
"plt5-small": {
|
26 |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-small-dst", use_auth_token=auth_token),
|
27 |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-small-dst", use_auth_token=auth_token),
|
28 |
+
"default_input": DEFAULT_DST_INPUTS["polish"],
|
29 |
},
|
30 |
"plt5-base": {
|
31 |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-base-dst", use_auth_token=auth_token),
|
32 |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-base-dst", use_auth_token=auth_token),
|
33 |
+
"default_input": DEFAULT_DST_INPUTS["polish"],
|
34 |
},
|
35 |
"plt5-base-poquad-dst-v2": {
|
36 |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-base-poquad-dst-v2", use_auth_token=auth_token),
|
37 |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-base-poquad-dst-v2", use_auth_token=auth_token),
|
38 |
+
"default_input": DEFAULT_DST_INPUTS["polish"],
|
39 |
},
|
40 |
"t5-small": {
|
41 |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/t5-small-dst", use_auth_token=auth_token),
|
42 |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/t5-small-dst", use_auth_token=auth_token),
|
43 |
+
"default_input": DEFAULT_DST_INPUTS["english"],
|
44 |
},
|
45 |
"t5-base": {
|
46 |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/t5-base-dst", use_auth_token=auth_token),
|
47 |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/t5-base-dst", use_auth_token=auth_token),
|
48 |
+
"default_input": DEFAULT_DST_INPUTS["english"],
|
49 |
},
|
50 |
"flant5-small [EN/PL]": {
|
51 |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/flant5-small-dst", use_auth_token=auth_token),
|
52 |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/flant5-small-dst", use_auth_token=auth_token),
|
53 |
+
"default_input": DEFAULT_DST_INPUTS["english"],
|
54 |
},
|
55 |
"flant5-base [EN/PL]": {
|
56 |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/flant5-base-dst", use_auth_token=auth_token),
|
57 |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/flant5-base-dst", use_auth_token=auth_token),
|
58 |
+
"default_input": DEFAULT_DST_INPUTS["english"],
|
59 |
+
},
|
60 |
+
}
|
61 |
+
|
62 |
+
|
63 |
+
DEFAULT_ENCODER_DECODER_INPUT_EN = "The alarm is set for 6 am. The alarm's name is name \"Get up\"."
|
64 |
+
DEFAULT_DECODER_ONLY_INPUT_EN = f"[BOS]{DEFAULT_ENCODER_DECODER_INPUT_EN}[SEP]"
|
65 |
+
DEFAULT_ENCODER_DECODER_INPUT_PL = "Alarm jest o godzinie 6 rano. Alarm ma nazwę \"Obudź się\"."
|
66 |
+
DEFAULT_DECODER_ONLY_INPUT_PL = f"[BOS]{DEFAULT_ENCODER_DECODER_INPUT_PL}[SEP]"
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
NLG_MODELS: Dict[str, Dict[str, Any]] = {
|
71 |
+
# English
|
72 |
+
"t5-large": {
|
73 |
+
"model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-t5-large", use_auth_token=auth_token),
|
74 |
+
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-t5-large", use_auth_token=auth_token),
|
75 |
+
"default_input": DEFAULT_ENCODER_DECODER_INPUT_EN,
|
76 |
+
},
|
77 |
+
"en-mt5-large": {
|
78 |
+
"model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-en-mt5-large", use_auth_token=auth_token),
|
79 |
+
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-en-mt5-large", use_auth_token=auth_token),
|
80 |
+
"default_input": DEFAULT_ENCODER_DECODER_INPUT_EN,
|
81 |
+
},
|
82 |
+
"gpt2": {
|
83 |
+
"model": AutoModelForCausalLM.from_pretrained("clarin-knext/utterance-rewriting-gpt2", use_auth_token=auth_token),
|
84 |
+
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-gpt2", use_auth_token=auth_token),
|
85 |
+
"default_input": DEFAULT_DECODER_ONLY_INPUT_EN,
|
86 |
+
},
|
87 |
+
|
88 |
+
"pt5-large": {
|
89 |
+
"model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-pt5-large", use_auth_token=auth_token),
|
90 |
+
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-pt5-large", use_auth_token=auth_token),
|
91 |
+
"default_input": DEFAULT_ENCODER_DECODER_INPUT_PL,
|
92 |
+
},
|
93 |
+
"pl-mt5-large": {
|
94 |
+
"model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-pl-mt5-large", use_auth_token=auth_token),
|
95 |
+
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-pl-mt5-large", use_auth_token=auth_token),
|
96 |
+
"default_input": DEFAULT_ENCODER_DECODER_INPUT_PL,
|
97 |
+
},
|
98 |
+
"polish-gpt2": {
|
99 |
+
"model": AutoModelForCausalLM.from_pretrained("clarin-knext/utterance-rewriting-polish-gpt2", use_auth_token=auth_token),
|
100 |
+
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-polish-gpt2", use_auth_token=auth_token),
|
101 |
+
"default_input": DEFAULT_DECODER_ONLY_INPUT_PL,
|
102 |
},
|
103 |
}
|
104 |
|
105 |
PIPELINES: Dict[str, Pipeline] = {
|
106 |
model_name: pipeline(
|
107 |
+
"text2text-generation", model=DST_MODELS[model_name]["model"], tokenizer=DST_MODELS[model_name]["tokenizer"]
|
108 |
)
|
109 |
+
for model_name in DST_MODELS
|
110 |
}
|