DEVICE = 'cuda' import gradio as gr import numpy as np from sklearn.svm import LinearSVC from sklearn import preprocessing import pandas as pd from diffusers import LCMScheduler, AutoencoderTiny, EulerDiscreteScheduler, UNet2DConditionModel, AutoPipelineForText2Image, DiffusionPipeline from diffusers.models import ImageProjection import torch torch.set_float32_matmul_precision('high') import random import time # TODO put back import spaces from urllib.request import urlopen from PIL import Image import requests from io import BytesIO, StringIO from transformers import CLIPVisionModelWithProjection from huggingface_hub import hf_hub_download from safetensors.torch import load_file from safety_checker_improved import maybe_nsfw prompt_list = [p for p in list(set( pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] start_time = time.time() ####################### Setup Model model_id = "stabilityai/stable-diffusion-xl-base-1.0" sdxl_lightening = "ByteDance/SDXL-Lightning" ckpt = "sdxl_lightning_2step_unet.safetensors" unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to(DEVICE, torch.float16) unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device=DEVICE)) image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16,).to(DEVICE) pipe = AutoPipelineForText2Image.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", image_encoder=image_encoder).to(DEVICE) pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl_vit-h.bin'))) pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl_vit-h.bin") pipe.register_modules(image_encoder = image_encoder) pipe.set_ip_adapter_scale(0.8) pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16) pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") pipe.to(device=DEVICE) # TODO put back @spaces.GPU def compile_em(): pipe.unet = torch.compile(pipe.unet) pipe.vae = torch.compile(pipe.vae, mode='reduce-overhead') autoencoder.model.forward = torch.compile(autoencoder.model.forward, backend='inductor', dynamic=True) output_hidden_state = False ####################### ####################### Setup autoencoder from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM class BottleneckT5Autoencoder: def __init__(self, model_path: str, device='cuda'): self.device = device self.tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=512, torch_dtype=torch.bfloat16) self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(self.device) self.model.eval() def embed(self, text: str) -> torch.FloatTensor: inputs = self.tokenizer(text, return_tensors='pt', padding=True).to(self.device) decoder_inputs = self.tokenizer('', return_tensors='pt').to(self.device) return self.model( **inputs, decoder_input_ids=decoder_inputs['input_ids'], encode_only=True, ) def generate_from_latent(self, latent: torch.FloatTensor, max_length=512, temperature=1., top_p=.8, min_new_tokens=30) -> str: dummy_text = '.' dummy = self.embed(dummy_text) perturb_vector = latent - dummy self.model.perturb_vector = perturb_vector input_ids = self.tokenizer(dummy_text, return_tensors='pt').to(self.device).input_ids output = self.model.generate( input_ids=input_ids, max_length=max_length, do_sample=True, temperature=temperature, top_p=top_p, num_return_sequences=1, min_new_tokens=min_new_tokens, # num_beams=8, ) return self.tokenizer.decode(output[0], skip_special_tokens=True) autoencoder = BottleneckT5Autoencoder(model_path='thesephist/contra-bottleneck-t5-xl-wikipedia') compile_em() ####################### # TODO put back @spaces.GPU def generate(prompt, in_embs=None,): if prompt != '': print(prompt) in_embs = in_embs / in_embs.abs().max() * .15 if in_embs != None else None in_embs = .9 * in_embs.to('cuda') + .5 * autoencoder.embed(prompt).to('cuda') if in_embs != None else autoencoder.embed(prompt).to('cuda') else: print('From embeds.') in_embs = in_embs / in_embs.abs().max() * .15 text = autoencoder.generate_from_latent(in_embs.to('cuda').to(dtype=torch.bfloat16), temperature=.8, top_p=.94, min_new_tokens=5) return text, in_embs.to('cpu') # TODO put back @spaces.GPU def predict( prompt, im_emb=None, progress=gr.Progress(track_tqdm=True) ): """Run a single prediction on the model""" with torch.no_grad(): if im_emb == None: im_emb = torch.zeros(1, 1024, dtype=torch.float16, device=DEVICE) im_emb = [im_emb.to(DEVICE).unsqueeze(0)] if prompt == '': image = pipe( prompt_embeds=torch.zeros(1, 1, 2048, dtype=torch.float16, device=DEVICE), pooled_prompt_embeds=torch.zeros(1, 1280, dtype=torch.float16, device=DEVICE), ip_adapter_image_embeds=im_emb, height=1024, width=1024, num_inference_steps=2, guidance_scale=0, # timesteps=[800], ).images[0] else: image = pipe( prompt=prompt, ip_adapter_image_embeds=im_emb, height=1024, width=1024, num_inference_steps=2, guidance_scale=0, # timesteps=[800], ).images[0] im_emb, _ = pipe.encode_image( image, DEVICE, 1, output_hidden_state ) nsfw = maybe_nsfw(image) if nsfw: return None, im_emb.to('cpu') return image, im_emb.to('cpu') # sample a .8 of rated embeddings for some stochasticity, or at least two embeddings. def get_coeff(embs_local, ys): n_to_choose = max(int(len(embs_local)*.8), 2) indices = random.sample(range(len(embs_local)), n_to_choose) # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749); # this ends up adding a rating but losing an embedding, it seems. # let's take off a rating if so to continue without indexing errors. if len(ys) > len(embs_local): print('ys are longer than embs; popping latest rating') ys.pop(-1) # also add the latest 0 and the latest 1 has_0 = False has_1 = False for i in reversed(range(len(ys))): if ys[i] == 0 and has_0 == False: indices.append(i) has_0 = True elif ys[i] == 1 and has_1 == False: indices.append(i) has_1 = True if has_0 and has_1: break feature_embs = np.array(torch.cat([embs_local[i].to('cpu') for i in indices]).to('cpu')) scaler = preprocessing.StandardScaler().fit(feature_embs) feature_embs = scaler.transform(feature_embs) print(len(feature_embs), len(ys)) lin_class = LinearSVC(max_iter=50000, dual='auto', class_weight='balanced').fit(feature_embs, np.array([ys[i] for i in indices])) lin_class.coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) lin_class.coef_ = (lin_class.coef_.flatten() / (lin_class.coef_.flatten().norm())).unsqueeze(0) return lin_class.coef_ # TODO add to state instead of shared across all glob_idx = 0 def next_image(embs, img_embs, ys, calibrate_prompts): global glob_idx glob_idx = glob_idx + 1 if glob_idx >= 12: glob_idx = 0 # handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike' if len(calibrate_prompts) == 0 and len(list(set(ys))) <= 1: embs.append(.01*torch.randn(1, 2048)) embs.append(.01*torch.randn(1, 2048)) img_embs.append(.01*torch.randn(1, 1024)) img_embs.append(.01*torch.randn(1, 1024)) ys.append(0) ys.append(1) with torch.no_grad(): if len(calibrate_prompts) > 0: print('######### Calibrating with sample prompts #########') prompt = calibrate_prompts.pop(0) print(prompt) image, img_emb = predict(prompt) im_emb = autoencoder.embed(prompt) embs.append(im_emb) img_embs.append(img_emb) return image, embs, img_embs, ys, calibrate_prompts else: print('######### Roaming #########') im_s = get_coeff(embs, ys) rng_prompt = random.choice(prompt_list) w = 1.4# if len(embs) % 2 == 0 else 0 prompt= '' if not glob_idx % 3 == 0 else rng_prompt prompt, _ = generate(prompt, in_embs=im_s) print(prompt) im_emb = autoencoder.embed(prompt) embs.append(im_emb) learn_emb = get_coeff(img_embs, ys) img_emb = w * learn_emb.to(dtype=torch.float16) image, img_emb = predict(prompt, im_emb=img_emb) img_embs.append(img_emb) if len(embs) > 100: embs.pop(0) img_embs.pop(0) ys.pop(0) return image, embs, img_embs, ys, calibrate_prompts def start(_, embs, img_embs, ys, calibrate_prompts): image, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts) return [ gr.Button(value='Like (L)', interactive=True), gr.Button(value='Neither (Space)', interactive=True), gr.Button(value='Dislike (A)', interactive=True), gr.Button(value='Start', interactive=False), image, embs, img_embs, ys, calibrate_prompts ] def choose(img, choice, embs, img_embs, ys, calibrate_prompts): if choice == 'Like (L)': choice = 1 elif choice == 'Neither (Space)': _ = embs.pop(-1) _ = img_embs.pop(-1) img, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts) return img, embs, img_embs, ys, calibrate_prompts else: choice = 0 print(img, 'img') if img is None: print('NSFW -- choice is disliked') choice = 0 ys.append(choice) img, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts) return img, embs, img_embs, ys, calibrate_prompts css = '''.gradio-container{max-width: 700px !important} #description{text-align: center} #description h1, #description h3{display: block} #description p{margin-top: 0} .fade-in-out {animation: fadeInOut 3s forwards} @keyframes fadeInOut { 0% { background: var(--bg-color); } 100% { background: var(--button-secondary-background-fill); } } ''' js_head = ''' ''' with gr.Blocks(css=css, head=js_head) as demo: gr.Markdown('''### Zahir: Generative Recommenders for Unprompted, Scalable Exploration Explore the latent space without prompting based on your feedback. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/). ''', elem_id="description") embs = gr.State([]) img_embs = gr.State([]) ys = gr.State([]) calibrate_prompts = gr.State([ 'the moon is melting into my glass of tea', 'a sea slug -- pair of claws scuttling -- jelly fish glowing', 'an adorable creature. It may be a goblin or a pig or a slug.', 'an animation about a gorgeous nebula', 'a sketch of an impressive mountain by da vinci', 'a watercolor painting: the octopus writhes', ]) with gr.Row(elem_id='output-image'): img = gr.Image(interactive=False, elem_id='output-image', width=700) with gr.Row(equal_height=True): b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike") b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither") b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like") b1.click( choose, [img, b1, embs, img_embs, ys, calibrate_prompts], [img, embs, img_embs, ys, calibrate_prompts] ) b2.click( choose, [img, b2, embs, img_embs, ys, calibrate_prompts], [img, embs, img_embs, ys, calibrate_prompts] ) b3.click( choose, [img, b3, embs, img_embs, ys, calibrate_prompts], [img, embs, img_embs, ys, calibrate_prompts] ) with gr.Row(): b4 = gr.Button(value='Start') b4.click(start, [b4, embs, img_embs, ys, calibrate_prompts], [b1, b2, b3, b4, img, embs, img_embs, ys, calibrate_prompts]) with gr.Row(): html = gr.HTML('''