import streamlit as st from torch.quantization import quantize_dynamic from torch import nn, qint8 from torch.nn import Parameter from transformers import PreTrainedModel, PreTrainedTokenizer from optimum.bettertransformer import BetterTransformer from mbart_amr.constraints.constraints import AMRLogitsProcessor from mbart_amr.data.tokenization import AMRMBartTokenizer from transformers import MBartForConditionalGeneration st_hash_funcs = {PreTrainedModel: lambda model: model.name_or_path, PreTrainedTokenizer: lambda tokenizer: tokenizer.name_or_path, Parameter: lambda param: param.data} @st.cache(show_spinner=False, hash_funcs=st_hash_funcs, allow_output_mutation=True) def get_resources(quantize: bool = True): tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-to-amr", src_lang="en_XX") model = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-to-amr") model = BetterTransformer.transform(model, keep_original_model=False) model.resize_token_embeddings(len(tokenizer)) if quantize: model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8) logits_processor = AMRLogitsProcessor(tokenizer, model.config.max_length) return model, tokenizer, logits_processor