Spaces:
Running
Running
from typing import Tuple, Union, Dict, List | |
from multi_amr.data.postprocessing_graph import ParsedStatus | |
from multi_amr.data.tokenization import AMRTokenizerWrapper | |
from optimum.bettertransformer import BetterTransformer | |
import penman | |
import streamlit as st | |
import torch | |
from torch.quantization import quantize_dynamic | |
from torch import nn, qint8 | |
from transformers import MBartForConditionalGeneration, AutoConfig | |
def get_resources(multilingual: bool, src_lang: str, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRTokenizerWrapper]: | |
"""Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual | |
model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized | |
for better performance. | |
:param multilingual: whether to load the multilingual model or not | |
:param src_lang: source language | |
:param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic' | |
:param no_cuda: whether to disable CUDA, even if it is available | |
:return: the loaded model, and tokenizer wrapper | |
""" | |
model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en_es_nl" | |
if not multilingual: | |
if src_lang == "English": | |
model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en" | |
elif src_lang == "Spanish": | |
model_name = "BramVanroy/mbart-large-cc25-ft-amr30-es" | |
elif src_lang == "Dutch": | |
model_name = "BramVanroy/mbart-large-cc25-ft-amr30-nl" | |
else: | |
raise ValueError(f"Language {src_lang} not supported") | |
# Tokenizer src_lang is reset during translation to the right language | |
tok_wrapper = AMRTokenizerWrapper.from_pretrained(model_name, src_lang="en_XX") | |
config = AutoConfig.from_pretrained(model_name) | |
config.decoder_start_token_id = tok_wrapper.amr_token_id | |
model = MBartForConditionalGeneration.from_pretrained(model_name, config=config) | |
model.eval() | |
embedding_size = model.get_input_embeddings().weight.shape[0] | |
if len(tok_wrapper.tokenizer) > embedding_size: | |
model.resize_token_embeddings(len(tok_wrapper.tokenizer)) | |
model = BetterTransformer.transform(model, keep_original_model=False) | |
if torch.cuda.is_available() and not no_cuda: | |
model = model.to("cuda") | |
elif quantize: # Quantization not supported on CUDA | |
model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8) | |
return model, tok_wrapper | |
def translate(texts: List[str], src_lang: str, model: MBartForConditionalGeneration, tok_wrapper: AMRTokenizerWrapper, **gen_kwargs) -> Dict[str, List[Union[penman.Graph, ParsedStatus]]]: | |
"""Translates a given text of a given source language with a given model and tokenizer. The generation is guided by | |
potential keyword-arguments, which can include arguments such as max length, logits processors, etc. | |
:param texts: source text to translate (potentially a batch) | |
:param src_lang: source language | |
:param model: MBART model | |
:param tok_wrapper: MBART tokenizer wrapper | |
:param gen_kwargs: potential keyword arguments for the generation process | |
:return: the translation (linearized AMR graph) | |
""" | |
if isinstance(texts, str): | |
texts = [texts] | |
tok_wrapper.src_lang = LANGUAGES[src_lang] | |
encoded = tok_wrapper(texts, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
generated = model.generate(**encoded, output_scores=True, return_dict_in_generate=True, **gen_kwargs) | |
generated["sequences"] = generated["sequences"].cpu() | |
generated["sequences_scores"] = generated["sequences_scores"].cpu() | |
best_scoring_results = {"graph": [], "status": []} | |
beam_size = gen_kwargs["num_beams"] | |
# Select the best item from the beam: the sequence with best status and highest score | |
for sample_idx in range(0, len(generated["sequences_scores"]), beam_size): | |
sequences = generated["sequences"][sample_idx: sample_idx + beam_size] | |
scores = generated["sequences_scores"][sample_idx: sample_idx + beam_size].tolist() | |
outputs = tok_wrapper.batch_decode_amr_ids(sequences) | |
statuses = outputs["status"] | |
graphs = outputs["graph"] | |
zipped = zip(statuses, scores, graphs) | |
# Lowest status first (OK=0, FIXED=1, BACKOFF=2), highest score second | |
best = sorted(zipped, key=lambda item: (item[0].value, -item[1]))[0] | |
best_scoring_results["graph"].append(best[2]) | |
best_scoring_results["status"].append(best[0]) | |
# Returns dictionary with "graph" and "status" keys | |
return best_scoring_results | |
LANGUAGES = { | |
"English": "en_XX", | |
"Dutch": "nl_XX", | |
"Spanish": "es_XX", | |
} | |