Spaces:
Runtime error
Runtime error
File size: 10,450 Bytes
3815e0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
from transformers import AutoModelWithLMHead, AutoTokenizer
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from diffusers import DiffusionPipeline
import torch
from tqdm import tqdm
import pandas as pd
import numpy as np
import random
from utils import mpnet_embed_class, get_concreteness, Collate_t5
from torch.utils.data import DataLoader
from utils import SentenceDataset
class Summagery:
def __init__(self, t5_checkpoint, batch_size=5, abstractness=.4, max_d_length=1256, num_prompt=3, device='cuda'):
# ViPE: Visualize Pretty-much Everything
self.vipe_model = GPT2LMHeadModel.from_pretrained('fittar/ViPE-M-CTX7')
vipe_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
vipe_tokenizer.pad_token = vipe_tokenizer.eos_token
self.vipe_tokenizer = vipe_tokenizer
# SDXL, load both base & refiner
self.basexl = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
self.refinerxl = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=self.basexl.text_encoder_2,
vae=self.basexl.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
self.device = device
self.max_d_length = max_d_length # maximum document length to handle before chunking
self.final_document_length = 60
self.num_prompt = num_prompt # how many prompts to generate per document
self.abstractness = abstractness # to explore the prompts , just a handle from 0 to 1
self.concreteness_dataset = './data/concreteness.csv'
self.batch_size = batch_size
# T5
self.t5_model = AutoModelWithLMHead.from_pretrained(t5_checkpoint)
self.t5_tokenizer = AutoTokenizer.from_pretrained(t5_checkpoint, model_max_length=max_d_length)
self.collate_t5 = Collate_t5(self.t5_tokenizer)
# for concrteness rating of the prompts
data = pd.read_csv(self.concreteness_dataset, header=0,
delimiter='\t')
self.word2score = {w: s for w, s in zip(data['WORD'], data['RATING'])}
# for large documents, divide them into chunks with self.max_d_length size
def document_preprocess(self, document):
documents = []
words = document.split()
if len(words) <= self.max_d_length:
return [document]
start = 0
while (len(words) > start):
if len(words) > start + self.max_d_length:
chunk = ' '.join(words[start:start + self.max_d_length])
else:
chunk = ' '.join(words[start:])
start += self.max_d_length
documents.append(chunk)
return documents
def t5_summarize(self, document):
continue_summarization = True
if len(document.split()) <= self.final_document_length:
return document
self.t5_model.to(self.device)
documents = self.document_preprocess(document)
if len(documents) > self.batch_size:
# use batch inference to make things faster
while (continue_summarization):
dataset = SentenceDataset(documents)
dataloader = DataLoader(dataset, batch_size=self.batch_size, collate_fn=self.collate_t5, num_workers=2)
summaries = ''
print('summarizing...')
for text_batch, batch in tqdm(dataloader):
if batch.input_ids.shape[1] > 5:
max_length = int(batch.input_ids.shape[1] / 2) # summarize the current chunk by half
if max_length < self.final_document_length: # unless max_length is too short
max_length = self.final_document_length
batch = batch.to(self.device)
generated_ids = self.t5_model.generate(input_ids=batch.input_ids,
attention_mask=batch.attention_mask, num_beams=3,
max_length=max_length,
repetition_penalty=2.5,
length_penalty=1.0, early_stopping=True)
preds = \
[self.t5_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for g
in
generated_ids]
for pred in preds:
summaries = summaries + pred + '. '
else:
for chunk in text_batch:
summaries = summaries + chunk + '. '
if len(summaries.split()) <= self.final_document_length:
continue_summarization = False
print('finished summarizing.')
else:
documents = self.document_preprocess(summaries)
else:
# skip batch inference since we only have a few documents
while (continue_summarization):
summaries = ''
print('summarizing...')
for chunk in tqdm(documents):
if len(chunk.split()) > 2:
max_length = int(len(chunk.split()) / 2) # summarize the current chunk by half
if max_length < self.final_document_length: # unless max_length is too short
max_length = self.final_document_length
input_ids = self.t5_tokenizer.encode('summarize: ' + chunk, return_tensors="pt",
add_special_tokens=True, padding='longest',
max_length=self.max_d_length)
input_ids = input_ids.to(self.device)
generated_ids = self.t5_model.generate(input_ids=input_ids, num_beams=3, max_length=max_length,
repetition_penalty=2.5,
length_penalty=1.0, early_stopping=True)
pred = \
[self.t5_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g
in
generated_ids][0]
summaries = summaries + pred + '. '
else:
summaries = summaries + chunk + '. '
if len(summaries.split()) <= self.final_document_length:
continue_summarization = False
print('finished summarizing.')
else:
documents = self.document_preprocess(summaries)
return summaries
def vipe_generate(self, summary, do_sample=True, top_k=100, epsilon_cutoff=.00005, temperature=1):
batch_size = random.choice([20, 40, 60])
input_text = [summary] * batch_size
# mark the text with special tokens
input_text = [self.vipe_tokenizer.eos_token + i + self.vipe_tokenizer.eos_token for i in input_text]
batch = self.vipe_tokenizer(input_text, padding=True, return_tensors="pt")
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
self.vipe_model.to(self.device)
# how many new tokens to generate at max
max_prompt_length = 50
generated_ids = self.vipe_model.generate(input_ids=input_ids, attention_mask=attention_mask,
max_new_tokens=max_prompt_length, do_sample=do_sample, top_k=top_k,
epsilon_cutoff=epsilon_cutoff, temperature=temperature)
# return only the generated prompts
prompts = self.vipe_tokenizer.batch_decode(generated_ids[:, -(generated_ids.shape[1] - input_ids.shape[1]):],
skip_special_tokens=True)
# for semantic similarity
mpnet_object = mpnet_embed_class(device=self.device, nli=False)
similarities = mpnet_object.get_mpnet_embed_batch(prompts, [summary] * batch_size,
batch_size=batch_size).cpu().numpy()
concreteness_score = get_concreteness(prompts, self.word2score)
final_scores = [i * (1 - self.abstractness) + (self.abstractness) * j for i, j in
zip(similarities, concreteness_score)]
# Get the indices that would sort the final_scores in descending order
sorted_indices = np.argsort(final_scores)[::-1]
# Extract the indices of the top 5 highest scores
top_indices = sorted_indices[:self.num_prompt]
prompts = [prompts[i] for i in top_indices]
return prompts
def sdxl_generate(self, prompts):
# Define how many steps and what % of steps to be run on each experts (80/20) here
n_steps = 50
high_noise_frac = 0.8
self.basexl.to(self.device)
self.refinerxl.to(self.device)
images=[]
for i, p in enumerate(prompts):
# torch.manual_seed(i)
image = self.basexl(
prompt=p,
num_inference_steps=n_steps,
denoising_end=high_noise_frac,
output_type="latent",
).images
image = self.refinerxl(
prompt=p,
num_inference_steps=n_steps,
denoising_start=high_noise_frac,
image=image,
).images[0]
images.append(image)
return images
def ignite(self, document):
prompts = []
summary = self.t5_summarize(document)
prompts.append(summary)
summary = summary.replace('. ', '; ')
print(summary)
prompts.extend(self.vipe_generate(summary))
for p in prompts:
print(p + '\n')
images=self.sdxl_generate(prompts)
return images |