Bram Vanroy commited on
Commit
f8b0e70
β€’
1 Parent(s): 55fbc57

add docstrings

Browse files
Files changed (1) hide show
  1. utils.py +22 -4
utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import streamlit as st
2
 
3
  from torch.quantization import quantize_dynamic
@@ -16,7 +18,15 @@ st_hash_funcs = {PreTrainedModel: lambda model: model.name_or_path,
16
 
17
 
18
  @st.cache(show_spinner=False, hash_funcs=st_hash_funcs, allow_output_mutation=True)
19
- def get_resources(multilingual: bool, quantize: bool = True):
 
 
 
 
 
 
 
 
20
  if multilingual:
21
  # Tokenizer src_lang is reset during translation to the right language
22
  tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-es-nl-to-amr", src_lang="nl_XX")
@@ -37,11 +47,19 @@ def get_resources(multilingual: bool, quantize: bool = True):
37
 
38
 
39
  @st.cache(show_spinner=False, hash_funcs=st_hash_funcs)
40
- def translate(text: str, src_lang: str, model: MBartForConditionalGeneration, tokenizer: AMRMBartTokenizer, **gen_kwargs):
 
 
 
 
 
 
 
 
 
 
41
  tokenizer.src_lang = LANGUAGES[src_lang]
42
  encoded = tokenizer(text, return_tensors="pt")
43
- print(tokenizer.convert_ids_to_tokens(encoded.input_ids[0]))
44
- print(model.name_or_path)
45
  generated = model.generate(**encoded, **gen_kwargs)
46
  return tokenizer.decode_and_fix(generated)[0]
47
 
 
1
+ from typing import Tuple
2
+
3
  import streamlit as st
4
 
5
  from torch.quantization import quantize_dynamic
 
18
 
19
 
20
  @st.cache(show_spinner=False, hash_funcs=st_hash_funcs, allow_output_mutation=True)
21
+ def get_resources(multilingual: bool, quantize: bool = True) -> Tuple[MBartForConditionalGeneration, AMRMBartTokenizer, AMRLogitsProcessor]:
22
+ """Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
23
+ model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
24
+ for better performance.
25
+
26
+ :param multilingual: whether or not to load the multilingual model. If not, loads the English-only model
27
+ :param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic'
28
+ :return: the loaded model, tokenizer, and logits processor
29
+ """
30
  if multilingual:
31
  # Tokenizer src_lang is reset during translation to the right language
32
  tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-es-nl-to-amr", src_lang="nl_XX")
 
47
 
48
 
49
  @st.cache(show_spinner=False, hash_funcs=st_hash_funcs)
50
+ def translate(text: str, src_lang: str, model: MBartForConditionalGeneration, tokenizer: AMRMBartTokenizer, **gen_kwargs) -> str:
51
+ """Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
52
+ potential keyword-arguments, which can include arguments such as max length, logits processors, etc.
53
+
54
+ :param text: source text to translate
55
+ :param src_lang: source language
56
+ :param model: MBART model
57
+ :param tokenizer: MBART tokenizer
58
+ :param gen_kwargs: potential keyword arguments for the generation process
59
+ :return: the translation (linearized AMR graph)
60
+ """
61
  tokenizer.src_lang = LANGUAGES[src_lang]
62
  encoded = tokenizer(text, return_tensors="pt")
 
 
63
  generated = model.generate(**encoded, **gen_kwargs)
64
  return tokenizer.decode_and_fix(generated)[0]
65