fffiloni commited on
Commit
4b15ff4
·
1 Parent(s): 784addb

added safety checker

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -2,10 +2,12 @@ import gradio as gr
2
  import spaces
3
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
4
  import torch
 
 
5
 
6
  # load pipeline
7
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
8
- pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda")
9
 
10
  # load finetuned model
11
  unet_id = "mhdang/dpo-sdxl-text2image-v1"
@@ -22,8 +24,11 @@ def infer(prompt):
22
  —/n
23
  {prompt}
24
  """)
25
- image = pipe(prompt, guidance_scale=7.5).images[0]
26
- return image
 
 
 
27
 
28
  css = """
29
  #col-container{
 
2
  import spaces
3
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
4
  import torch
5
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
6
+ SAFETY_CHECKER = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=torch.float16)
7
 
8
  # load pipeline
9
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
10
+ pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True, safety_checker=SAFETY_CHECKER).to("cuda")
11
 
12
  # load finetuned model
13
  unet_id = "mhdang/dpo-sdxl-text2image-v1"
 
24
  —/n
25
  {prompt}
26
  """)
27
+ images = pipe(prompt, guidance_scale=7.5).images
28
+ for i in range(len(images)):
29
+ if results.nsfw_content_detected[i]:
30
+ images[i] = Image.open("nsfw.png")
31
+ return image[0]
32
 
33
  css = """
34
  #col-container{