rynmurdock's picture
.
b3aaea2
raw
history blame
15.2 kB
DEVICE = 'cuda'
import gradio as gr
import numpy as np
from sklearn.svm import LinearSVC
from sklearn import preprocessing
import pandas as pd
from diffusers import LCMScheduler, AutoencoderTiny, EulerDiscreteScheduler, UNet2DConditionModel, AutoPipelineForText2Image, DiffusionPipeline
from diffusers.models import ImageProjection
import torch
torch.set_float32_matmul_precision('high')
import random
import time
# TODO put back
import spaces
from urllib.request import urlopen
from PIL import Image
import requests
from io import BytesIO, StringIO
from transformers import CLIPVisionModelWithProjection
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from safety_checker_improved import maybe_nsfw
prompt_list = [p for p in list(set(
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
start_time = time.time()
####################### Setup Model
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
sdxl_lightening = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_2step_unet.safetensors"
unet = UNet2DConditionModel.from_config(model_id, subfolder="unet", low_cpu_mem_usage=True).to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt)))
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16, low_cpu_mem_usage=True)
pipe = AutoPipelineForText2Image.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", image_encoder=image_encoder, low_cpu_mem_usage=True)
pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl_vit-h.bin')))
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl_vit-h.bin")
pipe.register_modules(image_encoder = image_encoder)
pipe.set_ip_adapter_scale(0.8)
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16, low_cpu_mem_usage=True)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.to(device=DEVICE)
# TODO put back
@spaces.GPU
def compile_em():
pipe.unet = torch.compile(pipe.unet, mode='reduce-overhead')
pipe.vae = torch.compile(pipe.vae, mode='reduce-overhead')
autoencoder.model.forward = torch.compile(autoencoder.model.forward, backend='inductor', dynamic=True, mode='reduce-overhead')
output_hidden_state = False
#######################
####################### Setup autoencoder
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
class BottleneckT5Autoencoder:
def __init__(self, model_path: str, device='cuda'):
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=512, torch_dtype=torch.bfloat16)
self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(self.device)
self.model.eval()
def embed(self, text: str) -> torch.FloatTensor:
inputs = self.tokenizer(text, return_tensors='pt', padding=True).to(self.device)
decoder_inputs = self.tokenizer('', return_tensors='pt').to(self.device)
return self.model(
**inputs,
decoder_input_ids=decoder_inputs['input_ids'],
encode_only=True,
)
def generate_from_latent(self, latent: torch.FloatTensor, max_length=512, temperature=1., top_p=.8, min_new_tokens=30) -> str:
dummy_text = '.'
dummy = self.embed(dummy_text)
perturb_vector = latent - dummy
self.model.perturb_vector = perturb_vector
input_ids = self.tokenizer(dummy_text, return_tensors='pt').to(self.device).input_ids
output = self.model.generate(
input_ids=input_ids,
max_length=max_length,
do_sample=True,
temperature=temperature,
top_p=top_p,
num_return_sequences=1,
min_new_tokens=min_new_tokens,
# num_beams=8,
)
return self.tokenizer.decode(output[0], skip_special_tokens=True)
autoencoder = BottleneckT5Autoencoder(model_path='thesephist/contra-bottleneck-t5-xl-wikipedia')
compile_em()
#######################
# TODO put back
@spaces.GPU
def generate(prompt, in_embs=None,):
if prompt != '':
print(prompt)
in_embs = in_embs / in_embs.abs().max() * .15 if in_embs != None else None
in_embs = .9 * in_embs.to('cuda') + .5 * autoencoder.embed(prompt).to('cuda') if in_embs != None else autoencoder.embed(prompt).to('cuda')
else:
print('From embeds.')
in_embs = in_embs / in_embs.abs().max() * .15
text = autoencoder.generate_from_latent(in_embs.to('cuda').to(dtype=torch.bfloat16), temperature=.8, top_p=.94, min_new_tokens=5)
return text, in_embs.to('cpu')
# TODO put back
@spaces.GPU
def predict(
prompt,
im_emb=None,
progress=gr.Progress(track_tqdm=True)
):
"""Run a single prediction on the model"""
with torch.no_grad():
if im_emb == None:
im_emb = torch.zeros(1, 1024, dtype=torch.float16, device=DEVICE)
im_emb = [im_emb.to(DEVICE).unsqueeze(0)]
if prompt == '':
image = pipe(
prompt_embeds=torch.zeros(1, 1, 2048, dtype=torch.float16, device=DEVICE),
pooled_prompt_embeds=torch.zeros(1, 1280, dtype=torch.float16, device=DEVICE),
ip_adapter_image_embeds=im_emb,
height=1024,
width=1024,
num_inference_steps=2,
guidance_scale=0,
# timesteps=[800],
).images[0]
else:
image = pipe(
prompt=prompt,
ip_adapter_image_embeds=im_emb,
height=1024,
width=1024,
num_inference_steps=2,
guidance_scale=0,
# timesteps=[800],
).images[0]
im_emb, _ = pipe.encode_image(
image, DEVICE, 1, output_hidden_state
)
nsfw = maybe_nsfw(image)
if nsfw:
return None, im_emb.to('cpu')
return image, im_emb.to('cpu')
# sample a .8 of rated embeddings for some stochasticity, or at least two embeddings.
def get_coeff(embs_local, ys):
n_to_choose = max(int(len(embs_local)*.8), 2)
indices = random.sample(range(len(embs_local)), n_to_choose)
# we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
# this ends up adding a rating but losing an embedding, it seems.
# let's take off a rating if so to continue without indexing errors.
if len(ys) > len(embs_local):
print('ys are longer than embs; popping latest rating')
ys.pop(-1)
# also add the latest 0 and the latest 1
has_0 = False
has_1 = False
for i in reversed(range(len(ys))):
if ys[i] == 0 and has_0 == False:
indices.append(i)
has_0 = True
elif ys[i] == 1 and has_1 == False:
indices.append(i)
has_1 = True
if has_0 and has_1:
break
feature_embs = np.array(torch.cat([embs_local[i].to('cpu') for i in indices]).to('cpu'))
scaler = preprocessing.StandardScaler().fit(feature_embs)
feature_embs = scaler.transform(feature_embs)
print(len(feature_embs), len(ys))
lin_class = LinearSVC(max_iter=50000, dual='auto', class_weight='balanced').fit(feature_embs, np.array([ys[i] for i in indices]))
lin_class.coef_ = torch.tensor(lin_class.coef_, dtype=torch.double)
lin_class.coef_ = (lin_class.coef_.flatten() / (lin_class.coef_.flatten().norm())).unsqueeze(0)
return lin_class.coef_
# TODO add to state instead of shared across all
glob_idx = 0
def next_image(embs, img_embs, ys, calibrate_prompts):
global glob_idx
glob_idx = glob_idx + 1
if glob_idx >= 12:
glob_idx = 0
# handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike'
if len(calibrate_prompts) == 0 and len(list(set(ys))) <= 1:
embs.append(.01*torch.randn(1, 2048))
embs.append(.01*torch.randn(1, 2048))
img_embs.append(.01*torch.randn(1, 1024))
img_embs.append(.01*torch.randn(1, 1024))
ys.append(0)
ys.append(1)
with torch.no_grad():
if len(calibrate_prompts) > 0:
print('######### Calibrating with sample prompts #########')
prompt = calibrate_prompts.pop(0)
print(prompt)
image, img_emb = predict(prompt)
im_emb = autoencoder.embed(prompt)
embs.append(im_emb)
img_embs.append(img_emb)
return image, embs, img_embs, ys, calibrate_prompts
else:
print('######### Roaming #########')
im_s = get_coeff(embs, ys)
rng_prompt = random.choice(prompt_list)
w = 1.4# if len(embs) % 2 == 0 else 0
prompt= '' if not glob_idx % 3 == 0 else rng_prompt
prompt, _ = generate(prompt, in_embs=im_s)
print(prompt)
im_emb = autoencoder.embed(prompt)
embs.append(im_emb)
learn_emb = get_coeff(img_embs, ys)
img_emb = w * learn_emb.to(dtype=torch.float16)
image, img_emb = predict(prompt, im_emb=img_emb)
img_embs.append(img_emb)
if len(embs) > 100:
embs.pop(0)
img_embs.pop(0)
ys.pop(0)
return image, embs, img_embs, ys, calibrate_prompts
def start(_, embs, img_embs, ys, calibrate_prompts):
image, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts)
return [
gr.Button(value='Like (L)', interactive=True),
gr.Button(value='Neither (Space)', interactive=True),
gr.Button(value='Dislike (A)', interactive=True),
gr.Button(value='Start', interactive=False),
image,
embs,
img_embs,
ys,
calibrate_prompts
]
def choose(img, choice, embs, img_embs, ys, calibrate_prompts):
if choice == 'Like (L)':
choice = 1
elif choice == 'Neither (Space)':
_ = embs.pop(-1)
_ = img_embs.pop(-1)
img, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts)
return img, embs, img_embs, ys, calibrate_prompts
else:
choice = 0
print(img, 'img')
if img is None:
print('NSFW -- choice is disliked')
choice = 0
ys.append(choice)
img, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts)
return img, embs, img_embs, ys, calibrate_prompts
css = '''.gradio-container{max-width: 700px !important}
#description{text-align: center}
#description h1, #description h3{display: block}
#description p{margin-top: 0}
.fade-in-out {animation: fadeInOut 3s forwards}
@keyframes fadeInOut {
0% {
background: var(--bg-color);
}
100% {
background: var(--button-secondary-background-fill);
}
}
'''
js_head = '''
<script>
document.addEventListener('keydown', function(event) {
if (event.key === 'a' || event.key === 'A') {
// Trigger click on 'dislike' if 'A' is pressed
document.getElementById('dislike').click();
} else if (event.key === ' ' || event.keyCode === 32) {
// Trigger click on 'neither' if Spacebar is pressed
document.getElementById('neither').click();
} else if (event.key === 'l' || event.key === 'L') {
// Trigger click on 'like' if 'L' is pressed
document.getElementById('like').click();
}
});
function fadeInOut(button, color) {
button.style.setProperty('--bg-color', color);
button.classList.remove('fade-in-out');
void button.offsetWidth; // This line forces a repaint by accessing a DOM property
button.classList.add('fade-in-out');
button.addEventListener('animationend', () => {
button.classList.remove('fade-in-out'); // Reset the animation state
}, {once: true});
}
document.body.addEventListener('click', function(event) {
const target = event.target;
if (target.id === 'dislike') {
fadeInOut(target, '#ff1717');
} else if (target.id === 'like') {
fadeInOut(target, '#006500');
} else if (target.id === 'neither') {
fadeInOut(target, '#cccccc');
}
});
</script>
'''
with gr.Blocks(css=css, head=js_head) as demo:
gr.Markdown('''### Zahir: Generative Recommenders for Unprompted, Scalable Exploration
Explore the latent space without prompting based on your feedback. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
''', elem_id="description")
embs = gr.State([])
img_embs = gr.State([])
ys = gr.State([])
calibrate_prompts = gr.State([
'the moon is melting into my glass of tea',
'a sea slug -- pair of claws scuttling -- jelly fish glowing',
'an adorable creature. It may be a goblin or a pig or a slug.',
'an animation about a gorgeous nebula',
'a sketch of an impressive mountain by da vinci',
'a watercolor painting: the octopus writhes',
])
with gr.Row(elem_id='output-image'):
img = gr.Image(interactive=False, elem_id='output-image', width=700)
with gr.Row(equal_height=True):
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
b1.click(
choose,
[img, b1, embs, img_embs, ys, calibrate_prompts],
[img, embs, img_embs, ys, calibrate_prompts]
)
b2.click(
choose,
[img, b2, embs, img_embs, ys, calibrate_prompts],
[img, embs, img_embs, ys, calibrate_prompts]
)
b3.click(
choose,
[img, b3, embs, img_embs, ys, calibrate_prompts],
[img, embs, img_embs, ys, calibrate_prompts]
)
with gr.Row():
b4 = gr.Button(value='Start')
b4.click(start,
[b4, embs, img_embs, ys, calibrate_prompts],
[b1, b2, b3, b4, img, embs, img_embs, ys, calibrate_prompts])
with gr.Row():
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br>
<div style='text-align:center; font-size:14px'>Note that while the SDXL model is unlikely to produce NSFW images, it still may be possible, and users should avoid NSFW content when rating.
</ div>''')
demo.launch(share=True) # Share your demo with just 1 extra parameter πŸš€