from transformers import AutoModelWithLMHead, AutoTokenizer from transformers import GPT2LMHeadModel, GPT2Tokenizer from diffusers import DiffusionPipeline import torch from tqdm import tqdm import pandas as pd import numpy as np import random from utils import mpnet_embed_class, get_concreteness, Collate_t5 from torch.utils.data import DataLoader from utils import SentenceDataset class Summagery: def __init__(self, t5_checkpoint, batch_size=5, abstractness=.4, max_d_length=1256, num_prompt=3, device='cuda'): # ViPE: Visualize Pretty-much Everything self.vipe_model = GPT2LMHeadModel.from_pretrained('fittar/ViPE-M-CTX7') vipe_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium') vipe_tokenizer.pad_token = vipe_tokenizer.eos_token self.vipe_tokenizer = vipe_tokenizer # SDXL, load both base & refiner self.basexl = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True ) self.refinerxl = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-refiner-1.0", text_encoder_2=self.basexl.text_encoder_2, vae=self.basexl.vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16", ) self.device = device self.max_d_length = max_d_length # maximum document length to handle before chunking self.final_document_length = 60 self.num_prompt = num_prompt # how many prompts to generate per document self.abstractness = abstractness # to explore the prompts , just a handle from 0 to 1 self.concreteness_dataset = './data/concreteness.csv' self.batch_size = batch_size # T5 self.t5_model = AutoModelWithLMHead.from_pretrained(t5_checkpoint) self.t5_tokenizer = AutoTokenizer.from_pretrained(t5_checkpoint, model_max_length=max_d_length) self.collate_t5 = Collate_t5(self.t5_tokenizer) # for concrteness rating of the prompts data = pd.read_csv(self.concreteness_dataset, header=0, delimiter='\t') self.word2score = {w: s for w, s in zip(data['WORD'], data['RATING'])} # for large documents, divide them into chunks with self.max_d_length size def document_preprocess(self, document): documents = [] words = document.split() if len(words) <= self.max_d_length: return [document] start = 0 while (len(words) > start): if len(words) > start + self.max_d_length: chunk = ' '.join(words[start:start + self.max_d_length]) else: chunk = ' '.join(words[start:]) start += self.max_d_length documents.append(chunk) return documents def t5_summarize(self, document): continue_summarization = True if len(document.split()) <= self.final_document_length: return document self.t5_model.to(self.device) documents = self.document_preprocess(document) if len(documents) > self.batch_size: # use batch inference to make things faster while (continue_summarization): dataset = SentenceDataset(documents) dataloader = DataLoader(dataset, batch_size=self.batch_size, collate_fn=self.collate_t5, num_workers=2) summaries = '' print('summarizing...') for text_batch, batch in tqdm(dataloader): if batch.input_ids.shape[1] > 5: max_length = int(batch.input_ids.shape[1] / 2) # summarize the current chunk by half if max_length < self.final_document_length: # unless max_length is too short max_length = self.final_document_length batch = batch.to(self.device) generated_ids = self.t5_model.generate(input_ids=batch.input_ids, attention_mask=batch.attention_mask, num_beams=3, max_length=max_length, repetition_penalty=2.5, length_penalty=1.0, early_stopping=True) preds = \ [self.t5_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids] for pred in preds: summaries = summaries + pred + '. ' else: for chunk in text_batch: summaries = summaries + chunk + '. ' if len(summaries.split()) <= self.final_document_length: continue_summarization = False print('finished summarizing.') else: documents = self.document_preprocess(summaries) else: # skip batch inference since we only have a few documents while (continue_summarization): summaries = '' print('summarizing...') for chunk in tqdm(documents): if len(chunk.split()) > 2: max_length = int(len(chunk.split()) / 2) # summarize the current chunk by half if max_length < self.final_document_length: # unless max_length is too short max_length = self.final_document_length input_ids = self.t5_tokenizer.encode('summarize: ' + chunk, return_tensors="pt", add_special_tokens=True, padding='longest', max_length=self.max_d_length) input_ids = input_ids.to(self.device) generated_ids = self.t5_model.generate(input_ids=input_ids, num_beams=3, max_length=max_length, repetition_penalty=2.5, length_penalty=1.0, early_stopping=True) pred = \ [self.t5_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids][0] summaries = summaries + pred + '. ' else: summaries = summaries + chunk + '. ' if len(summaries.split()) <= self.final_document_length: continue_summarization = False print('finished summarizing.') else: documents = self.document_preprocess(summaries) return summaries def vipe_generate(self, summary, do_sample=True, top_k=100, epsilon_cutoff=.00005, temperature=1): batch_size = random.choice([20, 40, 60]) input_text = [summary] * batch_size # mark the text with special tokens input_text = [self.vipe_tokenizer.eos_token + i + self.vipe_tokenizer.eos_token for i in input_text] batch = self.vipe_tokenizer(input_text, padding=True, return_tensors="pt") input_ids = batch["input_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) self.vipe_model.to(self.device) # how many new tokens to generate at max max_prompt_length = 50 generated_ids = self.vipe_model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_prompt_length, do_sample=do_sample, top_k=top_k, epsilon_cutoff=epsilon_cutoff, temperature=temperature) # return only the generated prompts prompts = self.vipe_tokenizer.batch_decode(generated_ids[:, -(generated_ids.shape[1] - input_ids.shape[1]):], skip_special_tokens=True) # for semantic similarity mpnet_object = mpnet_embed_class(device=self.device, nli=False) similarities = mpnet_object.get_mpnet_embed_batch(prompts, [summary] * batch_size, batch_size=batch_size).cpu().numpy() concreteness_score = get_concreteness(prompts, self.word2score) final_scores = [i * (1 - self.abstractness) + (self.abstractness) * j for i, j in zip(similarities, concreteness_score)] # Get the indices that would sort the final_scores in descending order sorted_indices = np.argsort(final_scores)[::-1] # Extract the indices of the top 5 highest scores top_indices = sorted_indices[:self.num_prompt] prompts = [prompts[i] for i in top_indices] return prompts def sdxl_generate(self, prompts): # Define how many steps and what % of steps to be run on each experts (80/20) here n_steps = 50 high_noise_frac = 0.8 self.basexl.to(self.device) self.refinerxl.to(self.device) images=[] for i, p in enumerate(prompts): # torch.manual_seed(i) image = self.basexl( prompt=p, num_inference_steps=n_steps, denoising_end=high_noise_frac, output_type="latent", ).images image = self.refinerxl( prompt=p, num_inference_steps=n_steps, denoising_start=high_noise_frac, image=image, ).images[0] images.append(image) return images def ignite(self, document): prompts = [] summary = self.t5_summarize(document) prompts.append(summary) summary = summary.replace('. ', '; ') print(summary) prompts.extend(self.vipe_generate(summary)) for p in prompts: print(p + '\n') images=self.sdxl_generate(prompts) return images