radames commited on
Commit
04a8586
1 Parent(s): de1f644

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPFeatureExtractor
2
+ from safety_checker import StableDiffusionSafetyChecker
3
+ import torch
4
+ from PIL import Image
5
+ import gradio as gr
6
+ from pathlib import Path
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
10
+ "CompVis/stable-diffusion-safety-checker"
11
+ ).to(device)
12
+ feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
13
+
14
+
15
+ import gradio as gr
16
+
17
+
18
+ def image_classifier(files):
19
+ images = [Image.open(file).convert("RGB").resize((512, 512)) for file in files]
20
+
21
+ safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
22
+ has_nsfw_concepts = safety_checker(
23
+ images=[images], clip_input=safety_checker_input.pixel_values.to(torch.float16)
24
+ )
25
+ results = [
26
+ {"has_nsfw": nsfw, "file": Path(file).name}
27
+ for (nsfw, file) in zip(has_nsfw_concepts, files)
28
+ ]
29
+ return {"results": results}
30
+
31
+
32
+ demo = gr.Interface(
33
+ fn=image_classifier,
34
+ inputs=gr.File(file_count="multiple", file_types=["image"]),
35
+ outputs="json",
36
+ api_name="classify",
37
+ )
38
+ demo.launch()