Add default texts and detect of language direction
Browse files- app.py +53 -31
- default_texts.py +23 -0
- generator.py +7 -3
app.py
CHANGED
@@ -5,64 +5,73 @@ import psutil
|
|
5 |
import streamlit as st
|
6 |
|
7 |
from generator import GeneratorFactory
|
|
|
|
|
8 |
|
9 |
device = torch.cuda.device_count() - 1
|
10 |
|
11 |
-
|
|
|
12 |
|
13 |
GENERATOR_LIST = [
|
14 |
-
{
|
15 |
-
"model_name": "Helsinki-NLP/opus-mt-en-nl",
|
16 |
-
"desc": "Opus MT en->nl",
|
17 |
-
"task": TRANSLATION_NL_TO_EN,
|
18 |
-
"split_sentences": True,
|
19 |
-
},
|
20 |
{
|
21 |
"model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
|
22 |
"desc": "T5 small nl24 ccmatrix en->nl",
|
23 |
-
"task":
|
24 |
"split_sentences": True,
|
25 |
},
|
26 |
{
|
27 |
-
"model_name": "yhavinga/
|
28 |
-
"desc": "
|
29 |
"task": TRANSLATION_NL_TO_EN,
|
30 |
-
"split_sentences":
|
31 |
},
|
32 |
{
|
33 |
-
"model_name": "
|
34 |
-
"desc": "
|
35 |
-
"task":
|
36 |
"split_sentences": True,
|
37 |
},
|
38 |
# {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
# "model_name": "yhavinga/t5-eff-large-8l-nedd-en-nl",
|
40 |
# "desc": "T5 eff large nl8 en->nl",
|
41 |
-
# "task":
|
42 |
# "split_sentences": True,
|
43 |
# },
|
44 |
# {
|
45 |
# "model_name": "yhavinga/t5-base-36L-ccmatrix-multi",
|
46 |
# "desc": "T5 base nl36 ccmatrix en->nl",
|
47 |
-
# "task":
|
48 |
# "split_sentences": True,
|
49 |
# },
|
50 |
# {
|
51 |
# "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl",
|
52 |
# "desc": "longT5 large nl8 512beta/512l en->nl",
|
53 |
-
# "task":
|
54 |
# "split_sentences": False,
|
55 |
# },
|
56 |
# {
|
57 |
# "model_name": "yhavinga/t5-base-36L-nedd-x-en-nl-300",
|
58 |
# "desc": "T5 base 36L nedd en->nl 300",
|
59 |
-
# "task":
|
60 |
# "split_sentences": True,
|
61 |
# },
|
62 |
# {
|
63 |
# "model_name": "yhavinga/long-t5-local-small-ccmatrix-en-nl",
|
64 |
# "desc": "longT5 small ccmatrix en->nl",
|
65 |
-
# "task":
|
66 |
# "split_sentences": True,
|
67 |
# },
|
68 |
]
|
@@ -88,20 +97,18 @@ def main():
|
|
88 |
Vertaal van en naar Engels"""
|
89 |
)
|
90 |
st.sidebar.title("Parameters:")
|
91 |
-
if "prompt_box" not in st.session_state:
|
92 |
-
# Text is from https://www.gutenberg.org/files/35091/35091-h/35091-h.html
|
93 |
-
st.session_state[
|
94 |
-
"prompt_box"
|
95 |
-
] = """It was a wet, gusty night and I had a lonely walk home. By taking the river road, though I hated it, I saved two miles, so I sloshed ahead trying not to think at all. Through the barbed wire fence I could see the racing river. Its black swollen body writhed along with extraordinary swiftness, breathlessly silent, only occasionally making a swishing ripple. I did not enjoy looking at it. I was somehow afraid.
|
96 |
|
97 |
-
|
|
|
|
|
|
|
|
|
98 |
|
99 |
-
|
|
|
100 |
|
101 |
-
|
102 |
-
st.session_state["text"] =
|
103 |
-
"Enter text", st.session_state.prompt_box, height=250
|
104 |
-
)
|
105 |
num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=1)
|
106 |
num_beam_groups = st.sidebar.number_input(
|
107 |
"Num beam groups", min_value=1, max_value=10, value=1
|
@@ -109,6 +116,7 @@ It was a quite young girl, unknown to me, with a hood over her head, and with la
|
|
109 |
length_penalty = st.sidebar.number_input(
|
110 |
"Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
|
111 |
)
|
|
|
112 |
st.sidebar.markdown(
|
113 |
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
|
114 |
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
|
@@ -125,7 +133,21 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
|
|
125 |
if st.button("Run"):
|
126 |
memory = psutil.virtual_memory()
|
127 |
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
st.markdown(f"๐งฎ **Model `{generator}`**")
|
130 |
time_start = time.time()
|
131 |
result, params_used = generator.generate(
|
|
|
5 |
import streamlit as st
|
6 |
|
7 |
from generator import GeneratorFactory
|
8 |
+
from langdetect import detect
|
9 |
+
from default_texts import default_texts
|
10 |
|
11 |
device = torch.cuda.device_count() - 1
|
12 |
|
13 |
+
TRANSLATION_EN_TO_NL = "translation_en_to_nl"
|
14 |
+
TRANSLATION_NL_TO_EN = "translation_nl_to_en"
|
15 |
|
16 |
GENERATOR_LIST = [
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
{
|
18 |
"model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
|
19 |
"desc": "T5 small nl24 ccmatrix en->nl",
|
20 |
+
"task": TRANSLATION_EN_TO_NL,
|
21 |
"split_sentences": True,
|
22 |
},
|
23 |
{
|
24 |
+
"model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
|
25 |
+
"desc": "T5 small nl24 ccmatrix nl-en",
|
26 |
"task": TRANSLATION_NL_TO_EN,
|
27 |
+
"split_sentences": True,
|
28 |
},
|
29 |
{
|
30 |
+
"model_name": "Helsinki-NLP/opus-mt-en-nl",
|
31 |
+
"desc": "Opus MT en->nl",
|
32 |
+
"task": TRANSLATION_EN_TO_NL,
|
33 |
"split_sentences": True,
|
34 |
},
|
35 |
# {
|
36 |
+
# "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl",
|
37 |
+
# "desc": "longT5 large nl8 256cc/512beta/512l en->nl",
|
38 |
+
# "task": TRANSLATION_EN_TO_NL,
|
39 |
+
# "split_sentences": False,
|
40 |
+
# },
|
41 |
+
# {
|
42 |
+
# "model_name": "yhavinga/byt5-small-ccmatrix-en-nl",
|
43 |
+
# "desc": "ByT5 small ccmatrix en->nl",
|
44 |
+
# "task": TRANSLATION_EN_TO_NL,
|
45 |
+
# "split_sentences": True,
|
46 |
+
# },
|
47 |
+
# {
|
48 |
# "model_name": "yhavinga/t5-eff-large-8l-nedd-en-nl",
|
49 |
# "desc": "T5 eff large nl8 en->nl",
|
50 |
+
# "task": TRANSLATION_EN_TO_NL,
|
51 |
# "split_sentences": True,
|
52 |
# },
|
53 |
# {
|
54 |
# "model_name": "yhavinga/t5-base-36L-ccmatrix-multi",
|
55 |
# "desc": "T5 base nl36 ccmatrix en->nl",
|
56 |
+
# "task": TRANSLATION_EN_TO_NL,
|
57 |
# "split_sentences": True,
|
58 |
# },
|
59 |
# {
|
60 |
# "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl",
|
61 |
# "desc": "longT5 large nl8 512beta/512l en->nl",
|
62 |
+
# "task": TRANSLATION_EN_TO_NL,
|
63 |
# "split_sentences": False,
|
64 |
# },
|
65 |
# {
|
66 |
# "model_name": "yhavinga/t5-base-36L-nedd-x-en-nl-300",
|
67 |
# "desc": "T5 base 36L nedd en->nl 300",
|
68 |
+
# "task": TRANSLATION_EN_TO_NL,
|
69 |
# "split_sentences": True,
|
70 |
# },
|
71 |
# {
|
72 |
# "model_name": "yhavinga/long-t5-local-small-ccmatrix-en-nl",
|
73 |
# "desc": "longT5 small ccmatrix en->nl",
|
74 |
+
# "task": TRANSLATION_EN_TO_NL,
|
75 |
# "split_sentences": True,
|
76 |
# },
|
77 |
]
|
|
|
97 |
Vertaal van en naar Engels"""
|
98 |
)
|
99 |
st.sidebar.title("Parameters:")
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
+
default_text = st.sidebar.radio(
|
102 |
+
"Change default text",
|
103 |
+
tuple(default_texts.keys()),
|
104 |
+
index=0,
|
105 |
+
)
|
106 |
|
107 |
+
if default_text or "prompt_box" not in st.session_state:
|
108 |
+
st.session_state["prompt_box"] = default_texts[default_text]["text"]
|
109 |
|
110 |
+
text_area = st.text_area("Enter text", st.session_state.prompt_box, height=300)
|
111 |
+
st.session_state["text"] = text_area
|
|
|
|
|
112 |
num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=1)
|
113 |
num_beam_groups = st.sidebar.number_input(
|
114 |
"Num beam groups", min_value=1, max_value=10, value=1
|
|
|
116 |
length_penalty = st.sidebar.number_input(
|
117 |
"Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
|
118 |
)
|
119 |
+
|
120 |
st.sidebar.markdown(
|
121 |
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
|
122 |
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
|
|
|
133 |
if st.button("Run"):
|
134 |
memory = psutil.virtual_memory()
|
135 |
|
136 |
+
language = detect(st.session_state.text)
|
137 |
+
if language == "en":
|
138 |
+
task = TRANSLATION_EN_TO_NL
|
139 |
+
elif language == "nl":
|
140 |
+
task = TRANSLATION_NL_TO_EN
|
141 |
+
else:
|
142 |
+
st.error(f"Language {language} not supported")
|
143 |
+
return
|
144 |
+
|
145 |
+
# Num beam groups should be a divisor of num beams
|
146 |
+
if num_beams % num_beam_groups != 0:
|
147 |
+
st.error("Num beams should be a multiple of num beam groups")
|
148 |
+
return
|
149 |
+
|
150 |
+
for generator in generators.filter(task=task):
|
151 |
st.markdown(f"๐งฎ **Model `{generator}`**")
|
152 |
time_start = time.time()
|
153 |
result, params_used = generator.generate(
|
default_texts.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
default_texts = {
|
2 |
+
"The Invisible Censor": {
|
3 |
+
"url": "https://www.gutenberg.org/files/35091/35091-h/35091-h.html",
|
4 |
+
"year": 1921,
|
5 |
+
"text": """It was a wet, gusty night and I had a lonely walk home. By taking the river road, though I hated it, I saved two miles, so I sloshed ahead trying not to think at all. Through the barbed wire fence I could see the racing river. Its black swollen body writhed along with extraordinary swiftness, breathlessly silent, only occasionally making a swishing ripple. I did not enjoy looking at it. I was somehow afraid.
|
6 |
+
|
7 |
+
And there, at the end of the river road where I swerved off, a figure stood waiting for me, motionless and enigmatic. I had to meet it or turn back.
|
8 |
+
|
9 |
+
It was a quite young girl, unknown to me, with a hood over her head, and with large unhappy eyes.
|
10 |
+
|
11 |
+
โMy father is very ill,โ she said without a word of introduction. โThe nurse is frightened. Could you come in and help?โ""",
|
12 |
+
},
|
13 |
+
"Gedachten": {
|
14 |
+
"url": "https://www.dbnl.org/tekst/eede003geda01_01/eede003geda01_01_0001.php",
|
15 |
+
"year": 1920,
|
16 |
+
"text": """Verdraagzaamheid en gewetensvrijheid brengen voor velen mee, dat ze noodwendig hun overtuiging een beetje onderdrukken, dat ze hun ware meening verzwijgen.
|
17 |
+
|
18 |
+
Ieder gezond, jong mensch voelt op een zekeren leeftijd dat zijn hoogste aspiraties, zijn beste capaciteiten worden neergehouden en doodgedrukt door wat men maatschappelijke verplichting noemt.
|
19 |
+
|
20 |
+
Drijf uw kinderen niet met dweepzucht of bekrompenheid in dezen of genen afgepaalden weg, maar geef hen ruimte, zoek voor hun geest alle voedsel, dat zij van nature behoeven en waarbij zij gedijen en vreugde vinden.
|
21 |
+
""",
|
22 |
+
},
|
23 |
+
}
|
generator.py
CHANGED
@@ -124,7 +124,7 @@ class Generator:
|
|
124 |
return decoded_preds[0], generate_kwargs
|
125 |
|
126 |
def __str__(self):
|
127 |
-
return self.
|
128 |
|
129 |
|
130 |
class GeneratorFactory:
|
@@ -150,5 +150,9 @@ class GeneratorFactory:
|
|
150 |
def __iter__(self):
|
151 |
return iter(self.generators)
|
152 |
|
153 |
-
def
|
154 |
-
return [
|
|
|
|
|
|
|
|
|
|
124 |
return decoded_preds[0], generate_kwargs
|
125 |
|
126 |
def __str__(self):
|
127 |
+
return self.model_name
|
128 |
|
129 |
|
130 |
class GeneratorFactory:
|
|
|
150 |
def __iter__(self):
|
151 |
return iter(self.generators)
|
152 |
|
153 |
+
def filter(self, **kwargs):
|
154 |
+
return [
|
155 |
+
g
|
156 |
+
for g in self.generators
|
157 |
+
if all([g.__dict__.get(k) == v for k, v in kwargs.items()])
|
158 |
+
]
|