Spaces:
Running
on
A10G
Running
on
A10G
rynmurdock
commited on
Commit
•
5c43323
1
Parent(s):
32bd2b2
added safety checker
Browse files- app.py +23 -8
- nsfweffnetv2-b02-3epochs.h5 +3 -0
- safety_checker_improved.py +45 -0
app.py
CHANGED
@@ -27,6 +27,8 @@ from transformers import CLIPVisionModelWithProjection
|
|
27 |
from huggingface_hub import hf_hub_download
|
28 |
from safetensors.torch import load_file
|
29 |
|
|
|
|
|
30 |
prompt_list = [p for p in list(set(
|
31 |
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
32 |
|
@@ -56,6 +58,8 @@ pipe.to(device=DEVICE)
|
|
56 |
# TODO put back
|
57 |
@spaces.GPU
|
58 |
def compile_em():
|
|
|
|
|
59 |
pipe.unet = torch.compile(pipe.unet)
|
60 |
pipe.vae = torch.compile(pipe.vae, mode='reduce-overhead')
|
61 |
autoencoder.model.forward = torch.compile(autoencoder.model.forward, backend='inductor', dynamic=True)
|
@@ -160,6 +164,11 @@ def predict(
|
|
160 |
im_emb, _ = pipe.encode_image(
|
161 |
image, DEVICE, 1, output_hidden_state
|
162 |
)
|
|
|
|
|
|
|
|
|
|
|
163 |
return image, im_emb.to('cpu')
|
164 |
|
165 |
|
@@ -245,10 +254,10 @@ def next_image(embs, img_embs, ys, calibrate_prompts):
|
|
245 |
image, img_emb = predict(prompt, im_emb=img_emb)
|
246 |
img_embs.append(img_emb)
|
247 |
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
return image, embs, img_embs, ys, calibrate_prompts
|
253 |
|
254 |
|
@@ -274,7 +283,7 @@ def start(_, embs, img_embs, ys, calibrate_prompts):
|
|
274 |
]
|
275 |
|
276 |
|
277 |
-
def choose(choice, embs, img_embs, ys, calibrate_prompts):
|
278 |
if choice == 'Like (L)':
|
279 |
choice = 1
|
280 |
elif choice == 'Neither (Space)':
|
@@ -284,6 +293,12 @@ def choose(choice, embs, img_embs, ys, calibrate_prompts):
|
|
284 |
return img, embs, img_embs, ys, calibrate_prompts
|
285 |
else:
|
286 |
choice = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
ys.append(choice)
|
288 |
img, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts)
|
289 |
return img, embs, img_embs, ys, calibrate_prompts
|
@@ -363,17 +378,17 @@ with gr.Blocks(css=css, head=js_head) as demo:
|
|
363 |
b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
|
364 |
b1.click(
|
365 |
choose,
|
366 |
-
[b1, embs, img_embs, ys, calibrate_prompts],
|
367 |
[img, embs, img_embs, ys, calibrate_prompts]
|
368 |
)
|
369 |
b2.click(
|
370 |
choose,
|
371 |
-
[b2, embs, img_embs, ys, calibrate_prompts],
|
372 |
[img, embs, img_embs, ys, calibrate_prompts]
|
373 |
)
|
374 |
b3.click(
|
375 |
choose,
|
376 |
-
[b3, embs, img_embs, ys, calibrate_prompts],
|
377 |
[img, embs, img_embs, ys, calibrate_prompts]
|
378 |
)
|
379 |
with gr.Row():
|
|
|
27 |
from huggingface_hub import hf_hub_download
|
28 |
from safetensors.torch import load_file
|
29 |
|
30 |
+
from safety_checker_improved import maybe_nsfw
|
31 |
+
|
32 |
prompt_list = [p for p in list(set(
|
33 |
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
34 |
|
|
|
58 |
# TODO put back
|
59 |
@spaces.GPU
|
60 |
def compile_em():
|
61 |
+
# TODO Compile
|
62 |
+
return None
|
63 |
pipe.unet = torch.compile(pipe.unet)
|
64 |
pipe.vae = torch.compile(pipe.vae, mode='reduce-overhead')
|
65 |
autoencoder.model.forward = torch.compile(autoencoder.model.forward, backend='inductor', dynamic=True)
|
|
|
164 |
im_emb, _ = pipe.encode_image(
|
165 |
image, DEVICE, 1, output_hidden_state
|
166 |
)
|
167 |
+
|
168 |
+
nsfw = maybe_nsfw(image)
|
169 |
+
if nsfw:
|
170 |
+
return None, im_emb.to('cpu')
|
171 |
+
|
172 |
return image, im_emb.to('cpu')
|
173 |
|
174 |
|
|
|
254 |
image, img_emb = predict(prompt, im_emb=img_emb)
|
255 |
img_embs.append(img_emb)
|
256 |
|
257 |
+
if len(embs) > 100:
|
258 |
+
embs.pop(0)
|
259 |
+
img_embs.pop(0)
|
260 |
+
ys.pop(0)
|
261 |
return image, embs, img_embs, ys, calibrate_prompts
|
262 |
|
263 |
|
|
|
283 |
]
|
284 |
|
285 |
|
286 |
+
def choose(img, choice, embs, img_embs, ys, calibrate_prompts):
|
287 |
if choice == 'Like (L)':
|
288 |
choice = 1
|
289 |
elif choice == 'Neither (Space)':
|
|
|
293 |
return img, embs, img_embs, ys, calibrate_prompts
|
294 |
else:
|
295 |
choice = 0
|
296 |
+
|
297 |
+
print(img, 'img')
|
298 |
+
if img is None:
|
299 |
+
print('NSFW -- choice is disliked')
|
300 |
+
choice = 0
|
301 |
+
|
302 |
ys.append(choice)
|
303 |
img, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts)
|
304 |
return img, embs, img_embs, ys, calibrate_prompts
|
|
|
378 |
b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
|
379 |
b1.click(
|
380 |
choose,
|
381 |
+
[img, b1, embs, img_embs, ys, calibrate_prompts],
|
382 |
[img, embs, img_embs, ys, calibrate_prompts]
|
383 |
)
|
384 |
b2.click(
|
385 |
choose,
|
386 |
+
[img, b2, embs, img_embs, ys, calibrate_prompts],
|
387 |
[img, embs, img_embs, ys, calibrate_prompts]
|
388 |
)
|
389 |
b3.click(
|
390 |
choose,
|
391 |
+
[img, b3, embs, img_embs, ys, calibrate_prompts],
|
392 |
[img, embs, img_embs, ys, calibrate_prompts]
|
393 |
)
|
394 |
with gr.Row():
|
nsfweffnetv2-b02-3epochs.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91422f388d1632c1af21b3d787b4f6c1a8e6114f600162d392b0bf285ff8a433
|
3 |
+
size 71027272
|
safety_checker_improved.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# TODO required tensorflow==2.14 for me
|
3 |
+
# weights from https://github.com/LAION-AI/safety-pipeline/tree/main
|
4 |
+
from PIL import Image
|
5 |
+
import tensorflow_hub as hub
|
6 |
+
import tensorflow
|
7 |
+
import numpy as np
|
8 |
+
import sys
|
9 |
+
sys.path.append('/home/ryn_mote/Misc/generative_recommender/gradio_video/automl/efficientnetv2/')
|
10 |
+
import tensorflow as tf
|
11 |
+
from tensorflow.keras import mixed_precision
|
12 |
+
physical_devices = tf.config.list_physical_devices('GPU')
|
13 |
+
|
14 |
+
tf.config.experimental.set_memory_growth(
|
15 |
+
physical_devices[0], True
|
16 |
+
)
|
17 |
+
|
18 |
+
model = tf.keras.models.load_model('nsfweffnetv2-b02-3epochs.h5',custom_objects={"KerasLayer":hub.KerasLayer})
|
19 |
+
# "The image classifier had been trained on 682550 images from the 5 classes "Drawing" (39026), "Hentai" (28134), "Neutral" (369507), "Porn" (207969) & "Sexy" (37914).
|
20 |
+
# ... we created a manually inspected test set that consists of 4900 samples, that contains images & their captions."
|
21 |
+
|
22 |
+
# Run prediction
|
23 |
+
def maybe_nsfw(pil_image):
|
24 |
+
# Run prediction
|
25 |
+
imm = tensorflow.image.resize(np.array(pil_image)[:, :, :3], (260, 260))
|
26 |
+
imm = (imm / 255)
|
27 |
+
pred = model(tensorflow.expand_dims(imm, 0)).numpy()
|
28 |
+
probs = tensorflow.math.softmax(pred[0]).numpy()
|
29 |
+
print(probs)
|
30 |
+
if all([i < .3 for i in probs[[1, 3, 4]]]):
|
31 |
+
return False
|
32 |
+
return True
|
33 |
+
|
34 |
+
# pre-initializing prediction
|
35 |
+
maybe_nsfw(Image. new("RGB", (260, 260), 255))
|
36 |
+
model.load_weights('nsfweffnetv2-b02-3epochs.h5', by_name=True, )
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|