Spaces:
Paused
Paused
added safety checker
Browse files
app.py
CHANGED
@@ -2,10 +2,12 @@ import gradio as gr
|
|
2 |
import spaces
|
3 |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
|
4 |
import torch
|
|
|
|
|
5 |
|
6 |
# load pipeline
|
7 |
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
8 |
-
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda")
|
9 |
|
10 |
# load finetuned model
|
11 |
unet_id = "mhdang/dpo-sdxl-text2image-v1"
|
@@ -22,8 +24,11 @@ def infer(prompt):
|
|
22 |
—/n
|
23 |
{prompt}
|
24 |
""")
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
27 |
|
28 |
css = """
|
29 |
#col-container{
|
|
|
2 |
import spaces
|
3 |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
|
4 |
import torch
|
5 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
6 |
+
SAFETY_CHECKER = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=torch.float16)
|
7 |
|
8 |
# load pipeline
|
9 |
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
10 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True, safety_checker=SAFETY_CHECKER).to("cuda")
|
11 |
|
12 |
# load finetuned model
|
13 |
unet_id = "mhdang/dpo-sdxl-text2image-v1"
|
|
|
24 |
—/n
|
25 |
{prompt}
|
26 |
""")
|
27 |
+
images = pipe(prompt, guidance_scale=7.5).images
|
28 |
+
for i in range(len(images)):
|
29 |
+
if results.nsfw_content_detected[i]:
|
30 |
+
images[i] = Image.open("nsfw.png")
|
31 |
+
return image[0]
|
32 |
|
33 |
css = """
|
34 |
#col-container{
|