rynmurdock's picture
device changes
385fb5f verified
raw
history blame
10.6 kB
DEVICE = 'cuda'
import gradio as gr
import numpy as np
from sklearn.svm import LinearSVC
from sklearn import preprocessing
import pandas as pd
from diffusers import LCMScheduler, AutoencoderTiny, EulerDiscreteScheduler, UNet2DConditionModel, AutoPipelineForText2Image
from diffusers.models import ImageProjection
import torch
import random
import time
import torch
from urllib.request import urlopen
from PIL import Image
import requests
from io import BytesIO, StringIO
from transformers import CLIPVisionModelWithProjection
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces
prompt_list = [p for p in list(set(
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
start_time = time.time()
####################### Setup Model
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
sdxl_lightening = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_2step_unet.safetensors"
unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to(DEVICE, torch.float16)
unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device=DEVICE))
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16,).to(DEVICE)
pipe = AutoPipelineForText2Image.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", image_encoder=image_encoder).to(DEVICE)
pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl_vit-h.bin')))
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl_vit-h.bin")
pipe.register_modules(image_encoder = image_encoder)
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.to(device=DEVICE)
output_hidden_state = False
#######################
@spaces.GPU
def predict(
prompt,
im_emb=None,
progress=gr.Progress(track_tqdm=True)
):
"""Run a single prediction on the model"""
with torch.no_grad():
if im_emb == None:
im_emb = torch.zeros(1, 1024, dtype=torch.float16, device=DEVICE)
im_emb = [im_emb.to(DEVICE).unsqueeze(0)]
if prompt == '':
image = pipe(
prompt_embeds=torch.zeros(1, 1, 2048, dtype=torch.float16, device=DEVICE),
pooled_prompt_embeds=torch.zeros(1, 1280, dtype=torch.float16, device=DEVICE),
ip_adapter_image_embeds=im_emb,
height=1024,
width=1024,
num_inference_steps=2,
guidance_scale=0,
).images[0]
else:
image = pipe(
prompt=prompt,
ip_adapter_image_embeds=im_emb,
height=1024,
width=1024,
num_inference_steps=2,
guidance_scale=0,
).images[0]
im_emb, _ = pipe.encode_image(
image, DEVICE, 1, output_hidden_state
)
return image, im_emb.to('cpu')
# TODO add to state instead of shared across all
glob_idx = 0
def next_image(embs, ys, calibrate_prompts):
global glob_idx
glob_idx = glob_idx + 1
# handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike'
if len(calibrate_prompts) == 0 and len(list(set(ys))) <= 1:
embs.append(.01*torch.randn(1, 1024))
embs.append(.01*torch.randn(1, 1024))
ys.append(0)
ys.append(1)
with torch.no_grad():
if len(calibrate_prompts) > 0:
print('######### Calibrating with sample prompts #########')
prompt = calibrate_prompts.pop(0)
print(prompt)
image, img_emb = predict(prompt)
embs.append(img_emb)
return image, embs, ys, calibrate_prompts
else:
print('######### Roaming #########')
# sample a .8 of rated embeddings for some stochasticity, or at least two embeddings.
n_to_choose = max(int(len(embs)*.8), 2)
indices = random.sample(range(len(embs)), n_to_choose)
# also add the latest 0 and the latest 1
has_0 = False
has_1 = False
for i in reversed(range(len(ys))):
if ys[i] == 0 and has_0 == False:
indices.append(i)
has_0 = True
elif ys[i] == 1 and has_1 == False:
indices.append(i)
has_1 = True
if has_0 and has_1:
break
feature_embs = np.array(torch.cat([embs[i].to('cpu') for i in indices]).to('cpu'))
scaler = preprocessing.StandardScaler().fit(feature_embs)
feature_embs = scaler.transform(feature_embs)
lin_class = LinearSVC(max_iter=50000, dual='auto', class_weight='balanced').fit(feature_embs, np.array([ys[i] for i in indices]))
lin_class.coef_ = torch.tensor(lin_class.coef_, dtype=torch.double)
lin_class.coef_ = (lin_class.coef_.flatten() / (lin_class.coef_.flatten().norm())).unsqueeze(0)
rng_prompt = random.choice(prompt_list)
w = 1# if len(embs) % 2 == 0 else 0
im_emb = w * lin_class.coef_.to(dtype=torch.float16)
prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
print(prompt, len(ys))
image, im_emb = predict(prompt, im_emb)
embs.append(im_emb)
if len(embs) > 100:
embs.pop(0)
ys.pop(0)
return image, embs, ys, calibrate_prompts
def start(_, embs, ys, calibrate_prompts):
image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
return [
gr.Button(value='Like (L)', interactive=True),
gr.Button(value='Neither (Space)', interactive=True),
gr.Button(value='Dislike (A)', interactive=True),
gr.Button(value='Start', interactive=False),
image,
embs,
ys,
calibrate_prompts
]
def choose(choice, embs, ys, calibrate_prompts):
if choice == 'Like (L)':
choice = 1
elif choice == 'Neither (Space)':
_ = embs.pop(-1)
img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
return img, embs, ys, calibrate_prompts
else:
choice = 0
ys.append(choice)
img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
return img, embs, ys, calibrate_prompts
css = '''.gradio-container{max-width: 700px !important}
#description{text-align: center}
#description h1, #description h3{display: block}
#description p{margin-top: 0}
.fade-in-out {animation: fadeInOut 3s forwards}
@keyframes fadeInOut {
0% {
background: var(--bg-color);
}
100% {
background: var(--button-secondary-background-fill);
}
}
'''
js_head = '''
<script>
document.addEventListener('keydown', function(event) {
if (event.key === 'a' || event.key === 'A') {
// Trigger click on 'dislike' if 'A' is pressed
document.getElementById('dislike').click();
} else if (event.key === ' ' || event.keyCode === 32) {
// Trigger click on 'neither' if Spacebar is pressed
document.getElementById('neither').click();
} else if (event.key === 'l' || event.key === 'L') {
// Trigger click on 'like' if 'L' is pressed
document.getElementById('like').click();
}
});
function fadeInOut(button, color) {
button.style.setProperty('--bg-color', color);
button.classList.remove('fade-in-out');
void button.offsetWidth; // This line forces a repaint by accessing a DOM property
button.classList.add('fade-in-out');
button.addEventListener('animationend', () => {
button.classList.remove('fade-in-out'); // Reset the animation state
}, {once: true});
}
document.body.addEventListener('click', function(event) {
const target = event.target;
if (target.id === 'dislike') {
fadeInOut(target, '#ff1717');
} else if (target.id === 'like') {
fadeInOut(target, '#006500');
} else if (target.id === 'neither') {
fadeInOut(target, '#cccccc');
}
});
</script>
'''
with gr.Blocks(css=css, head=js_head) as demo:
gr.Markdown('''### Zahir: Generative Recommenders for Unprompted, Scalable Exploration
Explore the latent space without text prompts, based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
''', elem_id="description")
embs = gr.State([])
ys = gr.State([])
calibrate_prompts = gr.State([
"4k photo",
'surrealist art',
# 'a psychedelic, fractal view',
'a beautiful collage',
'abstract art',
'an eldritch image',
'a sketch',
# 'a city full of darkness and graffiti',
'',
])
with gr.Row(elem_id='output-image'):
img = gr.Image(interactive=False, elem_id='output-image',width=700)
with gr.Row(equal_height=True):
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
b1.click(
choose,
[b1, embs, ys, calibrate_prompts],
[img, embs, ys, calibrate_prompts]
)
b2.click(
choose,
[b2, embs, ys, calibrate_prompts],
[img, embs, ys, calibrate_prompts]
)
b3.click(
choose,
[b3, embs, ys, calibrate_prompts],
[img, embs, ys, calibrate_prompts]
)
with gr.Row():
b4 = gr.Button(value='Start')
b4.click(start,
[b4, embs, ys, calibrate_prompts],
[b1, b2, b3, b4, img, embs, ys, calibrate_prompts])
with gr.Row():
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam.</ div>''')
demo.launch() # Share your demo with just 1 extra parameter πŸš€