DEVICE = 'cuda' from sfast.compilers.diffusion_pipeline_compiler import (compile, CompilationConfig) config = CompilationConfig.Default() import gradio as gr import numpy as np from sklearn.svm import LinearSVC from sklearn import preprocessing import pandas as pd from diffusers import LCMScheduler, AutoencoderTiny, EulerDiscreteScheduler, UNet2DConditionModel, AutoPipelineForText2Image from diffusers.models import ImageProjection import torch import random import time import torch from urllib.request import urlopen from PIL import Image import requests from io import BytesIO, StringIO from transformers import CLIPVisionModelWithProjection from huggingface_hub import hf_hub_download from safetensors.torch import load_file #import spaces prompt_list = [p for p in list(set( pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] start_time = time.time() ####################### Setup Model model_id = "stabilityai/stable-diffusion-xl-base-1.0" sdxl_lightening = "ByteDance/SDXL-Lightning" ckpt = "sdxl_lightning_2step_unet.safetensors" unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to(DEVICE, torch.float16) unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device=DEVICE)) image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16,).to(DEVICE) pipe = AutoPipelineForText2Image.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", image_encoder=image_encoder).to(DEVICE) pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl_vit-h.bin'))) pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl_vit-h.bin") pipe.register_modules(image_encoder = image_encoder) pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16) pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") pipe.to(device=DEVICE) pipe = compile(pipe, config=config) image = pipe(prompt_embeds=torch.zeros(1, 1, 2048, dtype=torch.float16, device=DEVICE), pooled_prompt_embeds=torch.zeros(1, 1280, dtype=torch.float16, device=DEVICE), ip_adapter_image_embeds=[torch.zeros(1, 1, 1024, dtype=torch.float16, device=DEVICE)], height=1024, width=1024, num_inference_steps=2, guidance_scale=0, ).images[0] output_hidden_state = False ####################### ####################### Setup autoencoder from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM class BottleneckT5Autoencoder: def __init__(self, model_path: str, device='cuda'): self.device = device self.tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=512, torch_dtype=torch.bfloat16) self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(self.device) self.model.eval() # self.model = torch.compile(self.model) def embed(self, text: str) -> torch.FloatTensor: inputs = self.tokenizer(text, return_tensors='pt', padding=True).to(self.device) decoder_inputs = self.tokenizer('', return_tensors='pt').to(self.device) return self.model( **inputs, decoder_input_ids=decoder_inputs['input_ids'], encode_only=True, ) def generate_from_latent(self, latent: torch.FloatTensor, max_length=512, temperature=1., top_p=.8, length_penalty=10, min_new_tokens=30) -> str: dummy_text = '.' dummy = self.embed(dummy_text) perturb_vector = latent - dummy self.model.perturb_vector = perturb_vector input_ids = self.tokenizer(dummy_text, return_tensors='pt').to(self.device).input_ids output = self.model.generate( input_ids=input_ids, max_length=max_length, do_sample=True, temperature=temperature, top_p=top_p, num_return_sequences=1, length_penalty=length_penalty, min_new_tokens=min_new_tokens, # num_beams=8, ) return self.tokenizer.decode(output[0], skip_special_tokens=True) autoencoder = BottleneckT5Autoencoder(model_path='thesephist/contra-bottleneck-t5-xl-wikipedia') ####################### def generate(prompt, in_embs=None,): if prompt != '': print(prompt) in_embs = in_embs / in_embs.abs().max() * .15 if in_embs != None else None in_embs = .9 * in_embs.to('cuda') + .5 * autoencoder.embed(prompt).to('cuda') if in_embs != None else autoencoder.embed(prompt).to('cuda') else: print('From embeds.') in_embs = in_embs / in_embs.abs().max() * .15 text = autoencoder.generate_from_latent(in_embs.to('cuda').to(dtype=torch.bfloat16), temperature=.3, top_p=.99, min_new_tokens=5) return text, in_embs.to('cpu') #@spaces.GPU def predict( prompt, im_emb=None, progress=gr.Progress(track_tqdm=True) ): """Run a single prediction on the model""" with torch.no_grad(): if im_emb == None: im_emb = torch.zeros(1, 1024, dtype=torch.float16, device=DEVICE) im_emb = [im_emb.to(DEVICE).unsqueeze(0)] if prompt == '': image = pipe( prompt_embeds=torch.zeros(1, 1, 2048, dtype=torch.float16, device=DEVICE), pooled_prompt_embeds=torch.zeros(1, 1280, dtype=torch.float16, device=DEVICE), ip_adapter_image_embeds=im_emb, height=1024, width=1024, num_inference_steps=2, guidance_scale=0, ).images[0] else: image = pipe( prompt=prompt, ip_adapter_image_embeds=im_emb, height=1024, width=1024, num_inference_steps=2, guidance_scale=0, ).images[0] im_emb, _ = pipe.encode_image( image, DEVICE, 1, output_hidden_state ) return image, im_emb.to('cpu') # sample a .8 of rated embeddings for some stochasticity, or at least two embeddings. def get_coeff(embs_local, ys): n_to_choose = max(int(len(embs_local)*.8), 2) indices = random.sample(range(len(embs_local)), n_to_choose) # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749); # this ends up adding a rating but losing an embedding, it seems. # let's take off a rating if so to continue without indexing errors. if len(ys) > len(embs_local): print('ys are longer than embs; popping latest rating') ys.pop(-1) # also add the latest 0 and the latest 1 has_0 = False has_1 = False for i in reversed(range(len(ys))): if ys[i] == 0 and has_0 == False: indices.append(i) has_0 = True elif ys[i] == 1 and has_1 == False: indices.append(i) has_1 = True if has_0 and has_1: break feature_embs = np.array(torch.cat([embs_local[i].to('cpu') for i in indices]).to('cpu')) scaler = preprocessing.StandardScaler().fit(feature_embs) feature_embs = scaler.transform(feature_embs) print(len(feature_embs), len(ys)) lin_class = LinearSVC(max_iter=50000, dual='auto', class_weight='balanced').fit(feature_embs, np.array([ys[i] for i in indices])) lin_class.coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) lin_class.coef_ = (lin_class.coef_.flatten() / (lin_class.coef_.flatten().norm())).unsqueeze(0) return lin_class.coef_ # TODO add to state instead of shared across all glob_idx = 0 def next_image(embs, img_embs, ys, calibrate_prompts): global glob_idx glob_idx = glob_idx + 1 if glob_idx >= 12: glob_idx = 0 # handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike' if len(calibrate_prompts) == 0 and len(list(set(ys))) <= 1: embs.append(.01*torch.randn(1, 1024)) embs.append(.01*torch.randn(1, 1024)) img_embs.append(.01*torch.randn(1, 1024)) img_embs.append(.01*torch.randn(1, 1024)) ys.append(0) ys.append(1) with torch.no_grad(): if len(calibrate_prompts) > 0: print('######### Calibrating with sample prompts #########') prompt = calibrate_prompts.pop(0) print(prompt) image, img_emb = predict(prompt) im_emb = autoencoder.embed(prompt) embs.append(im_emb) img_embs.append(img_emb) return image, embs, img_embs, ys, calibrate_prompts else: print('######### Roaming #########') im_s = get_coeff(embs, ys) rng_prompt = random.choice(prompt_list) w = 1.4# if len(embs) % 2 == 0 else 0 prompt= '' if glob_idx % 2 == 0 else rng_prompt prompt, _ = generate(prompt, in_embs=im_s) print(prompt) im_emb = autoencoder.embed(prompt) embs.append(im_emb) learn_emb = get_coeff(img_embs, ys) img_emb = w * learn_emb.to(dtype=torch.float16) image, img_emb = predict(prompt, im_emb=img_emb) img_embs.append(img_emb) if len(embs) > 20: embs.pop(0) img_embs.pop(0) ys.pop(0) return image, embs, img_embs, ys, calibrate_prompts def start(_, embs, img_embs, ys, calibrate_prompts): image, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts) return [ gr.Button(value='Like (L)', interactive=True), gr.Button(value='Neither (Space)', interactive=True), gr.Button(value='Dislike (A)', interactive=True), gr.Button(value='Start', interactive=False), image, embs, img_embs, ys, calibrate_prompts ] def choose(choice, embs, img_embs, ys, calibrate_prompts): if choice == 'Like (L)': choice = 1 elif choice == 'Neither (Space)': _ = embs.pop(-1) _ = img_embs.pop(-1) img, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts) return img, embs, img_embs, ys, calibrate_prompts else: choice = 0 ys.append(choice) img, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts) return img, embs, img_embs, ys, calibrate_prompts css = '''.gradio-container{max-width: 700px !important} #description{text-align: center} #description h1, #description h3{display: block} #description p{margin-top: 0} .fade-in-out {animation: fadeInOut 3s forwards} @keyframes fadeInOut { 0% { background: var(--bg-color); } 100% { background: var(--button-secondary-background-fill); } } ''' js_head = ''' ''' with gr.Blocks(css=css, head=js_head) as demo: gr.Markdown('''### Zahir: Generative Recommenders for Unprompted, Scalable Exploration Explore the latent space without text prompts, based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/). ''', elem_id="description") embs = gr.State([]) img_embs = gr.State([]) ys = gr.State([]) calibrate_prompts = gr.State([ 'the moon is melting into my glass of tea', 'a sea slug -- pair of claws scuttling -- jelly fish glowing', 'an adorable creature. It may be a goblin or a pig or a slug.', 'an animation about a gorgeous nebula', 'a sketch of an impressive mountain by da vinci', 'a watercolor painting: the octopus writhes', ]) with gr.Row(elem_id='output-image'): img = gr.Image(interactive=False, elem_id='output-image', width=700) with gr.Row(equal_height=True): b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike") b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither") b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like") b1.click( choose, [b1, embs, img_embs, ys, calibrate_prompts], [img, embs, img_embs, ys, calibrate_prompts] ) b2.click( choose, [b2, embs, img_embs, ys, calibrate_prompts], [img, embs, img_embs, ys, calibrate_prompts] ) b3.click( choose, [b3, embs, img_embs, ys, calibrate_prompts], [img, embs, img_embs, ys, calibrate_prompts] ) with gr.Row(): b4 = gr.Button(value='Start') b4.click(start, [b4, embs, img_embs, ys, calibrate_prompts], [b1, b2, b3, b4, img, embs, img_embs, ys, calibrate_prompts]) with gr.Row(): html = gr.HTML('''
You will calibrate for several prompts and then roam.


Note that while the SDXL model is unlikely to produce NSFW images, it still may be possible, and users should avoid NSFW content when rating. ''') demo.launch(share=True) # Share your demo with just 1 extra parameter 🚀