DEVICE = 'cpu' import gradio as gr import numpy as np from sklearn.svm import LinearSVC from sklearn import preprocessing import pandas as pd import random import time import replicate import torch from urllib.request import urlopen from PIL import Image import requests from io import BytesIO, StringIO prompt_list = [p for p in list(set( pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] start_time = time.time() # TODO add to state instead of shared across all glob_idx = 0 def next_image(embs, ys, calibrate_prompts): global glob_idx glob_idx = glob_idx + 1 with torch.no_grad(): if len(calibrate_prompts) > 0: print('######### Calibrating with sample prompts #########') prompt = calibrate_prompts.pop(0) print(prompt) output = replicate.run( "rynmurdock/zahir:42c58addd49ab57f1e309f0b9a0f271f483bbef0470758757c623648fe989e42", input={"prompt": prompt,} ) response = requests.get(output['file1']) image = Image.open(BytesIO(response.content)) embs.append(torch.tensor([float(i) for i in urlopen(output['file2']).read().decode('utf-8').split(', ')]).unsqueeze(0)) return image, embs, ys else: print('######### Roaming #########') # sample only as many negatives as there are positives indices = range(len(ys)) pos_indices = [i for i in indices if ys[i] == 1] neg_indices = [i for i in indices if ys[i] == 0] lower = min(len(pos_indices), len(neg_indices)) neg_indices = random.sample(neg_indices, lower) pos_indices = random.sample(pos_indices, lower) cut_embs = [embs[i] for i in neg_indices] + [embs[i] for i in pos_indices] cut_ys = [ys[i] for i in neg_indices] + [ys[i] for i in pos_indices] feature_embs = torch.stack([e[0].detach().cpu() for e in cut_embs]) scaler = preprocessing.StandardScaler().fit(feature_embs) feature_embs = scaler.transform(feature_embs) print(np.array(feature_embs).shape, np.array(ys).shape) lin_class = LinearSVC(max_iter=50000, dual='auto', class_weight='balanced').fit(np.array(feature_embs), np.array(cut_ys)) lin_class.coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) lin_class.coef_ = (lin_class.coef_.flatten() / (lin_class.coef_.flatten().norm())).unsqueeze(0) rng_prompt = random.choice(prompt_list) w = 1# if len(embs) % 2 == 0 else 0 im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16) prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt print(prompt) im_emb_st = str(im_emb[0].cpu().detach().tolist())[1:-1] output = replicate.run( "rynmurdock/zahir:42c58addd49ab57f1e309f0b9a0f271f483bbef0470758757c623648fe989e42", input={"prompt": prompt, 'im_emb': im_emb_st} ) response = requests.get(output['file1']) image = Image.open(BytesIO(response.content)) im_emb = torch.tensor([float(i) for i in urlopen(output['file2']).read().decode('utf-8').split(', ')]).unsqueeze(0) embs.append(im_emb) torch.save(lin_class.coef_, f'./{start_time}.pt') return image, embs, ys, calibrate_prompts def start(_, embs, ys, calibrate_prompts): image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts) return [ gr.Button(value='Like', interactive=True), gr.Button(value='Neither', interactive=True), gr.Button(value='Dislike', interactive=True), gr.Button(value='Start', interactive=False), image, embs, ys, calibrate_prompts ] def choose(choice, embs, ys, calibrate_prompts): if choice == 'Like': choice = 1 elif choice == 'Neither': _ = embs.pop(-1) img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts) else: choice = 0 ys.append(choice) img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts) return img, embs, ys, calibrate_prompts css = "div#output-image {height: 768px !important; width: 768px !important; margin:auto;}" with gr.Blocks(css=css) as demo: embs = gr.State([]) ys = gr.State([]) calibrate_prompts = gr.State([ "4k photo", 'surrealist art', # 'a psychedelic, fractal view', 'a beautiful collage', 'abstract art', 'an eldritch image', 'a sketch', # 'a city full of darkness and graffiti', '', ]) with gr.Row(elem_id='output-image'): img = gr.Image(interactive=False, elem_id='output-image',) with gr.Row(equal_height=True): b3 = gr.Button(value='Dislike', interactive=False,) b2 = gr.Button(value='Neither', interactive=False,) b1 = gr.Button(value='Like', interactive=False,) b1.click( choose, [b1, embs, ys, calibrate_prompts], [img, embs, ys, calibrate_prompts] ) b2.click( choose, [b2, embs, ys, calibrate_prompts], [img, embs, ys, calibrate_prompts] ) b3.click( choose, [b3, embs, ys, calibrate_prompts], [img, embs, ys, calibrate_prompts] ) with gr.Row(): b4 = gr.Button(value='Start') b4.click(start, [b4, embs, ys, calibrate_prompts], [b1, b2, b3, b4, img, calibrate_prompts]) with gr.Row(): html = gr.HTML('''