from typing import Tuple, Union, Dict, List from multi_amr.data.postprocessing_graph import ParsedStatus from multi_amr.data.tokenization import AMRTokenizerWrapper from optimum.bettertransformer import BetterTransformer import penman import streamlit as st import torch from torch.quantization import quantize_dynamic from torch import nn, qint8 from transformers import MBartForConditionalGeneration, AutoConfig import spaces @st.cache_resource(show_spinner=False) def get_resources(multilingual: bool, src_lang: str, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRTokenizerWrapper]: """Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized for better performance. :param multilingual: whether to load the multilingual model or not :param src_lang: source language :param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic' :param no_cuda: whether to disable CUDA, even if it is available :return: the loaded model, and tokenizer wrapper """ model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en_es_nl" if not multilingual: if src_lang == "English": model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en" elif src_lang == "Spanish": model_name = "BramVanroy/mbart-large-cc25-ft-amr30-es" elif src_lang == "Dutch": model_name = "BramVanroy/mbart-large-cc25-ft-amr30-nl" else: raise ValueError(f"Language {src_lang} not supported") # Tokenizer src_lang is reset during translation to the right language tok_wrapper = AMRTokenizerWrapper.from_pretrained(model_name, src_lang="en_XX") config = AutoConfig.from_pretrained(model_name) config.decoder_start_token_id = tok_wrapper.amr_token_id model = MBartForConditionalGeneration.from_pretrained(model_name, config=config) model.eval() embedding_size = model.get_input_embeddings().weight.shape[0] if len(tok_wrapper.tokenizer) > embedding_size: model.resize_token_embeddings(len(tok_wrapper.tokenizer)) model = BetterTransformer.transform(model, keep_original_model=False) if torch.cuda.is_available() and not no_cuda: model = model.to("cuda") elif quantize: # Quantization not supported on CUDA model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8) return model, tok_wrapper @spaces.GPU def translate(texts: List[str], src_lang: str, model: MBartForConditionalGeneration, tok_wrapper: AMRTokenizerWrapper, **gen_kwargs) -> Dict[str, List[Union[penman.Graph, ParsedStatus]]]: """Translates a given text of a given source language with a given model and tokenizer. The generation is guided by potential keyword-arguments, which can include arguments such as max length, logits processors, etc. :param texts: source text to translate (potentially a batch) :param src_lang: source language :param model: MBART model :param tok_wrapper: MBART tokenizer wrapper :param gen_kwargs: potential keyword arguments for the generation process :return: the translation (linearized AMR graph) """ if isinstance(texts, str): texts = [texts] tok_wrapper.src_lang = LANGUAGES[src_lang] encoded = tok_wrapper(texts, return_tensors="pt").to(model.device) with torch.no_grad(): generated = model.generate(**encoded, output_scores=True, return_dict_in_generate=True, **gen_kwargs) generated["sequences"] = generated["sequences"].cpu() generated["sequences_scores"] = generated["sequences_scores"].cpu() best_scoring_results = {"graph": [], "status": []} beam_size = gen_kwargs["num_beams"] # Select the best item from the beam: the sequence with best status and highest score for sample_idx in range(0, len(generated["sequences_scores"]), beam_size): sequences = generated["sequences"][sample_idx: sample_idx + beam_size] scores = generated["sequences_scores"][sample_idx: sample_idx + beam_size].tolist() outputs = tok_wrapper.batch_decode_amr_ids(sequences) statuses = outputs["status"] graphs = outputs["graph"] zipped = zip(statuses, scores, graphs) # Lowest status first (OK=0, FIXED=1, BACKOFF=2), highest score second best = sorted(zipped, key=lambda item: (item[0].value, -item[1]))[0] best_scoring_results["graph"].append(best[2]) best_scoring_results["status"].append(best[0]) # Returns dictionary with "graph" and "status" keys return best_scoring_results LANGUAGES = { "English": "en_XX", "Dutch": "nl_XX", "Spanish": "es_XX", }