Native safety checker

#918
by multimodalart HF staff - opened
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -23,7 +23,7 @@ import user_history
23
  from illusion_style import css
24
  import os
25
  from transformers import CLIPImageProcessor
26
- from safety_checker import StableDiffusionSafetyChecker #-> commenting this out to provoke runtime error
27
 
28
  BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
29
 
@@ -49,16 +49,16 @@ main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
49
  ).to("cuda")
50
 
51
  # Function to check NSFW images
52
- def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]:
53
- if SAFETY_CHECKER_ENABLED:
54
- safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
55
- has_nsfw_concepts = safety_checker(
56
- images=[images],
57
- clip_input=safety_checker_input.pixel_values.to("cuda")
58
- )
59
- return images, has_nsfw_concepts
60
- else:
61
- return images, [False] * len(images)
62
 
63
  #main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
64
  #main_pipe.unet.to(memory_format=torch.channels_last)
@@ -284,4 +284,4 @@ with gr.Blocks(css=css) as app_with_history:
284
  app_with_history.queue(max_size=20,api_open=False )
285
 
286
  if __name__ == "__main__":
287
- app_with_history.launch(max_threads=400)
 
23
  from illusion_style import css
24
  import os
25
  from transformers import CLIPImageProcessor
26
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
27
 
28
  BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
29
 
 
49
  ).to("cuda")
50
 
51
  # Function to check NSFW images
52
+ #def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]:
53
+ # if SAFETY_CHECKER_ENABLED:
54
+ # safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
55
+ # has_nsfw_concepts = safety_checker(
56
+ # images=[images],
57
+ # clip_input=safety_checker_input.pixel_values.to("cuda")
58
+ # )
59
+ # return images, has_nsfw_concepts
60
+ # else:
61
+ # return images, [False] * len(images)
62
 
63
  #main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
64
  #main_pipe.unet.to(memory_format=torch.channels_last)
 
284
  app_with_history.queue(max_size=20,api_open=False )
285
 
286
  if __name__ == "__main__":
287
+ app_with_history.launch(max_threads=400)