Spaces:
Running
on
A10G
Running
on
A10G
DEVICE = 'cpu' | |
import gradio as gr | |
import numpy as np | |
from sklearn.svm import LinearSVC | |
from sklearn import preprocessing | |
import pandas as pd | |
from transformers import CLIPVisionModelWithProjection | |
from diffusers import LCMScheduler, AutoencoderTiny, EulerDiscreteScheduler, UNet2DConditionModel, StableDiffusionXLPipeline | |
from diffusers.models import ImageProjection | |
from patch_sdxl import SDEmb | |
import torch | |
import spaces | |
import random | |
import time | |
import torch | |
from urllib.request import urlopen | |
from PIL import Image | |
import requests | |
from io import BytesIO, StringIO | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
prompt_list = [p for p in list(set( | |
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] | |
start_time = time.time() | |
####################### Setup Model | |
model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
sdxl_lightening = "ByteDance/SDXL-Lightning" | |
ckpt = "sdxl_lightning_2step_unet.safetensors" | |
unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to("cuda", torch.float16) | |
unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device="cuda")) | |
pipe = SDEmb.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda") | |
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16) | |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") | |
pipe.to(device='cuda') | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained('h94/IP-Adapter', subfolder='sdxl_models/image_encoder', torch_dtype=torch.float16).to("cuda") | |
pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl.bin'), map_location="cpu")) | |
pipe.register_modules(image_encoder = image_encoder) | |
# pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") | |
output_hidden_state = False | |
####################### | |
def predict( | |
prompt, | |
im_emb=None, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
"""Run a single prediction on the model""" | |
with torch.no_grad(): | |
if im_emb == None: | |
im_emb = torch.zeros(1, 1280, dtype=torch.float16, device='cuda') | |
image = pipe( | |
prompt=prompt, | |
ip_adapter_emb=[im_emb.to('cuda')], | |
height=1024, | |
width=1024, | |
num_inference_steps=2, | |
guidance_scale=0, | |
).images[0] | |
im_emb, _ = pipe.encode_image( | |
image, 'cuda', 1, output_hidden_state | |
) | |
return image, im_emb.to(DEVICE) | |
# TODO add to state instead of shared across all | |
glob_idx = 0 | |
def next_image(embs, ys, calibrate_prompts): | |
global glob_idx | |
glob_idx = glob_idx + 1 | |
# handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike' | |
if len(calibrate_prompts) == 0 and len(list(set(ys))) <= 1: | |
embs.append(.01*torch.randn(1, 1280)) | |
embs.append(.01*torch.randn(1, 1280)) | |
ys.append(0) | |
ys.append(1) | |
with torch.no_grad(): | |
if len(calibrate_prompts) > 0: | |
print('######### Calibrating with sample prompts #########') | |
prompt = calibrate_prompts.pop(0) | |
print(prompt) | |
image, img_emb = predict(prompt) | |
embs.append(img_emb) | |
return image, embs, ys, calibrate_prompts | |
else: | |
print('######### Roaming #########') | |
# sample only as many negatives as there are positives | |
indices = range(len(ys)) | |
pos_indices = [i for i in indices if ys[i] == 1] | |
neg_indices = [i for i in indices if ys[i] == 0] | |
lower = min(len(pos_indices), len(neg_indices)) | |
neg_indices = random.sample(neg_indices, lower) | |
pos_indices = random.sample(pos_indices, lower) | |
cut_embs = [embs[i] for i in neg_indices] + [embs[i] for i in pos_indices] | |
cut_ys = [ys[i] for i in neg_indices] + [ys[i] for i in pos_indices] | |
feature_embs = torch.stack([e[0].detach().cpu() for e in cut_embs]) | |
scaler = preprocessing.StandardScaler().fit(feature_embs) | |
feature_embs = scaler.transform(feature_embs) | |
print(np.array(feature_embs).shape, np.array(ys).shape) | |
lin_class = LinearSVC(max_iter=50000, dual='auto', class_weight='balanced').fit(np.array(feature_embs), np.array(cut_ys)) | |
lin_class.coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) | |
lin_class.coef_ = (lin_class.coef_.flatten() / (lin_class.coef_.flatten().norm())).unsqueeze(0) | |
rng_prompt = random.choice(prompt_list) | |
w = 1# if len(embs) % 2 == 0 else 0 | |
im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16) | |
prompt= '' if glob_idx % 2 == 0 else rng_prompt | |
print(prompt) | |
image, im_emb = predict(prompt, im_emb) | |
embs.append(im_emb) | |
return image, embs, ys, calibrate_prompts | |
def start(_, embs, ys, calibrate_prompts): | |
image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts) | |
return [ | |
gr.Button(value='Like (L)', interactive=True), | |
gr.Button(value='Neither (Space)', interactive=True), | |
gr.Button(value='Dislike (A)', interactive=True), | |
gr.Button(value='Start', interactive=False), | |
image, | |
embs, | |
ys, | |
calibrate_prompts | |
] | |
def choose(choice, embs, ys, calibrate_prompts): | |
if choice == 'Like': | |
choice = 1 | |
elif choice == 'Neither': | |
_ = embs.pop(-1) | |
img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts) | |
return img, embs, ys, calibrate_prompts | |
else: | |
choice = 0 | |
ys.append(choice) | |
img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts) | |
return img, embs, ys, calibrate_prompts | |
css = '''.gradio-container{max-width: 700px !important} | |
#description{text-align: center} | |
#description h1{display: block} | |
#description p{margin-top: 0} | |
''' | |
js = ''' | |
<script> | |
document.addEventListener('keydown', function(event) { | |
if (event.key === 'a' || event.key === 'A') { | |
// Trigger click on 'dislike' if 'A' is pressed | |
document.getElementById('dislike').click(); | |
} else if (event.key === ' ' || event.keyCode === 32) { | |
// Trigger click on 'neither' if Spacebar is pressed | |
document.getElementById('neither').click(); | |
} else if (event.key === 'l' || event.key === 'L') { | |
// Trigger click on 'like' if 'L' is pressed | |
document.getElementById('like').click(); | |
} | |
}); | |
</script> | |
''' | |
with gr.Blocks(css=css, head=js) as demo: | |
gr.Markdown('''# Generative Recommenders | |
Explore the latent space without text prompts, based on your preferences. [Learn more on the blog](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/) | |
''', elem_id="description") | |
embs = gr.State([]) | |
ys = gr.State([]) | |
calibrate_prompts = gr.State([ | |
"4k photo", | |
'surrealist art', | |
# 'a psychedelic, fractal view', | |
'a beautiful collage', | |
'abstract art', | |
'an eldritch image', | |
'a sketch', | |
# 'a city full of darkness and graffiti', | |
'', | |
]) | |
with gr.Row(elem_id='output-image'): | |
img = gr.Image(interactive=False, elem_id='output-image',width=700) | |
with gr.Row(equal_height=True): | |
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike") | |
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither") | |
b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like") | |
b1.click( | |
choose, | |
[b1, embs, ys, calibrate_prompts], | |
[img, embs, ys, calibrate_prompts] | |
) | |
b2.click( | |
choose, | |
[b2, embs, ys, calibrate_prompts], | |
[img, embs, ys, calibrate_prompts] | |
) | |
b3.click( | |
choose, | |
[b3, embs, ys, calibrate_prompts], | |
[img, embs, ys, calibrate_prompts] | |
) | |
with gr.Row(): | |
b4 = gr.Button(value='Start') | |
b4.click(start, | |
[b4, embs, ys, calibrate_prompts], | |
[b1, b2, b3, b4, img, embs, ys, calibrate_prompts]) | |
with gr.Row(): | |
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam.</ div>''') | |
demo.launch() # Share your demo with just 1 extra parameter π |