Spaces:
Running
on
A10G
Running
on
A10G
DEVICE = 'cuda' | |
import gradio as gr | |
import numpy as np | |
from sklearn.svm import SVC | |
from sklearn import preprocessing | |
import pandas as pd | |
from diffusers import LCMScheduler, AutoencoderTiny, EulerDiscreteScheduler, UNet2DConditionModel, AutoPipelineForText2Image, DiffusionPipeline | |
from diffusers.models import ImageProjection | |
import torch | |
torch.set_float32_matmul_precision('high') | |
import random | |
import time | |
import spaces | |
from urllib.request import urlopen | |
from PIL import Image | |
import requests | |
from io import BytesIO, StringIO | |
from transformers import CLIPVisionModelWithProjection | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
from safety_checker_improved import maybe_nsfw | |
prompt_list = [p for p in list(set( | |
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] | |
start_time = time.time() | |
####################### Setup Model | |
model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
sdxl_lightening = "ByteDance/SDXL-Lightning" | |
ckpt = "sdxl_lightning_2step_unet.safetensors" | |
unet = UNet2DConditionModel.from_config(model_id, subfolder="unet", low_cpu_mem_usage=True, device_map=DEVICE).to(torch.float16) | |
unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt))) | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map=DEVICE) | |
pipe = AutoPipelineForText2Image.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", image_encoder=image_encoder, low_cpu_mem_usage=True) | |
pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl_vit-h.bin'))) | |
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl_vit-h.bin") | |
pipe.register_modules(image_encoder = image_encoder) | |
pipe.set_ip_adapter_scale(0.8) | |
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16, low_cpu_mem_usage=True) | |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") | |
pipe.to(device=DEVICE) | |
def compile_em(): | |
return None# TODO add back | |
pipe.unet = torch.compile(pipe.unet, mode='reduce-overhead') | |
pipe.vae = torch.compile(pipe.vae, mode='reduce-overhead') | |
autoencoder.model.forward = torch.compile(autoencoder.model.forward, backend='inductor', dynamic=True) | |
output_hidden_state = False | |
####################### | |
####################### Setup autoencoder | |
from tqdm import tqdm | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
class BottleneckT5Autoencoder: | |
def __init__(self, model_path: str, device='cuda'): | |
self.device = device | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=512, torch_dtype=torch.bfloat16) | |
self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, low_cpu_mem_usage=True).to('cuda') | |
self.model.eval() | |
def embed(self, text: str) -> torch.FloatTensor: | |
inputs = self.tokenizer(text, return_tensors='pt', padding=True).to(self.device) | |
decoder_inputs = self.tokenizer('', return_tensors='pt').to(self.device) | |
return self.model( | |
**inputs, | |
decoder_input_ids=decoder_inputs['input_ids'], | |
encode_only=True, | |
) | |
def generate_from_latent(self, latent: torch.FloatTensor, max_length=20, temperature=1., top_p=.8, min_new_tokens=30) -> str: | |
dummy_text = '.' | |
dummy = self.embed(dummy_text) | |
perturb_vector = latent - dummy | |
self.model.perturb_vector = perturb_vector | |
input_ids = self.tokenizer(dummy_text, return_tensors='pt').to(self.device).input_ids | |
output = self.model.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
num_return_sequences=1, | |
min_new_tokens=min_new_tokens, | |
# num_beams=8, | |
) | |
return self.tokenizer.decode(output[0], skip_special_tokens=True) | |
autoencoder = BottleneckT5Autoencoder(model_path='thesephist/contra-bottleneck-t5-xl-wikipedia') | |
compile_em() | |
####################### | |
def generate(prompt, in_embs=None,): | |
if prompt != '': | |
print(prompt) | |
in_embs = in_embs / in_embs.abs().max() * .15 if in_embs != None else None | |
in_embs = .9 * in_embs.to('cuda') + .5 * autoencoder.embed(prompt).to('cuda') if in_embs != None else autoencoder.embed(prompt).to('cuda') | |
else: | |
print('From embeds.') | |
in_embs = in_embs / in_embs.abs().max() * .15 | |
text = autoencoder.generate_from_latent(in_embs.to('cuda').to(dtype=torch.bfloat16), temperature=.3, top_p=.99, min_new_tokens=5) | |
return text, in_embs.to('cpu') | |
def predict( | |
prompt, | |
im_emb=None, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
"""Run a single prediction on the model""" | |
with torch.no_grad(): | |
if im_emb == None: | |
im_emb = torch.zeros(1, 1024, dtype=torch.float16, device=DEVICE) | |
im_emb = [im_emb.to(DEVICE).unsqueeze(0)] | |
if prompt == '': | |
image = pipe( | |
prompt_embeds=torch.zeros(1, 1, 2048, dtype=torch.float16, device=DEVICE), | |
pooled_prompt_embeds=torch.zeros(1, 1280, dtype=torch.float16, device=DEVICE), | |
ip_adapter_image_embeds=im_emb, | |
height=1024, | |
width=1024, | |
num_inference_steps=2, | |
guidance_scale=0, | |
# timesteps=[800], | |
).images[0] | |
else: | |
image = pipe( | |
prompt=prompt, | |
ip_adapter_image_embeds=im_emb, | |
height=1024, | |
width=1024, | |
num_inference_steps=2, | |
guidance_scale=0, | |
# timesteps=[800], | |
).images[0] | |
im_emb, _ = pipe.encode_image( | |
image, DEVICE, 1, output_hidden_state | |
) | |
nsfw = maybe_nsfw(image) | |
if nsfw: | |
return None, im_emb.to('cpu') | |
return image, im_emb.to('cpu') | |
# sample a .8 of rated embeddings for some stochasticity, or at least two embeddings. | |
def get_coeff(embs_local, ys): | |
n_to_choose = max(int(len(embs_local)*.8), 2) | |
indices = random.sample(range(len(embs_local)), n_to_choose) | |
# we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749); | |
# this ends up adding a rating but losing an embedding, it seems. | |
# let's take off a rating if so to continue without indexing errors. | |
if len(ys) > len(embs_local): | |
print('ys are longer than embs; popping latest rating') | |
ys.pop(-1) | |
# also add the latest 0 and the latest 1 | |
has_0 = False | |
has_1 = False | |
for i in reversed(range(len(ys))): | |
if ys[i] == 0 and has_0 == False: | |
indices.append(i) | |
has_0 = True | |
elif ys[i] == 1 and has_1 == False: | |
indices.append(i) | |
has_1 = True | |
if has_0 and has_1: | |
break | |
feature_embs = np.array(torch.cat([embs_local[i].to('cpu') for i in indices]).to('cpu')) | |
scaler = preprocessing.StandardScaler().fit(feature_embs) | |
feature_embs = scaler.transform(feature_embs) | |
print(len(feature_embs), len(ys)) | |
lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=.1).fit(feature_embs, np.array([ys[i] for i in indices])) | |
coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) | |
coef_ = (coef_.flatten() / (coef_.flatten().norm())).unsqueeze(0) | |
return coef_ | |
# TODO add to state instead of shared across all | |
glob_idx = 0 | |
def next_image(embs, img_embs, ys, calibrate_prompts): | |
global glob_idx | |
glob_idx = glob_idx + 1 | |
if glob_idx >= 12: | |
glob_idx = 0 | |
# handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike' | |
if len(calibrate_prompts) == 0 and len(list(set(ys))) <= 1: | |
embs.append(.01*torch.randn(1, 2048)) | |
embs.append(.01*torch.randn(1, 2048)) | |
img_embs.append(.01*torch.randn(1, 1024)) | |
img_embs.append(.01*torch.randn(1, 1024)) | |
ys.append(0) | |
ys.append(1) | |
with torch.no_grad(): | |
if len(calibrate_prompts) > 0: | |
print('######### Calibrating with sample prompts #########') | |
prompt = calibrate_prompts.pop(0) | |
print(prompt) | |
image, img_emb = predict(prompt) | |
im_emb = autoencoder.embed(prompt) | |
embs.append(im_emb) | |
img_embs.append(img_emb) | |
return image, embs, img_embs, ys, calibrate_prompts, prompt | |
else: | |
print('######### Roaming #########') | |
pos_indices = [i for i in range(len(embs)) if ys[i] == 1] | |
neg_indices = [i for i in range(len(embs)) if ys[i] == 0] | |
if len(neg_indices) > 40: | |
neg_indices = neg_indices[1:] | |
# popping first negative rating due to > 25 | |
indices = pos_indices + neg_indices | |
embs = [embs[i] for i in indices] | |
img_embs = [img_embs[i] for i in indices] | |
ys = [ys[i] for i in indices] | |
im_s = get_coeff(embs, ys) | |
rng_prompt = random.choice(prompt_list) | |
w = 1.4# if len(embs) % 2 == 0 else 0 | |
prompt= '' if not glob_idx % 3 == 0 else rng_prompt | |
prompt, _ = generate(prompt, in_embs=im_s) | |
print(prompt) | |
im_emb = autoencoder.embed(prompt) | |
embs.append(im_emb) | |
learn_emb = get_coeff(img_embs, ys) | |
img_emb = w * learn_emb.to(dtype=torch.float16) | |
image, img_emb = predict(prompt, im_emb=img_emb) | |
img_embs.append(img_emb) | |
return image, embs, img_embs, ys, calibrate_prompts, prompt | |
def start(_, embs, img_embs, ys, calibrate_prompts): | |
image, embs, img_embs, ys, calibrate_prompts, prompt = next_image(embs, img_embs, ys, calibrate_prompts) | |
return [ | |
gr.Button(value='Like (L)', interactive=True), | |
gr.Button(value='Neither (Space)', interactive=True), | |
gr.Button(value='Dislike (A)', interactive=True), | |
gr.Button(value='Start', interactive=False), | |
image, | |
embs, | |
img_embs, | |
ys, | |
calibrate_prompts, | |
prompt | |
] | |
def choose(img, choice, embs, img_embs, ys, calibrate_prompts): | |
if choice == 'Like (L)': | |
choice = 1 | |
elif choice == 'Neither (Space)': | |
_ = embs.pop(-1) | |
_ = img_embs.pop(-1) | |
img, embs, img_embs, ys, calibrate_prompts, prompt = next_image(embs, img_embs, ys, calibrate_prompts) | |
return img, embs, img_embs, ys, calibrate_prompts, prompt | |
else: | |
choice = 0 | |
if img is None: | |
print('NSFW -- choice is disliked') | |
choice = 0 | |
ys.append(choice) | |
img, embs, img_embs, ys, calibrate_prompts, prompt = next_image(embs, img_embs, ys, calibrate_prompts) | |
return img, embs, img_embs, ys, calibrate_prompts, prompt | |
css = '''.gradio-container{max-width: 700px !important} | |
#description{text-align: center} | |
#description h1, #description h3{display: block} | |
#description p{margin-top: 0} | |
.fade-in-out {animation: fadeInOut 3s forwards} | |
@keyframes fadeInOut { | |
0% { | |
background: var(--bg-color); | |
} | |
100% { | |
background: var(--button-secondary-background-fill); | |
} | |
} | |
''' | |
js_head = ''' | |
<script> | |
document.addEventListener('keydown', function(event) { | |
if (event.key === 'a' || event.key === 'A') { | |
// Trigger click on 'dislike' if 'A' is pressed | |
document.getElementById('dislike').click(); | |
} else if (event.key === ' ' || event.keyCode === 32) { | |
// Trigger click on 'neither' if Spacebar is pressed | |
document.getElementById('neither').click(); | |
} else if (event.key === 'l' || event.key === 'L') { | |
// Trigger click on 'like' if 'L' is pressed | |
document.getElementById('like').click(); | |
} | |
}); | |
function fadeInOut(button, color) { | |
button.style.setProperty('--bg-color', color); | |
button.classList.remove('fade-in-out'); | |
void button.offsetWidth; // This line forces a repaint by accessing a DOM property | |
button.classList.add('fade-in-out'); | |
button.addEventListener('animationend', () => { | |
button.classList.remove('fade-in-out'); // Reset the animation state | |
}, {once: true}); | |
} | |
document.body.addEventListener('click', function(event) { | |
const target = event.target; | |
if (target.id === 'dislike') { | |
fadeInOut(target, '#ff1717'); | |
} else if (target.id === 'like') { | |
fadeInOut(target, '#006500'); | |
} else if (target.id === 'neither') { | |
fadeInOut(target, '#cccccc'); | |
} | |
}); | |
</script> | |
''' | |
with gr.Blocks(css=css, head=js_head) as demo: | |
gr.Markdown('''### Zahir: Generative Recommenders for Unprompted, Scalable Exploration | |
Explore the latent space without prompting based on your feedback. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/). | |
''', elem_id="description") | |
embs = gr.State([]) | |
img_embs = gr.State([]) | |
ys = gr.State([]) | |
calibrate_prompts = gr.State([ | |
'the moon is melting into my glass of tea', | |
'a sea slug -- pair of claws scuttling -- jelly fish glowing', | |
'an adorable creature. It may be a goblin or a pig or a slug.', | |
'an animation about a gorgeous nebula', | |
'a sketch of an impressive mountain by da vinci', | |
'a watercolor painting: the octopus writhes', | |
]) | |
with gr.Row(): | |
prompt = gr.Textbox(interactive=False, elem_id="text") | |
with gr.Row(elem_id='output-image'): | |
img = gr.Image(interactive=False, elem_id='output-image', width=700) | |
with gr.Row(equal_height=True): | |
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike") | |
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither") | |
b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like") | |
b1.click( | |
choose, | |
[img, b1, embs, img_embs, ys, calibrate_prompts], | |
[img, embs, img_embs, ys, calibrate_prompts, prompt] | |
) | |
b2.click( | |
choose, | |
[img, b2, embs, img_embs, ys, calibrate_prompts], | |
[img, embs, img_embs, ys, calibrate_prompts, prompt] | |
) | |
b3.click( | |
choose, | |
[img, b3, embs, img_embs, ys, calibrate_prompts], | |
[img, embs, img_embs, ys, calibrate_prompts, prompt] | |
) | |
with gr.Row(): | |
b4 = gr.Button(value='Start') | |
b4.click(start, | |
[b4, embs, img_embs, ys, calibrate_prompts], | |
[b1, b2, b3, b4, img, embs, img_embs, ys, calibrate_prompts, prompt]) | |
with gr.Row(): | |
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br> | |
<div style='text-align:center; font-size:14px'>Note that while the SDXL model is unlikely to produce NSFW images, it still may be possible, and users should avoid NSFW content when rating. | |
</ div>''') | |
demo.launch(share=True) # Share your demo with just 1 extra parameter π | |