from typing import List, Tuple import streamlit as st import torch from optimum.bettertransformer import BetterTransformer from torch import nn, qint8 from torch.quantization import quantize_dynamic from transformers import T5ForConditionalGeneration, T5Tokenizer @st.cache_resource(show_spinner=False) def get_resources(quantize: bool = True, no_cuda: bool = False) -> Tuple[T5ForConditionalGeneration, T5Tokenizer]: """Load a T5 model and its (slow) tokenizer""" tokenizer = T5Tokenizer.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023", use_fast=False) model = T5ForConditionalGeneration.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023") model = BetterTransformer.transform(model, keep_original_model=False) model.resize_token_embeddings(len(tokenizer)) if torch.cuda.is_available() and not no_cuda: model = model.to("cuda") elif quantize: # Quantization not supported on CUDA model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8) model.eval() return model, tokenizer def batchify(iterable, batch_size=16): """Turn an iterable in a batch generator :param iterable: iterable to batchify :param batch_size: batch size """ num_items = len(iterable) for idx in range(0, num_items, batch_size): yield iterable[idx : min(idx + batch_size, num_items)] def simplify( texts: List[str], model: T5ForConditionalGeneration, tokenizer: T5Tokenizer, batch_size: int = 16 ) -> List[str]: """Simplify a given set of texts with a given model and tokenizer. Yields results in batches of 'batch_size' :param texts: texts to simplify :param model: model to use for simplification :param tokenizer: tokenizer to use for simplification :param batch_size: batch size to yield results in """ for batch_texts in batchify(texts, batch_size=batch_size): nlg_batch_texts = ["[NLG] " + text for text in batch_texts] encoded = tokenizer(nlg_batch_texts, return_tensors="pt", padding=True, truncation=True) encoded = {k: v.to(model.device) for k, v in encoded.items()} gen_kwargs = { "max_new_tokens": 128, "num_beams": 3, } with torch.no_grad(): encoded = {k: v.to(model.device) for k, v in encoded.items()} generated = model.generate(**encoded, **gen_kwargs).cpu() yield batch_texts, tokenizer.batch_decode(generated, skip_special_tokens=True)