Spaces:
Running
on
A10G
Running
on
A10G
DEVICE = 'cpu' | |
import gradio as gr | |
import numpy as np | |
from sklearn.svm import LinearSVC | |
from sklearn import preprocessing | |
import pandas as pd | |
import random | |
import time | |
import replicate | |
import torch | |
import pickle | |
from urllib.request import urlopen | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
prompt_list = [p for p in list(set( | |
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] | |
calibrate_prompts = [ | |
"4k photo", | |
'surrealist art', | |
'a psychedelic, fractal view', | |
'a beautiful collage', | |
'an intricate portrait', | |
'an impressionist painting', | |
'abstract art', | |
'an eldritch image', | |
'a sketch', | |
'a city full of darkness and graffiti', | |
'a black & white photo', | |
'a brilliant, timeless tarot card of the world', | |
'a photo of a woman', | |
'', | |
] | |
embs = [] | |
ys = [] | |
start_time = time.time() | |
glob_idx = 0 | |
def next_image(): | |
global glob_idx | |
glob_idx = glob_idx + 1 | |
with torch.no_grad(): | |
if len(calibrate_prompts) > 0: | |
print('######### Calibrating with sample prompts #########') | |
prompt = calibrate_prompts.pop(0) | |
print(prompt) | |
image, pooled_embeds = replicate.run( | |
"rynmurdock/zahir:a8f4d222537221ba7a52252b8faf53eedb530b135218c59f349a681e5f24c641", | |
input={"prompt": prompt,} | |
) | |
response = requests.get(url) | |
image = Image.open(BytesIO(response.content)) | |
embs.append(pickle.load(urlopen(pooled_embeim_embds, 'rb'))) | |
return image | |
else: | |
print('######### Roaming #########') | |
# sample only as many negatives as there are positives | |
indices = range(len(ys)) | |
pos_indices = [i for i in indices if ys[i] == 1] | |
neg_indices = [i for i in indices if ys[i] == 0] | |
lower = min(len(pos_indices), len(neg_indices)) | |
neg_indices = random.sample(neg_indices, lower) | |
pos_indices = random.sample(pos_indices, lower) | |
cut_embs = [embs[i] for i in neg_indices] + [embs[i] for i in pos_indices] | |
cut_ys = [ys[i] for i in neg_indices] + [ys[i] for i in pos_indices] | |
feature_embs = torch.stack([e[0].detach().cpu() for e in cut_embs]) | |
scaler = preprocessing.StandardScaler().fit(feature_embs) | |
feature_embs = scaler.transform(feature_embs) | |
print(np.array(feature_embs).shape, np.array(ys).shape) | |
lin_class = LinearSVC(max_iter=50000, dual='auto', class_weight='balanced').fit(np.array(feature_embs), np.array(cut_ys)) | |
lin_class.coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) | |
lin_class.coef_ = (lin_class.coef_.flatten() / (lin_class.coef_.flatten().norm())).unsqueeze(0) | |
rng_prompt = random.choice(prompt_list) | |
w = 1# if len(embs) % 2 == 0 else 0 | |
im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16) | |
prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt | |
print(prompt) | |
image, im_emb = replicate.run( | |
"rynmurdock/zahir:a8f4d222537221ba7a52252b8faf53eedb530b135218c59f349a681e5f24c641", | |
input={"prompt": prompt, 'im_emb': pickle.dumps(im_emb)} | |
) | |
response = requests.get(url) | |
image = Image.open(BytesIO(response.content)) | |
im_emb = pickle.load(urlopen(im_emb, 'rb')) | |
embs.append(im_emb) | |
torch.save(lin_class.coef_, f'./{start_time}.pt') | |
return image | |
def start(_): | |
return [ | |
gr.Button(value='Like', interactive=True), | |
gr.Button(value='Neither', interactive=True), | |
gr.Button(value='Dislike', interactive=True), | |
gr.Button(value='Start', interactive=False), | |
next_image() | |
] | |
def choose(choice): | |
if choice == 'Like': | |
choice = 1 | |
elif choice == 'Neither': | |
_ = embs.pop(-1) | |
return next_image() | |
else: | |
choice = 0 | |
ys.append(choice) | |
return next_image() | |
css = "div#output-image {height: 768px !important; width: 768px !important; margin:auto;}" | |
with gr.Blocks(css=css) as demo: | |
with gr.Row(): | |
html = gr.HTML('''<div style='text-align:center; font-size:32'>You will callibrate for several prompts and then roam.</ div>''') | |
with gr.Row(elem_id='output-image'): | |
img = gr.Image(interactive=False, elem_id='output-image',) | |
with gr.Row(equal_height=True): | |
b3 = gr.Button(value='Dislike', interactive=False,) | |
b2 = gr.Button(value='Neither', interactive=False,) | |
b1 = gr.Button(value='Like', interactive=False,) | |
b1.click( | |
choose, | |
[b1], | |
[img] | |
) | |
b2.click( | |
choose, | |
[b2], | |
[img] | |
) | |
b3.click( | |
choose, | |
[b3], | |
[img] | |
) | |
with gr.Row(): | |
b4 = gr.Button(value='Start') | |
b4.click(start, | |
[b4], | |
[b1, b2, b3, b4, img,]) | |
demo.launch() # Share your demo with just 1 extra parameter π |