Spaces:
Runtime error
Runtime error
rynmurdock
commited on
Commit
•
9435075
1
Parent(s):
9239164
lfs and sync with blue-tigers github
Browse files- .gitattributes +20 -0
- app.py +78 -108
- eigth.gemb_.pt +3 -0
- eigth.im_.pt +3 -0
- fifth.gemb_.pt +3 -0
- fifth.im_.pt +3 -0
- first.gemb_.pt +3 -0
- first.im_.pt +3 -0
- fourth.gemb_.pt +3 -0
- fourth.im_.pt +3 -0
- lightning_app.py +0 -452
- ninth.gemb_.pt +3 -0
- ninth.im_.pt +3 -0
- requirements.txt +1 -3
- second.gemb_.pt +3 -0
- second.im_.pt +3 -0
- seventh.gemb_.pt +3 -0
- seventh.im_.pt +3 -0
- sixth.gemb_.pt +3 -0
- sixth.im_.pt +3 -0
- tenth.gemb_.pt +3 -0
- tenth.im_.pt +3 -0
- third.gemb_.pt +3 -0
- third.im_.pt +3 -0
- twitter_prompts.csv +0 -72
.gitattributes
CHANGED
@@ -1 +1,21 @@
|
|
1 |
nsfweffnetv2-b02-3epochs.h5 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
nsfweffnetv2-b02-3epochs.h5 filter=lfs diff=lfs merge=lfs -text
|
2 |
+
fifth.gemb_.pt filter=lfs diff=lfs merge=lfs -text
|
3 |
+
ninth.im_.pt filter=lfs diff=lfs merge=lfs -text
|
4 |
+
tenth.gemb_.pt filter=lfs diff=lfs merge=lfs -text
|
5 |
+
third.gemb_.pt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
eigth.gemb_.pt filter=lfs diff=lfs merge=lfs -text
|
7 |
+
first.gemb_.pt filter=lfs diff=lfs merge=lfs -text
|
8 |
+
fourth.gemb_.pt filter=lfs diff=lfs merge=lfs -text
|
9 |
+
ninth.gemb_.pt filter=lfs diff=lfs merge=lfs -text
|
10 |
+
sixth.gemb_.pt filter=lfs diff=lfs merge=lfs -text
|
11 |
+
tenth.im_.pt filter=lfs diff=lfs merge=lfs -text
|
12 |
+
eigth.im_.pt filter=lfs diff=lfs merge=lfs -text
|
13 |
+
seventh.gemb_.pt filter=lfs diff=lfs merge=lfs -text
|
14 |
+
sixth.im_.pt filter=lfs diff=lfs merge=lfs -text
|
15 |
+
third.im_.pt filter=lfs diff=lfs merge=lfs -text
|
16 |
+
fifth.im_.pt filter=lfs diff=lfs merge=lfs -text
|
17 |
+
first.im_.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
+
fourth.im_.pt filter=lfs diff=lfs merge=lfs -text
|
19 |
+
second.gemb_.pt filter=lfs diff=lfs merge=lfs -text
|
20 |
+
second.im_.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
+
seventh.im_.pt filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -10,12 +10,9 @@ STEPS = 6
|
|
10 |
output_hidden_state = False
|
11 |
device = "cuda"
|
12 |
dtype = torch.bfloat16
|
|
|
13 |
|
14 |
-
import matplotlib.pyplot as plt
|
15 |
-
import matplotlib
|
16 |
import logging
|
17 |
-
|
18 |
-
|
19 |
import os
|
20 |
import imageio
|
21 |
import gradio as gr
|
@@ -24,8 +21,6 @@ from sklearn.svm import SVC
|
|
24 |
from sklearn import preprocessing
|
25 |
import pandas as pd
|
26 |
from apscheduler.schedulers.background import BackgroundScheduler
|
27 |
-
import sched
|
28 |
-
import threading
|
29 |
|
30 |
import random
|
31 |
import time
|
@@ -104,7 +99,7 @@ pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", mot
|
|
104 |
unet=unet, text_encoder=text_encoder)
|
105 |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
|
106 |
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
|
107 |
-
pipe.set_adapters(["lcm-lora"], [.
|
108 |
pipe.fuse_lora()
|
109 |
|
110 |
|
@@ -121,6 +116,7 @@ pipe.unet.fuse_qkv_projections()
|
|
121 |
#pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
|
122 |
|
123 |
pipe.to(device=DEVICE)
|
|
|
124 |
#pipe.unet = torch.compile(pipe.unet)
|
125 |
#pipe.vae = torch.compile(pipe.vae)
|
126 |
|
@@ -130,9 +126,10 @@ pipe.to(device=DEVICE)
|
|
130 |
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
|
131 |
|
132 |
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
133 |
-
pali = PaliGemmaForConditionalGeneration.from_pretrained('google/paligemma-3b-pt-224', torch_dtype=dtype,
|
134 |
processor = AutoProcessor.from_pretrained('google/paligemma-3b-pt-224')
|
135 |
|
|
|
136 |
|
137 |
@spaces.GPU()
|
138 |
def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None):
|
@@ -148,19 +145,34 @@ def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None
|
|
148 |
return inputs_embeds
|
149 |
|
150 |
|
|
|
151 |
@spaces.GPU()
|
152 |
-
def generate_pali(
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
-
generation = pali.generate(max_new_tokens=100, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
|
163 |
-
decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
164 |
return decoded
|
165 |
|
166 |
|
@@ -182,7 +194,7 @@ def generate_gpu(in_im_embs, prompt='the scene'):
|
|
182 |
im = torchvision.transforms.ToTensor()(output.frames[0][len(output.frames[0])//2]).unsqueeze(0)
|
183 |
im = torch.nn.functional.interpolate(im, (224, 224))
|
184 |
im = (im - .5) * 2
|
185 |
-
gemb = pali.vision_tower(im.to(device).to(dtype)).last_hidden_state.detach().to('cpu').to(torch.float32)
|
186 |
return output, im_emb, gemb
|
187 |
|
188 |
|
@@ -210,10 +222,10 @@ def generate(in_im_embs, prompt='the scene'):
|
|
210 |
def get_user_emb(embs, ys):
|
211 |
# handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
|
212 |
|
213 |
-
if len(list(ys)) <=
|
214 |
-
aways = [
|
215 |
embs += aways
|
216 |
-
awal = [0 for i in range(
|
217 |
ys += awal
|
218 |
|
219 |
indices = list(range(len(embs)))
|
@@ -241,9 +253,10 @@ def get_user_emb(embs, ys):
|
|
241 |
feature_embs = feature_embs / feature_embs.norm()
|
242 |
|
243 |
#lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
|
244 |
-
|
|
|
245 |
coef_ = torch.tensor(lin_class.coef_, dtype=torch.float32).detach().to('cpu')
|
246 |
-
coef_ = coef_ / coef_.abs().max()
|
247 |
|
248 |
w = 1# if len(embs) % 2 == 0 else 0
|
249 |
im_emb = w * coef_.to(dtype=dtype)
|
@@ -273,7 +286,7 @@ def background_next_image():
|
|
273 |
# only let it get N (maybe 3) ahead of the user
|
274 |
#not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
275 |
rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
|
276 |
-
while len(rated_rows) <
|
277 |
# not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
278 |
rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
|
279 |
time.sleep(.01)
|
@@ -290,25 +303,21 @@ def background_next_image():
|
|
290 |
rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
|
291 |
|
292 |
# we pop previous ratings if there are > n
|
293 |
-
if len(rated_from_user) >=
|
294 |
oldest = rated_from_user.iloc[0]['paths']
|
295 |
prevs_df = prevs_df[prevs_df['paths'] != oldest]
|
296 |
# we don't compute more after n are in the queue for them
|
297 |
-
if len(unrated_from_user) >=
|
298 |
-
continue
|
299 |
-
|
300 |
-
if len(rated_rows) < 5:
|
301 |
continue
|
302 |
|
303 |
embs, ys, gembs = pluck_embs_ys(uid)
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
text = generate_pali(user_gem)
|
310 |
else:
|
311 |
-
text =
|
312 |
img, embs, new_gem = generate(user_emb, text)
|
313 |
|
314 |
if img:
|
@@ -351,60 +360,16 @@ def next_image(calibrate_prompts, user_id):
|
|
351 |
if len(calibrate_prompts) > 0:
|
352 |
cal_video = calibrate_prompts.pop(0)
|
353 |
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
354 |
-
|
355 |
return image, calibrate_prompts, ''
|
356 |
else:
|
357 |
embs, ys, gembs = pluck_embs_ys(user_id)
|
358 |
-
user_emb = get_user_emb(embs, ys)
|
359 |
image, text = pluck_img(user_id, user_emb)
|
360 |
return image, calibrate_prompts, text
|
361 |
|
362 |
|
363 |
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
done_init = False
|
368 |
-
|
369 |
def start(_, calibrate_prompts, user_id, request: gr.Request):
|
370 |
-
global done_init
|
371 |
-
global prevs_df
|
372 |
-
|
373 |
-
if not done_init:
|
374 |
-
# prep our calibration videos
|
375 |
-
for im in [
|
376 |
-
'./first.mp4',
|
377 |
-
# './second.mp4',
|
378 |
-
# './third.mp4',
|
379 |
-
# './fourth.mp4',
|
380 |
-
# './fifth.mp4',
|
381 |
-
# './sixth.mp4',
|
382 |
-
# './seventh.mp4',
|
383 |
-
# './eigth.mp4',
|
384 |
-
# './ninth.mp4',
|
385 |
-
# './tenth.mp4',
|
386 |
-
]:
|
387 |
-
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb'])
|
388 |
-
tmp_df['paths'] = [im]
|
389 |
-
image = list(imageio.imiter(im))
|
390 |
-
image = image[len(image)//2]
|
391 |
-
|
392 |
-
im = torchvision.transforms.ToTensor()(image).unsqueeze(0)
|
393 |
-
im = torch.nn.functional.interpolate(im, (224, 224))
|
394 |
-
im = (im - .5) * 2
|
395 |
-
|
396 |
-
im_emb, gemb = encode_space(image, im)
|
397 |
-
im_emb = im_emb.to('cpu')
|
398 |
-
gemb = gemb.to('cpu')
|
399 |
-
|
400 |
-
tmp_df['embeddings'] = [im_emb]
|
401 |
-
tmp_df['gemb'] = [gemb]
|
402 |
-
tmp_df['user:rating'] = [{' ': ' '}]
|
403 |
-
prevs_df = pd.concat((prevs_df, tmp_df))
|
404 |
-
done_init = True
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
user_id = int(str(time.time())[-7:].replace('.', ''))
|
409 |
image, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
|
410 |
return [
|
@@ -436,6 +401,7 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
|
|
436 |
print('NSFW -- choice is disliked')
|
437 |
choice = 0
|
438 |
|
|
|
439 |
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
440 |
# if it's still in the dataframe, add the choice
|
441 |
if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
|
@@ -506,11 +472,11 @@ Explore the latent space without text prompts based on your preferences. Learn m
|
|
506 |
# calibration videos -- this is a misnomer now :D
|
507 |
calibrate_prompts = gr.State([
|
508 |
'./first.mp4',
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
])
|
515 |
def l():
|
516 |
return None
|
@@ -569,26 +535,30 @@ scheduler = BackgroundScheduler()
|
|
569 |
scheduler.add_job(func=background_next_image, trigger="interval", seconds=.5)
|
570 |
scheduler.start()
|
571 |
|
572 |
-
#thread = threading.Thread(target=background_next_image,)
|
573 |
-
#thread.start()
|
574 |
|
575 |
-
#
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
|
594 |
|
|
|
10 |
output_hidden_state = False
|
11 |
device = "cuda"
|
12 |
dtype = torch.bfloat16
|
13 |
+
N_IMG_EMBS = 3
|
14 |
|
|
|
|
|
15 |
import logging
|
|
|
|
|
16 |
import os
|
17 |
import imageio
|
18 |
import gradio as gr
|
|
|
21 |
from sklearn import preprocessing
|
22 |
import pandas as pd
|
23 |
from apscheduler.schedulers.background import BackgroundScheduler
|
|
|
|
|
24 |
|
25 |
import random
|
26 |
import time
|
|
|
99 |
unet=unet, text_encoder=text_encoder)
|
100 |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
|
101 |
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
|
102 |
+
pipe.set_adapters(["lcm-lora"], [.95])
|
103 |
pipe.fuse_lora()
|
104 |
|
105 |
|
|
|
116 |
#pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
|
117 |
|
118 |
pipe.to(device=DEVICE)
|
119 |
+
|
120 |
#pipe.unet = torch.compile(pipe.unet)
|
121 |
#pipe.vae = torch.compile(pipe.vae)
|
122 |
|
|
|
126 |
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
|
127 |
|
128 |
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
129 |
+
pali = PaliGemmaForConditionalGeneration.from_pretrained('google/paligemma-3b-pt-224', torch_dtype=dtype, quantization_config=quantization_config).eval()
|
130 |
processor = AutoProcessor.from_pretrained('google/paligemma-3b-pt-224')
|
131 |
|
132 |
+
#pali = torch.compile(pali)
|
133 |
|
134 |
@spaces.GPU()
|
135 |
def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None):
|
|
|
145 |
return inputs_embeds
|
146 |
|
147 |
|
148 |
+
# TODO cache descriptions?
|
149 |
@spaces.GPU()
|
150 |
+
def generate_pali(n_embs):
|
151 |
+
prompt = 'caption en'
|
152 |
+
model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
|
153 |
+
# we need to get im_embs taken in here.
|
154 |
+
|
155 |
+
descs = ''
|
156 |
+
for n, emb in enumerate(n_embs):
|
157 |
+
if n < len(n_embs)-1:
|
158 |
+
input_len = model_inputs["input_ids"].shape[-1]
|
159 |
+
input_embeds = to_wanted_embs(emb,
|
160 |
+
model_inputs["input_ids"].to(device),
|
161 |
+
model_inputs["attention_mask"].to(device))
|
162 |
+
generation = pali.generate(max_new_tokens=20, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
|
163 |
+
decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
164 |
+
descs += f'Description: {decoded}\n'
|
165 |
+
else:
|
166 |
+
prompt = f'en {descs} Describe a new image that is similar.'
|
167 |
+
print(prompt)
|
168 |
+
model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
|
169 |
+
input_len = model_inputs["input_ids"].shape[-1]
|
170 |
+
input_embeds = to_wanted_embs(emb,
|
171 |
+
model_inputs["input_ids"].to(device),
|
172 |
+
model_inputs["attention_mask"].to(device))
|
173 |
+
generation = pali.generate(max_new_tokens=20, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
|
174 |
+
decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
175 |
|
|
|
|
|
176 |
return decoded
|
177 |
|
178 |
|
|
|
194 |
im = torchvision.transforms.ToTensor()(output.frames[0][len(output.frames[0])//2]).unsqueeze(0)
|
195 |
im = torch.nn.functional.interpolate(im, (224, 224))
|
196 |
im = (im - .5) * 2
|
197 |
+
gemb = pali.vision_tower(im.to(device).to(dtype)).last_hidden_state.detach().to('cpu').to(torch.float32)
|
198 |
return output, im_emb, gemb
|
199 |
|
200 |
|
|
|
222 |
def get_user_emb(embs, ys):
|
223 |
# handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
|
224 |
|
225 |
+
if len(list(ys)) <= 10:
|
226 |
+
aways = [torch.zeros_like(embs[0]) for i in range(10)]
|
227 |
embs += aways
|
228 |
+
awal = [0 for i in range(5)] + [1 for i in range(5)]
|
229 |
ys += awal
|
230 |
|
231 |
indices = list(range(len(embs)))
|
|
|
253 |
feature_embs = feature_embs / feature_embs.norm()
|
254 |
|
255 |
#lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
|
256 |
+
#class_weight='balanced'
|
257 |
+
lin_class = SVC(max_iter=500, kernel='linear', C=.1, ).fit(feature_embs.squeeze(), chosen_y)
|
258 |
coef_ = torch.tensor(lin_class.coef_, dtype=torch.float32).detach().to('cpu')
|
259 |
+
coef_ = coef_ / coef_.abs().max()
|
260 |
|
261 |
w = 1# if len(embs) % 2 == 0 else 0
|
262 |
im_emb = w * coef_.to(dtype=dtype)
|
|
|
286 |
# only let it get N (maybe 3) ahead of the user
|
287 |
#not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
288 |
rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
|
289 |
+
while len(rated_rows) < 5:
|
290 |
# not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
291 |
rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
|
292 |
time.sleep(.01)
|
|
|
303 |
rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
|
304 |
|
305 |
# we pop previous ratings if there are > n
|
306 |
+
if len(rated_from_user) >= 25:
|
307 |
oldest = rated_from_user.iloc[0]['paths']
|
308 |
prevs_df = prevs_df[prevs_df['paths'] != oldest]
|
309 |
# we don't compute more after n are in the queue for them
|
310 |
+
if len(unrated_from_user) >= 20:
|
|
|
|
|
|
|
311 |
continue
|
312 |
|
313 |
embs, ys, gembs = pluck_embs_ys(uid)
|
314 |
+
user_emb = get_user_emb(embs, ys) * 3
|
315 |
+
pos_gembs = [g for g, y in zip(gembs, ys) if y == 1]
|
316 |
+
if len(pos_gembs) > 4:
|
317 |
+
hist_gem = random.sample(pos_gembs, N_IMG_EMBS) # rng n embeddings
|
318 |
+
text = generate_pali(hist_gem)
|
|
|
319 |
else:
|
320 |
+
text = 'the scene'
|
321 |
img, embs, new_gem = generate(user_emb, text)
|
322 |
|
323 |
if img:
|
|
|
360 |
if len(calibrate_prompts) > 0:
|
361 |
cal_video = calibrate_prompts.pop(0)
|
362 |
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
|
|
363 |
return image, calibrate_prompts, ''
|
364 |
else:
|
365 |
embs, ys, gembs = pluck_embs_ys(user_id)
|
366 |
+
user_emb = get_user_emb(embs, ys) * 3
|
367 |
image, text = pluck_img(user_id, user_emb)
|
368 |
return image, calibrate_prompts, text
|
369 |
|
370 |
|
371 |
|
|
|
|
|
|
|
|
|
|
|
372 |
def start(_, calibrate_prompts, user_id, request: gr.Request):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
user_id = int(str(time.time())[-7:].replace('.', ''))
|
374 |
image, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
|
375 |
return [
|
|
|
401 |
print('NSFW -- choice is disliked')
|
402 |
choice = 0
|
403 |
|
404 |
+
print(prevs_df['paths'].to_list(), img)
|
405 |
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
406 |
# if it's still in the dataframe, add the choice
|
407 |
if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
|
|
|
472 |
# calibration videos -- this is a misnomer now :D
|
473 |
calibrate_prompts = gr.State([
|
474 |
'./first.mp4',
|
475 |
+
'./second.mp4',
|
476 |
+
'./third.mp4',
|
477 |
+
'./fourth.mp4',
|
478 |
+
'./fifth.mp4',
|
479 |
+
'./sixth.mp4',
|
480 |
])
|
481 |
def l():
|
482 |
return None
|
|
|
535 |
scheduler.add_job(func=background_next_image, trigger="interval", seconds=.5)
|
536 |
scheduler.start()
|
537 |
|
|
|
|
|
538 |
|
539 |
+
# prep our calibration videos
|
540 |
+
for im in [
|
541 |
+
'./first.mp4',
|
542 |
+
'./second.mp4',
|
543 |
+
'./third.mp4',
|
544 |
+
'./fourth.mp4',
|
545 |
+
'./fifth.mp4',
|
546 |
+
'./sixth.mp4',
|
547 |
+
'./seventh.mp4',
|
548 |
+
'./eigth.mp4',
|
549 |
+
'./ninth.mp4',
|
550 |
+
'./tenth.mp4',
|
551 |
+
]:
|
552 |
+
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb'])
|
553 |
+
tmp_df['paths'] = [im]
|
554 |
+
image = list(imageio.imiter(im))
|
555 |
+
image = image[len(image)//2]
|
556 |
+
tmp_df['embeddings'] = [torch.load(im.replace('mp4', 'im_.pt'))]
|
557 |
+
tmp_df['gemb'] = [torch.load(im.replace('mp4', 'gemb_.pt'))]
|
558 |
+
tmp_df['user:rating'] = [{' ': ' '}]
|
559 |
+
prevs_df = pd.concat((prevs_df, tmp_df))
|
560 |
+
|
561 |
+
|
562 |
+
demo.launch(share=True, server_port=8443)
|
563 |
|
564 |
|
eigth.gemb_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:313d2e918194715ad1da5e0dbbd567ef086bb5365920b1d0ec8f727187611be2
|
3 |
+
size 1180848
|
eigth.im_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6804672b5692563a7a6886a6e4010ab983dc6c1699cb6e41375776842fe4f2c7
|
3 |
+
size 6310
|
fifth.gemb_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b55ac2c8c3b7109e2673d7ee6c631597832c0f78331ab116d6130e77c2323587
|
3 |
+
size 1180848
|
fifth.im_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6165a30db650d03f17925a6590f4e0313d9c9c3ba2e4e4ce51fe00012d0efdff
|
3 |
+
size 6310
|
first.gemb_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:caca2ad20ebeefa19efbda52c60610521552759214a47fc36bf85c3ce2c7237d
|
3 |
+
size 1180848
|
first.im_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1728d8e50da01013a81c681200e9a1568663b6b48bc824b1ad0f3894a7e06aa0
|
3 |
+
size 6310
|
fourth.gemb_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dc346bd7447c7f84119cc9275c74b5f41509b357358f2dff4aa5b63a246442ce
|
3 |
+
size 1180853
|
fourth.im_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d6b60957c398b1a3ccfe29c9fde226570a69ed2a127c9db82136fdc872e10a26
|
3 |
+
size 6315
|
lightning_app.py
DELETED
@@ -1,452 +0,0 @@
|
|
1 |
-
|
2 |
-
import torch
|
3 |
-
|
4 |
-
# lol
|
5 |
-
sidel = 512
|
6 |
-
DEVICE = 'cuda'
|
7 |
-
STEPS = 4
|
8 |
-
output_hidden_state = False
|
9 |
-
device = "cuda"
|
10 |
-
dtype = torch.float16
|
11 |
-
|
12 |
-
import matplotlib.pyplot as plt
|
13 |
-
import matplotlib
|
14 |
-
matplotlib.use('TkAgg')
|
15 |
-
|
16 |
-
from sklearn.linear_model import LinearRegression
|
17 |
-
from sfast.compilers.diffusion_pipeline_compiler import (compile, compile_unet,
|
18 |
-
CompilationConfig)
|
19 |
-
config = CompilationConfig.Default()
|
20 |
-
|
21 |
-
try:
|
22 |
-
import triton
|
23 |
-
config.enable_triton = True
|
24 |
-
except ImportError:
|
25 |
-
print('Triton not installed, skip')
|
26 |
-
config.enable_cuda_graph = True
|
27 |
-
|
28 |
-
config.enable_jit = True
|
29 |
-
config.enable_jit_freeze = True
|
30 |
-
|
31 |
-
config.enable_cnn_optimization = True
|
32 |
-
config.preserve_parameters = False
|
33 |
-
config.prefer_lowp_gemm = True
|
34 |
-
|
35 |
-
import imageio
|
36 |
-
import gradio as gr
|
37 |
-
import numpy as np
|
38 |
-
from sklearn.svm import SVC
|
39 |
-
from sklearn.inspection import permutation_importance
|
40 |
-
from sklearn import preprocessing
|
41 |
-
import pandas as pd
|
42 |
-
|
43 |
-
import random
|
44 |
-
import time
|
45 |
-
from PIL import Image
|
46 |
-
from safety_checker_improved import maybe_nsfw
|
47 |
-
|
48 |
-
|
49 |
-
torch.set_grad_enabled(False)
|
50 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
51 |
-
torch.backends.cudnn.allow_tf32 = True
|
52 |
-
|
53 |
-
# TODO put back?
|
54 |
-
# import spaces
|
55 |
-
|
56 |
-
prompt_list = [p for p in list(set(
|
57 |
-
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
58 |
-
|
59 |
-
start_time = time.time()
|
60 |
-
|
61 |
-
####################### Setup Model
|
62 |
-
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, ConsistencyDecoderVAE, AutoencoderTiny
|
63 |
-
from hyper_tile import split_attention, flush
|
64 |
-
from huggingface_hub import hf_hub_download
|
65 |
-
from safetensors.torch import load_file
|
66 |
-
from PIL import Image
|
67 |
-
from transformers import CLIPVisionModelWithProjection
|
68 |
-
import uuid
|
69 |
-
import av
|
70 |
-
|
71 |
-
def write_video(file_name, images, fps=10):
|
72 |
-
print('Saving')
|
73 |
-
container = av.open(file_name, mode="w")
|
74 |
-
|
75 |
-
stream = container.add_stream("h264", rate=fps)
|
76 |
-
stream.width = sidel
|
77 |
-
stream.height = sidel
|
78 |
-
stream.pix_fmt = "yuv420p"
|
79 |
-
|
80 |
-
for img in images:
|
81 |
-
img = np.array(img)
|
82 |
-
img = np.round(img).astype(np.uint8)
|
83 |
-
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
|
84 |
-
for packet in stream.encode(frame):
|
85 |
-
container.mux(packet)
|
86 |
-
# Flush stream
|
87 |
-
for packet in stream.encode():
|
88 |
-
container.mux(packet)
|
89 |
-
# Close the file
|
90 |
-
container.close()
|
91 |
-
print('Saved')
|
92 |
-
|
93 |
-
bases = {
|
94 |
-
#"basem": "emilianJR/epiCRealism"
|
95 |
-
#SG161222/Realistic_Vision_V6.0_B1_noVAE
|
96 |
-
#runwayml/stable-diffusion-v1-5
|
97 |
-
#frankjoshua/realisticVisionV51_v51VAE
|
98 |
-
#Lykon/dreamshaper-7
|
99 |
-
}
|
100 |
-
|
101 |
-
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=dtype).to(DEVICE)
|
102 |
-
vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=dtype)
|
103 |
-
|
104 |
-
# vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
|
105 |
-
# vae = compile_unet(vae, config=config)
|
106 |
-
|
107 |
-
#adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
|
108 |
-
#pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype)
|
109 |
-
#pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
|
110 |
-
#pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
|
111 |
-
#pipe.set_adapters(["lcm-lora"], [1])
|
112 |
-
#pipe.fuse_lora()
|
113 |
-
|
114 |
-
pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder, vae=vae)
|
115 |
-
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
116 |
-
repo = "ByteDance/AnimateDiff-Lightning"
|
117 |
-
ckpt = f"animatediff_lightning_4step_diffusers.safetensors"
|
118 |
-
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device='cpu'), strict=False)
|
119 |
-
|
120 |
-
|
121 |
-
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin", map_location='cpu')
|
122 |
-
pipe.set_ip_adapter_scale(.8)
|
123 |
-
# pipe.unet.fuse_qkv_projections()
|
124 |
-
#pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
|
125 |
-
|
126 |
-
pipe = compile(pipe, config=config)
|
127 |
-
pipe.to(device=DEVICE)
|
128 |
-
|
129 |
-
|
130 |
-
# THIS WOULD NEED PATCHING TODO
|
131 |
-
with split_attention(pipe.vae, tile_size=128, swap_size=2, disable=False, aspect_ratio=1):
|
132 |
-
# ! Change the tile_size and disable to see their effects
|
133 |
-
with split_attention(pipe.unet, tile_size=128, swap_size=2, disable=False, aspect_ratio=1):
|
134 |
-
im_embs = torch.zeros(1, 1, 1, 1024, device=DEVICE, dtype=dtype)
|
135 |
-
output = pipe(prompt='a person', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[im_embs], num_inference_steps=STEPS)
|
136 |
-
leave_im_emb, _ = pipe.encode_image(
|
137 |
-
output.frames[0][len(output.frames[0])//2], DEVICE, 1, output_hidden_state
|
138 |
-
)
|
139 |
-
assert len(output.frames[0]) == 16
|
140 |
-
leave_im_emb.to('cpu')
|
141 |
-
|
142 |
-
|
143 |
-
# TODO put back
|
144 |
-
# @spaces.GPU()
|
145 |
-
def generate(prompt, in_im_embs=None, base='basem'):
|
146 |
-
|
147 |
-
if in_im_embs == None:
|
148 |
-
in_im_embs = torch.zeros(1, 1, 1, 1024, device=DEVICE, dtype=dtype)
|
149 |
-
#in_im_embs = in_im_embs / torch.norm(in_im_embs)
|
150 |
-
else:
|
151 |
-
in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
|
152 |
-
#im_embs = torch.cat((torch.zeros(1, 1024, device=DEVICE, dtype=dtype), in_im_embs), 0)
|
153 |
-
|
154 |
-
with split_attention(pipe.unet, tile_size=128, swap_size=2, disable=False, aspect_ratio=1):
|
155 |
-
# ! Change the tile_size and disable to see their effects
|
156 |
-
with split_attention(pipe.vae, tile_size=128, disable=False, aspect_ratio=1):
|
157 |
-
output = pipe(prompt=prompt, guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
|
158 |
-
|
159 |
-
im_emb, _ = pipe.encode_image(
|
160 |
-
output.frames[0][len(output.frames[0])//2], DEVICE, 1, output_hidden_state
|
161 |
-
)
|
162 |
-
|
163 |
-
nsfw = maybe_nsfw(output.frames[0][len(output.frames[0])//2])
|
164 |
-
|
165 |
-
name = str(uuid.uuid4()).replace("-", "")
|
166 |
-
path = f"/tmp/{name}.mp4"
|
167 |
-
|
168 |
-
if nsfw:
|
169 |
-
gr.Warning("NSFW content detected.")
|
170 |
-
# TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring.
|
171 |
-
return None, im_emb
|
172 |
-
|
173 |
-
plt.close('all')
|
174 |
-
plt.hist(np.array(im_emb.to('cpu')).flatten(), bins=5)
|
175 |
-
plt.savefig('real_im_emb_plot.jpg')
|
176 |
-
|
177 |
-
write_video(path, output.frames[0])
|
178 |
-
return path, im_emb.to('cpu')
|
179 |
-
|
180 |
-
|
181 |
-
#######################
|
182 |
-
|
183 |
-
# TODO add to state instead of shared across all
|
184 |
-
glob_idx = 0
|
185 |
-
|
186 |
-
def next_image(embs, ys, calibrate_prompts):
|
187 |
-
global glob_idx
|
188 |
-
glob_idx = glob_idx + 1
|
189 |
-
|
190 |
-
with torch.no_grad():
|
191 |
-
if len(calibrate_prompts) > 0:
|
192 |
-
print('######### Calibrating with sample prompts #########')
|
193 |
-
prompt = calibrate_prompts.pop(0)
|
194 |
-
print(prompt)
|
195 |
-
image, img_embs = generate(prompt)
|
196 |
-
embs += img_embs
|
197 |
-
print(len(embs))
|
198 |
-
return image, embs, ys, calibrate_prompts
|
199 |
-
else:
|
200 |
-
print('######### Roaming #########')
|
201 |
-
|
202 |
-
# sample a .8 of rated embeddings for some stochasticity, or at least two embeddings.
|
203 |
-
# could take a sample < len(embs)
|
204 |
-
#n_to_choose = max(int((len(embs))), 2)
|
205 |
-
#indices = random.sample(range(len(embs)), n_to_choose)
|
206 |
-
|
207 |
-
# sample only as many negatives as there are positives
|
208 |
-
#pos_indices = [i for i in indices if ys[i] == 1]
|
209 |
-
#neg_indices = [i for i in indices if ys[i] == 0]
|
210 |
-
#lower = min(len(pos_indices), len(neg_indices))
|
211 |
-
#neg_indices = random.sample(neg_indices, lower)
|
212 |
-
#pos_indices = random.sample(pos_indices, lower)
|
213 |
-
#indices = neg_indices + pos_indices
|
214 |
-
|
215 |
-
pos_indices = [i for i in range(len(embs)) if ys[i] == 1]
|
216 |
-
neg_indices = [i for i in range(len(embs)) if ys[i] == 0]
|
217 |
-
|
218 |
-
# the embs & ys stay tied by index but we shuffle to drop randomly
|
219 |
-
random.shuffle(pos_indices)
|
220 |
-
random.shuffle(neg_indices)
|
221 |
-
|
222 |
-
#if len(pos_indices) - len(neg_indices) > 48 and len(pos_indices) > 80:
|
223 |
-
# pos_indices = pos_indices[32:]
|
224 |
-
if len(neg_indices) - len(pos_indices) > 48/16 and len(pos_indices) > 120/16:
|
225 |
-
pos_indices = pos_indices[1:]
|
226 |
-
if len(neg_indices) - len(pos_indices) > 48/16 and len(neg_indices) > 200/16:
|
227 |
-
neg_indices = neg_indices[2:]
|
228 |
-
|
229 |
-
|
230 |
-
print(len(pos_indices), len(neg_indices))
|
231 |
-
indices = pos_indices + neg_indices
|
232 |
-
|
233 |
-
embs = [embs[i] for i in indices]
|
234 |
-
ys = [ys[i] for i in indices]
|
235 |
-
indices = list(range(len(embs)))
|
236 |
-
|
237 |
-
|
238 |
-
# handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike'
|
239 |
-
if len(list(set(ys))) <= 1:
|
240 |
-
embs.append(.01*torch.randn(1024))
|
241 |
-
embs.append(.01*torch.randn(1024))
|
242 |
-
ys.append(0)
|
243 |
-
ys.append(1)
|
244 |
-
|
245 |
-
|
246 |
-
# also add the latest 0 and the latest 1
|
247 |
-
has_0 = False
|
248 |
-
has_1 = False
|
249 |
-
for i in reversed(range(len(ys))):
|
250 |
-
if ys[i] == 0 and has_0 == False:
|
251 |
-
indices.append(i)
|
252 |
-
has_0 = True
|
253 |
-
elif ys[i] == 1 and has_1 == False:
|
254 |
-
indices.append(i)
|
255 |
-
has_1 = True
|
256 |
-
if has_0 and has_1:
|
257 |
-
break
|
258 |
-
|
259 |
-
# we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
|
260 |
-
# this ends up adding a rating but losing an embedding, it seems.
|
261 |
-
# let's take off a rating if so to continue without indexing errors.
|
262 |
-
if len(ys) > len(embs):
|
263 |
-
print('ys are longer than embs; popping latest rating')
|
264 |
-
ys.pop(-1)
|
265 |
-
|
266 |
-
feature_embs = np.array(torch.stack([embs[i].to('cpu') for i in indices] + [leave_im_emb[0].to('cpu')]).to('cpu'))
|
267 |
-
scaler = preprocessing.StandardScaler().fit(feature_embs)
|
268 |
-
feature_embs = scaler.transform(feature_embs)
|
269 |
-
chosen_y = np.array([ys[i] for i in indices] + [0])
|
270 |
-
|
271 |
-
print('Gathering coefficients')
|
272 |
-
#lin_class = LinearRegression(fit_intercept=False).fit(feature_embs, chosen_y)
|
273 |
-
lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=1).fit(feature_embs, chosen_y)
|
274 |
-
coef_ = torch.tensor(lin_class.coef_, dtype=torch.double)
|
275 |
-
coef_ = coef_ / coef_.abs().max() * 3
|
276 |
-
print(coef_.shape, 'COEF')
|
277 |
-
|
278 |
-
plt.close('all')
|
279 |
-
plt.hist(np.array(coef_).flatten(), bins=5)
|
280 |
-
plt.savefig('plot.jpg')
|
281 |
-
print(coef_)
|
282 |
-
print('Gathered')
|
283 |
-
|
284 |
-
rng_prompt = random.choice(prompt_list)
|
285 |
-
w = 1# if len(embs) % 2 == 0 else 0
|
286 |
-
im_emb = w * coef_.to(dtype=dtype)
|
287 |
-
|
288 |
-
prompt= 'the scene' if glob_idx % 2 == 0 else rng_prompt
|
289 |
-
print(prompt)
|
290 |
-
image, im_emb = generate(prompt, im_emb)
|
291 |
-
embs += im_emb
|
292 |
-
|
293 |
-
if len(embs) > 700/16:
|
294 |
-
embs = embs[1:]
|
295 |
-
ys = ys[1:]
|
296 |
-
|
297 |
-
return image, embs, ys, calibrate_prompts
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
def start(_, embs, ys, calibrate_prompts):
|
308 |
-
image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
|
309 |
-
return [
|
310 |
-
gr.Button(value='Like (L)', interactive=True),
|
311 |
-
gr.Button(value='Neither (Space)', interactive=True),
|
312 |
-
gr.Button(value='Dislike (A)', interactive=True),
|
313 |
-
gr.Button(value='Start', interactive=False),
|
314 |
-
image,
|
315 |
-
embs,
|
316 |
-
ys,
|
317 |
-
calibrate_prompts
|
318 |
-
]
|
319 |
-
|
320 |
-
|
321 |
-
def choose(img, choice, embs, ys, calibrate_prompts):
|
322 |
-
if choice == 'Like (L)':
|
323 |
-
choice = 1
|
324 |
-
elif choice == 'Neither (Space)':
|
325 |
-
embs = embs[:-1]
|
326 |
-
img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
|
327 |
-
return img, embs, ys, calibrate_prompts
|
328 |
-
else:
|
329 |
-
choice = 0
|
330 |
-
|
331 |
-
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
|
332 |
-
# TODO skip allowing rating
|
333 |
-
if img == None:
|
334 |
-
print('NSFW -- choice is disliked')
|
335 |
-
choice = 0
|
336 |
-
|
337 |
-
ys += [choice]*1
|
338 |
-
img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
|
339 |
-
return img, embs, ys, calibrate_prompts
|
340 |
-
|
341 |
-
css = '''.gradio-container{max-width: 700px !important}
|
342 |
-
#description{text-align: center}
|
343 |
-
#description h1, #description h3{display: block}
|
344 |
-
#description p{margin-top: 0}
|
345 |
-
.fade-in-out {animation: fadeInOut 3s forwards}
|
346 |
-
@keyframes fadeInOut {
|
347 |
-
0% {
|
348 |
-
background: var(--bg-color);
|
349 |
-
}
|
350 |
-
100% {
|
351 |
-
background: var(--button-secondary-background-fill);
|
352 |
-
}
|
353 |
-
}
|
354 |
-
'''
|
355 |
-
js_head = '''
|
356 |
-
<script>
|
357 |
-
document.addEventListener('keydown', function(event) {
|
358 |
-
if (event.key === 'a' || event.key === 'A') {
|
359 |
-
// Trigger click on 'dislike' if 'A' is pressed
|
360 |
-
document.getElementById('dislike').click();
|
361 |
-
} else if (event.key === ' ' || event.keyCode === 32) {
|
362 |
-
// Trigger click on 'neither' if Spacebar is pressed
|
363 |
-
document.getElementById('neither').click();
|
364 |
-
} else if (event.key === 'l' || event.key === 'L') {
|
365 |
-
// Trigger click on 'like' if 'L' is pressed
|
366 |
-
document.getElementById('like').click();
|
367 |
-
}
|
368 |
-
});
|
369 |
-
function fadeInOut(button, color) {
|
370 |
-
button.style.setProperty('--bg-color', color);
|
371 |
-
button.classList.remove('fade-in-out');
|
372 |
-
void button.offsetWidth; // This line forces a repaint by accessing a DOM property
|
373 |
-
|
374 |
-
button.classList.add('fade-in-out');
|
375 |
-
button.addEventListener('animationend', () => {
|
376 |
-
button.classList.remove('fade-in-out'); // Reset the animation state
|
377 |
-
}, {once: true});
|
378 |
-
}
|
379 |
-
document.body.addEventListener('click', function(event) {
|
380 |
-
const target = event.target;
|
381 |
-
if (target.id === 'dislike') {
|
382 |
-
fadeInOut(target, '#ff1717');
|
383 |
-
} else if (target.id === 'like') {
|
384 |
-
fadeInOut(target, '#006500');
|
385 |
-
} else if (target.id === 'neither') {
|
386 |
-
fadeInOut(target, '#cccccc');
|
387 |
-
}
|
388 |
-
});
|
389 |
-
|
390 |
-
</script>
|
391 |
-
'''
|
392 |
-
|
393 |
-
with gr.Blocks(css=css, head=js_head) as demo:
|
394 |
-
gr.Markdown('''### Blue Tigers: Generative Recommenders for Exporation of Video
|
395 |
-
Explore the latent space without text prompts based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
|
396 |
-
''', elem_id="description")
|
397 |
-
embs = gr.State([])
|
398 |
-
ys = gr.State([])
|
399 |
-
calibrate_prompts = gr.State([
|
400 |
-
'the moon is melting into my glass of tea',
|
401 |
-
'a sea slug -- pair of claws scuttling -- jelly fish glowing',
|
402 |
-
'an adorable creature. It may be a goblin or a pig or a slug.',
|
403 |
-
'an animation about a gorgeous nebula',
|
404 |
-
'an octopus writhes',
|
405 |
-
])
|
406 |
-
def l():
|
407 |
-
return None
|
408 |
-
|
409 |
-
with gr.Row(elem_id='output-image'):
|
410 |
-
img = gr.Video(
|
411 |
-
label='Lightning',
|
412 |
-
autoplay=True,
|
413 |
-
interactive=False,
|
414 |
-
height=sidel,
|
415 |
-
width=sidel,
|
416 |
-
include_audio=False,
|
417 |
-
elem_id="video_output"
|
418 |
-
)
|
419 |
-
img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
|
420 |
-
with gr.Row(equal_height=True):
|
421 |
-
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
|
422 |
-
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
|
423 |
-
b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
|
424 |
-
b1.click(
|
425 |
-
choose,
|
426 |
-
[img, b1, embs, ys, calibrate_prompts],
|
427 |
-
[img, embs, ys, calibrate_prompts]
|
428 |
-
)
|
429 |
-
b2.click(
|
430 |
-
choose,
|
431 |
-
[img, b2, embs, ys, calibrate_prompts],
|
432 |
-
[img, embs, ys, calibrate_prompts]
|
433 |
-
)
|
434 |
-
b3.click(
|
435 |
-
choose,
|
436 |
-
[img, b3, embs, ys, calibrate_prompts],
|
437 |
-
[img, embs, ys, calibrate_prompts]
|
438 |
-
)
|
439 |
-
with gr.Row():
|
440 |
-
b4 = gr.Button(value='Start')
|
441 |
-
b4.click(start,
|
442 |
-
[b4, embs, ys, calibrate_prompts],
|
443 |
-
[b1, b2, b3, b4, img, embs, ys, calibrate_prompts])
|
444 |
-
with gr.Row():
|
445 |
-
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br>
|
446 |
-
<div style='text-align:center; font-size:14px'>Note that while the AnimateDiff-Lightning model with NSFW filtering is unlikely to produce NSFW images, this may still occur, and users should avoid NSFW content when rating.
|
447 |
-
</ div>
|
448 |
-
<br><br>
|
449 |
-
<div style='text-align:center; font-size:14px'>Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback.
|
450 |
-
</ div>''')
|
451 |
-
|
452 |
-
demo.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ninth.gemb_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:519ba2479c0605772adbb8405f267b0543316da7520d26988417104b2ffc176b
|
3 |
+
size 1180848
|
ninth.im_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c7583e812651917c82ad7f8e41a031d1b2568e1369dd5cc63959a6bc5fd32959
|
3 |
+
size 6310
|
requirements.txt
CHANGED
@@ -15,6 +15,4 @@ tensorflow==2.14.0
|
|
15 |
imageio
|
16 |
apscheduler
|
17 |
pandas
|
18 |
-
av
|
19 |
-
torchvision
|
20 |
-
bitsandbytes
|
|
|
15 |
imageio
|
16 |
apscheduler
|
17 |
pandas
|
18 |
+
av
|
|
|
|
second.gemb_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c3fd4d35ade16f272d9df5ceb3faf859c20245553a32e41d1f0a7573e247ffde
|
3 |
+
size 1180853
|
second.im_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:389cce5de6ae401dbef57ec7ef8561f4a871a8d84a9107403d511cf259ff1840
|
3 |
+
size 6315
|
seventh.gemb_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8525e6bf8db787722604b13b5d00bb63b0ed20849ecd4c48cb1b64bafb9ba8fa
|
3 |
+
size 1180858
|
seventh.im_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d36e188adfaaca095979cfb8b899e013151a4604cf85c3f194300632890a64d5
|
3 |
+
size 6320
|
sixth.gemb_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:23d6a9c82b0684aec1b5d643bc9613e46d312de82cfae2416d316286bca4d11a
|
3 |
+
size 1180848
|
sixth.im_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8bb985011ba6fadd956681423b824783a7177f7bf3987527db92c657dbbda0b
|
3 |
+
size 6310
|
tenth.gemb_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ba865e86007c31b12074cbd7939fb19491abf80bc6d5f7c16f004f14c70cb2de
|
3 |
+
size 1180848
|
tenth.im_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ded797adeade52c0b2c1ea28e65963336a2b3572c1b5bf3a3f3f0bdfdf7457b6
|
3 |
+
size 6310
|
third.gemb_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6b467e423b27fbc9963c581b3c24b6ed00cc2092d7ee207547f399904007bf67
|
3 |
+
size 1180848
|
third.im_.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bc7db62b595535a68c484d582484d9af1fb1e94673d3e3a9aa7d181c75fe1fec
|
3 |
+
size 6310
|
twitter_prompts.csv
DELETED
@@ -1,72 +0,0 @@
|
|
1 |
-
,0
|
2 |
-
0,a sunset
|
3 |
-
1,a still life in blue
|
4 |
-
2,last day on earth
|
5 |
-
3,the conch shell
|
6 |
-
4,the winds of change
|
7 |
-
5,a surrealist eye
|
8 |
-
6,a surrealist polaroid photo of an apple
|
9 |
-
7,metaphysics
|
10 |
-
8,the sun is setting into my glass of tea
|
11 |
-
9,the moon at 3am
|
12 |
-
10,a memento mori
|
13 |
-
11,quaking aspen tree
|
14 |
-
12,violets and daffodils
|
15 |
-
13,espresso
|
16 |
-
14,sisyphus
|
17 |
-
15,high windows of stained glass
|
18 |
-
16,a green dog
|
19 |
-
17,an adorable companion; it is a pig
|
20 |
-
18,bird of paradise
|
21 |
-
19,a complex intricate machine
|
22 |
-
20,a white clock
|
23 |
-
21,a film featuring the landscape Salt Lake City Utah
|
24 |
-
22,a creature
|
25 |
-
23,a house set aflame.
|
26 |
-
24,a gorgeous landscape by Cy Twombly
|
27 |
-
25,smoke rises from the caterpillar's hookah
|
28 |
-
26,corvid in red
|
29 |
-
27,Monet's pond
|
30 |
-
28,Genesis
|
31 |
-
29,Death is a black camel that kneels down so we can ride
|
32 |
-
30,a cherry tree made of fractals
|
33 |
-
29,the end of the sidewalk
|
34 |
-
30,a polaroid photo of a bustling city of lights and sky scrapers
|
35 |
-
31,The Fig Tree metaphor
|
36 |
-
32,God killed Van Gogh.
|
37 |
-
33,a cosmic entity alien with four eyes.
|
38 |
-
34,a horse with 128 eyes.
|
39 |
-
35,a being with an infinite set of eyes (it is omniscient)
|
40 |
-
36,A sticky-note magnum opus featuring birds
|
41 |
-
37,Moka Pot
|
42 |
-
38,the moon is a sickle cell
|
43 |
-
39,The Penultimate Supper
|
44 |
-
40,Art
|
45 |
-
41,surrealism
|
46 |
-
42,a god made of wires & dust
|
47 |
-
43,a dandelion blown into the universe
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|