Spaces:
Running
Running
import streamlit as st | |
from torch.quantization import quantize_dynamic | |
from torch import nn, qint8 | |
from torch.nn import Parameter | |
from transformers import PreTrainedModel, PreTrainedTokenizer | |
from optimum.bettertransformer import BetterTransformer | |
from mbart_amr.constraints.constraints import AMRLogitsProcessor | |
from mbart_amr.data.tokenization import AMRMBartTokenizer | |
from transformers import MBartForConditionalGeneration | |
st_hash_funcs = {PreTrainedModel: lambda model: model.name_or_path, | |
PreTrainedTokenizer: lambda tokenizer: tokenizer.name_or_path, | |
Parameter: lambda param: param.data} | |
def get_resources(quantize: bool = True): | |
tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-to-amr", src_lang="en_XX") | |
model = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-to-amr") | |
model = BetterTransformer.transform(model, keep_original_model=False) | |
model.resize_token_embeddings(len(tokenizer)) | |
if quantize: | |
model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8) | |
logits_processor = AMRLogitsProcessor(tokenizer, model.config.max_length) | |
return model, tokenizer, logits_processor | |