fffiloni commited on
Commit
bccea8e
1 Parent(s): 6303d5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -7
app.py CHANGED
@@ -1,13 +1,15 @@
1
  import gradio as gr
 
 
 
2
  import spaces
3
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
4
  import torch
5
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
6
- SAFETY_CHECKER = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=torch.float16)
7
 
8
  # load pipeline
9
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
10
- pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True, safety_checker=SAFETY_CHECKER).to("cuda")
11
 
12
  # load finetuned model
13
  unet_id = "mhdang/dpo-sdxl-text2image-v1"
@@ -18,17 +20,32 @@ pipe = pipe.to("cuda")
18
  pipe.enable_model_cpu_offload()
19
  pipe.enable_vae_slicing()
20
 
 
 
 
 
 
 
 
 
21
  @spaces.GPU(enable_queue=True)
22
  def infer(prompt):
23
  print(f"""
24
  —/n
25
  {prompt}
26
  """)
27
- results = pipe(prompt, guidance_scale=7.5)
28
- #for i in range(len(results.images)):
29
- # if results.nsfw_content_detected[i]:
30
- # results.images[i] = Image.open("nsfw.png")
31
- return results.images[0]
 
 
 
 
 
 
 
32
 
33
  css = """
34
  #col-container{
 
1
  import gradio as gr
2
+ from gradio_client import Client
3
+ client = Client("https://fffiloni-safety-checker-bot.hf.space/")
4
+ import re
5
  import spaces
6
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
7
  import torch
8
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
 
9
 
10
  # load pipeline
11
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
12
+ pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda")
13
 
14
  # load finetuned model
15
  unet_id = "mhdang/dpo-sdxl-text2image-v1"
 
20
  pipe.enable_model_cpu_offload()
21
  pipe.enable_vae_slicing()
22
 
23
+ def safety_check(user_prompt):
24
+ client = Client("https://fffiloni-safety-checker-bot.hf.space/--replicas/dvzmf/")
25
+ response = client.predict(
26
+ user_prompt, # str in 'User sent this' Textbox component
27
+ api_name="/infer"
28
+ )
29
+ return response
30
+
31
  @spaces.GPU(enable_queue=True)
32
  def infer(prompt):
33
  print(f"""
34
  —/n
35
  {prompt}
36
  """)
37
+ is_safe = safety_check(prompt)
38
+ words = is_safe.split()
39
+ # Take the first word and convert it to a string variable
40
+ status = str(words[0])
41
+ if status == "True" :
42
+ gr.Error("Don't.")
43
+ else:
44
+ results = pipe(prompt, guidance_scale=7.5)
45
+ #for i in range(len(results.images)):
46
+ # if results.nsfw_content_detected[i]:
47
+ # results.images[i] = Image.open("nsfw.png")
48
+ return results.images[0]
49
 
50
  css = """
51
  #col-container{