File size: 1,305 Bytes
1e0a2f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import streamlit as st

from torch.quantization import quantize_dynamic
from torch import nn, qint8
from torch.nn import Parameter
from transformers import PreTrainedModel, PreTrainedTokenizer
from optimum.bettertransformer import BetterTransformer
from mbart_amr.constraints.constraints import AMRLogitsProcessor
from mbart_amr.data.tokenization import AMRMBartTokenizer
from transformers import MBartForConditionalGeneration


st_hash_funcs = {PreTrainedModel: lambda model: model.name_or_path,
                 PreTrainedTokenizer: lambda tokenizer: tokenizer.name_or_path,
                 Parameter: lambda param: param.data}


@st.cache(show_spinner=False, hash_funcs=st_hash_funcs, allow_output_mutation=True)
def get_resources(quantize: bool = True):
    tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-to-amr", src_lang="en_XX")
    model = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-to-amr")
    model = BetterTransformer.transform(model, keep_original_model=False)
    model.resize_token_embeddings(len(tokenizer))

    if quantize:
        model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)

    logits_processor = AMRLogitsProcessor(tokenizer, model.config.max_length)

    return model, tokenizer, logits_processor