Bram Vanroy commited on
Commit
11f8e5f
Β·
1 Parent(s): 1e0a2f8

add multilingual support and cache translations

Browse files
Files changed (2) hide show
  1. app.py +11 -9
  2. utils.py +26 -3
app.py CHANGED
@@ -8,24 +8,26 @@ from transformers import LogitsProcessorList
8
 
9
  import streamlit as st
10
 
11
- from utils import get_resources
12
 
13
- model, tokenizer, logitsprocessor = get_resources()
14
 
15
- st.title("πŸ“ Parse text into AMR")
 
 
 
 
16
 
17
- text = st.text_input(label="Text to transform (en)")
18
-
19
- if text:
20
  gen_kwargs = {
21
  "max_length": model.config.max_length,
22
  "num_beams": model.config.num_beams,
23
  "logits_processor": LogitsProcessorList([logitsprocessor])
24
  }
25
 
26
- encoded = tokenizer(text, return_tensors="pt")
27
- generated = model.generate(**encoded, **gen_kwargs)
28
- linearized = tokenizer.decode_and_fix(generated)[0]
29
  penman_str = linearized2penmanstr(linearized)
30
 
31
  try:
 
8
 
9
  import streamlit as st
10
 
11
+ from utils import get_resources, LANGUAGES, translate
12
 
13
+ st.title("πŸ‘©β€πŸ’» Generate AMR from multilingual text")
14
 
15
+ with st.form("input data"):
16
+ text_col, lang_col = st.columns((4, 1))
17
+ text = text_col.text_input(label="Input text")
18
+ src_lang = lang_col.selectbox(label="Language", options=list(LANGUAGES.keys()), index=0)
19
+ submitted = st.form_submit_button("Submit")
20
 
21
+ if submitted:
22
+ multilingual = src_lang != "English"
23
+ model, tokenizer, logitsprocessor = get_resources(multilingual)
24
  gen_kwargs = {
25
  "max_length": model.config.max_length,
26
  "num_beams": model.config.num_beams,
27
  "logits_processor": LogitsProcessorList([logitsprocessor])
28
  }
29
 
30
+ linearized = translate(text, src_lang, model, tokenizer, **gen_kwargs)
 
 
31
  penman_str = linearized2penmanstr(linearized)
32
 
33
  try:
utils.py CHANGED
@@ -16,9 +16,15 @@ st_hash_funcs = {PreTrainedModel: lambda model: model.name_or_path,
16
 
17
 
18
  @st.cache(show_spinner=False, hash_funcs=st_hash_funcs, allow_output_mutation=True)
19
- def get_resources(quantize: bool = True):
20
- tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-to-amr", src_lang="en_XX")
21
- model = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-to-amr")
 
 
 
 
 
 
22
  model = BetterTransformer.transform(model, keep_original_model=False)
23
  model.resize_token_embeddings(len(tokenizer))
24
 
@@ -28,3 +34,20 @@ def get_resources(quantize: bool = True):
28
  logits_processor = AMRLogitsProcessor(tokenizer, model.config.max_length)
29
 
30
  return model, tokenizer, logits_processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  @st.cache(show_spinner=False, hash_funcs=st_hash_funcs, allow_output_mutation=True)
19
+ def get_resources(multilingual: bool, quantize: bool = True):
20
+ if multilingual:
21
+ # Tokenizer src_lang is reset during translation to the right language
22
+ tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-es-nl-to-amr", src_lang="nl_XX")
23
+ model = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-es-nl-to-amr")
24
+ else:
25
+ tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-to-amr", src_lang="en_XX")
26
+ model = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-to-amr")
27
+
28
  model = BetterTransformer.transform(model, keep_original_model=False)
29
  model.resize_token_embeddings(len(tokenizer))
30
 
 
34
  logits_processor = AMRLogitsProcessor(tokenizer, model.config.max_length)
35
 
36
  return model, tokenizer, logits_processor
37
+
38
+
39
+ @st.cache(show_spinner=False, hash_funcs=st_hash_funcs)
40
+ def translate(text: str, src_lang: str, model: MBartForConditionalGeneration, tokenizer: AMRMBartTokenizer, **gen_kwargs):
41
+ tokenizer.src_lang = LANGUAGES[src_lang]
42
+ encoded = tokenizer(text, return_tensors="pt")
43
+ print(tokenizer.convert_ids_to_tokens(encoded.input_ids[0]))
44
+ print(model.name_or_path)
45
+ generated = model.generate(**encoded, **gen_kwargs)
46
+ return tokenizer.decode_and_fix(generated)[0]
47
+
48
+
49
+ LANGUAGES = {
50
+ "English": "en_XX",
51
+ "Dutch": "nl_XX",
52
+ "Spanish": "es_XX",
53
+ }