Bram Vanroy commited on
Commit
1e0a2f8
Β·
1 Parent(s): f3fd096

use cache and quantization

Browse files
Files changed (2) hide show
  1. app.py +10 -24
  2. utils.py +30 -0
app.py CHANGED
@@ -1,45 +1,31 @@
1
  from collections import Counter
2
 
3
  import graphviz
4
- from optimum.bettertransformer import BetterTransformer
5
  import penman
6
  from penman.models.noop import NoOpModel
7
- from mbart_amr.constraints.constraints import AMRLogitsProcessor
8
  from mbart_amr.data.linearization import linearized2penmanstr
9
- from mbart_amr.data.tokenization import AMRMBartTokenizer
10
- from transformers import MBartForConditionalGeneration, LogitsProcessorList
11
 
12
  import streamlit as st
13
 
14
- if "logits_processor" not in st.session_state:
15
- st.session_state["logits_processor"] = None
16
 
17
- if "tokenizer" not in st.session_state:
18
- st.session_state["tokenizer"] = None
19
-
20
- if "model" not in st.session_state:
21
- st.session_state["tokenizer"] = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-to-amr", src_lang="en_XX")
22
- st.session_state["model"] = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-to-amr")
23
- st.session_state["model"] = BetterTransformer.transform(st.session_state["model"], keep_original_model=False)
24
- st.session_state["model"].resize_token_embeddings(len(st.session_state["tokenizer"]))
25
- st.session_state["logits_processor"] = AMRLogitsProcessor(st.session_state["tokenizer"],
26
- st.session_state["model"].config.max_length)
27
 
28
  st.title("πŸ“ Parse text into AMR")
29
 
30
  text = st.text_input(label="Text to transform (en)")
31
 
32
- if text and "model" in st.session_state:
33
  gen_kwargs = {
34
- "max_length": st.session_state["model"].config.max_length,
35
- "num_beams": st.session_state["model"].config.num_beams,
36
- "logits_processor": LogitsProcessorList([st.session_state["logits_processor"]]) if st.session_state[
37
- "logits_processor"] else None
38
  }
39
 
40
- encoded = st.session_state["tokenizer"](text, return_tensors="pt")
41
- generated = st.session_state["model"].generate(**encoded, **gen_kwargs)
42
- linearized = st.session_state["tokenizer"].decode_and_fix(generated)[0]
43
  penman_str = linearized2penmanstr(linearized)
44
 
45
  try:
 
1
  from collections import Counter
2
 
3
  import graphviz
 
4
  import penman
5
  from penman.models.noop import NoOpModel
 
6
  from mbart_amr.data.linearization import linearized2penmanstr
7
+ from transformers import LogitsProcessorList
 
8
 
9
  import streamlit as st
10
 
11
+ from utils import get_resources
 
12
 
13
+ model, tokenizer, logitsprocessor = get_resources()
 
 
 
 
 
 
 
 
 
14
 
15
  st.title("πŸ“ Parse text into AMR")
16
 
17
  text = st.text_input(label="Text to transform (en)")
18
 
19
+ if text:
20
  gen_kwargs = {
21
+ "max_length": model.config.max_length,
22
+ "num_beams": model.config.num_beams,
23
+ "logits_processor": LogitsProcessorList([logitsprocessor])
 
24
  }
25
 
26
+ encoded = tokenizer(text, return_tensors="pt")
27
+ generated = model.generate(**encoded, **gen_kwargs)
28
+ linearized = tokenizer.decode_and_fix(generated)[0]
29
  penman_str = linearized2penmanstr(linearized)
30
 
31
  try:
utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from torch.quantization import quantize_dynamic
4
+ from torch import nn, qint8
5
+ from torch.nn import Parameter
6
+ from transformers import PreTrainedModel, PreTrainedTokenizer
7
+ from optimum.bettertransformer import BetterTransformer
8
+ from mbart_amr.constraints.constraints import AMRLogitsProcessor
9
+ from mbart_amr.data.tokenization import AMRMBartTokenizer
10
+ from transformers import MBartForConditionalGeneration
11
+
12
+
13
+ st_hash_funcs = {PreTrainedModel: lambda model: model.name_or_path,
14
+ PreTrainedTokenizer: lambda tokenizer: tokenizer.name_or_path,
15
+ Parameter: lambda param: param.data}
16
+
17
+
18
+ @st.cache(show_spinner=False, hash_funcs=st_hash_funcs, allow_output_mutation=True)
19
+ def get_resources(quantize: bool = True):
20
+ tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-to-amr", src_lang="en_XX")
21
+ model = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-to-amr")
22
+ model = BetterTransformer.transform(model, keep_original_model=False)
23
+ model.resize_token_embeddings(len(tokenizer))
24
+
25
+ if quantize:
26
+ model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)
27
+
28
+ logits_processor = AMRLogitsProcessor(tokenizer, model.config.max_length)
29
+
30
+ return model, tokenizer, logits_processor