jayparmr commited on
Commit
ae524a9
·
1 Parent(s): f256b62

Upload folder using huggingface_hub

Browse files
inference.py CHANGED
@@ -14,18 +14,12 @@ from internals.pipelines.prompt_modifier import PromptModifier
14
  from internals.pipelines.safety_checker import SafetyChecker
15
  from internals.util.args import apply_style_args
16
  from internals.util.avatar import Avatar
17
- from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc
18
- from internals.util.commons import (
19
- download_image,
20
- pickPoses,
21
- upload_image,
22
- upload_images,
23
- )
24
- from internals.util.config import (
25
- num_return_sequences,
26
- set_configs_from_task,
27
- set_root_dir,
28
- )
29
  from internals.util.failure_hander import FailureHandler
30
  from internals.util.lora_style import LoraStyle
31
  from internals.util.slack import Slack
@@ -455,10 +449,6 @@ def model_fn(model_dir):
455
  img2img_pipe.create(text2img_pipe)
456
  inpainter.create(text2img_pipe)
457
 
458
- safety_checker.apply(text2img_pipe)
459
- safety_checker.apply(img2img_pipe)
460
- safety_checker.apply(controlnet)
461
-
462
  print("Logs: model loaded ....")
463
  return
464
 
@@ -474,6 +464,11 @@ def predict_fn(data, pipe):
474
  # Set set_environment
475
  set_configs_from_task(task)
476
 
 
 
 
 
 
477
  # Apply arguments
478
  apply_style_args(data)
479
 
 
14
  from internals.pipelines.safety_checker import SafetyChecker
15
  from internals.util.args import apply_style_args
16
  from internals.util.avatar import Avatar
17
+ from internals.util.cache import (auto_clear_cuda_and_gc, clear_cuda,
18
+ clear_cuda_and_gc)
19
+ from internals.util.commons import (download_image, pickPoses, upload_image,
20
+ upload_images)
21
+ from internals.util.config import (num_return_sequences, set_configs_from_task,
22
+ set_root_dir)
 
 
 
 
 
 
23
  from internals.util.failure_hander import FailureHandler
24
  from internals.util.lora_style import LoraStyle
25
  from internals.util.slack import Slack
 
449
  img2img_pipe.create(text2img_pipe)
450
  inpainter.create(text2img_pipe)
451
 
 
 
 
 
452
  print("Logs: model loaded ....")
453
  return
454
 
 
464
  # Set set_environment
465
  set_configs_from_task(task)
466
 
467
+ # Apply safety checkers based on environment
468
+ safety_checker.apply(text2img_pipe)
469
+ safety_checker.apply(img2img_pipe)
470
+ safety_checker.apply(controlnet)
471
+
472
  # Apply arguments
473
  apply_style_args(data)
474
 
inference2.py CHANGED
@@ -7,18 +7,17 @@ from internals.data.task import ModelType, Task, TaskType
7
  from internals.pipelines.inpainter import InPainter
8
  from internals.pipelines.object_remove import ObjectRemoval
9
  from internals.pipelines.prompt_modifier import PromptModifier
10
- from internals.pipelines.remove_background import RemoveBackground, RemoveBackgroundV2
 
11
  from internals.pipelines.replace_background import ReplaceBackground
12
  from internals.pipelines.safety_checker import SafetyChecker
13
  from internals.pipelines.upscaler import Upscaler
14
  from internals.util.avatar import Avatar
15
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
16
- from internals.util.commons import construct_default_s3_url, upload_image, upload_images
17
- from internals.util.config import (
18
- num_return_sequences,
19
- set_configs_from_task,
20
- set_root_dir,
21
- )
22
  from internals.util.failure_hander import FailureHandler
23
  from internals.util.slack import Slack
24
 
@@ -173,8 +172,6 @@ def model_fn(model_dir):
173
 
174
  replace_background.load(upscaler, remove_background_v2)
175
 
176
- safety_checker.apply(inpainter)
177
-
178
  print("Logs: model loaded ....")
179
  return
180
 
@@ -186,13 +183,13 @@ def predict_fn(data, pipe):
186
 
187
  FailureHandler.handle(task)
188
 
189
- # Set set_environment
190
- set_configs_from_task(task)
191
-
192
  try:
193
  # Set set_environment
194
  set_configs_from_task(task)
195
 
 
 
 
196
  # Fetch avatars
197
  avatar.fetch_from_network(task.get_model_id())
198
 
 
7
  from internals.pipelines.inpainter import InPainter
8
  from internals.pipelines.object_remove import ObjectRemoval
9
  from internals.pipelines.prompt_modifier import PromptModifier
10
+ from internals.pipelines.remove_background import (RemoveBackground,
11
+ RemoveBackgroundV2)
12
  from internals.pipelines.replace_background import ReplaceBackground
13
  from internals.pipelines.safety_checker import SafetyChecker
14
  from internals.pipelines.upscaler import Upscaler
15
  from internals.util.avatar import Avatar
16
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
17
+ from internals.util.commons import (construct_default_s3_url, upload_image,
18
+ upload_images)
19
+ from internals.util.config import (num_return_sequences, set_configs_from_task,
20
+ set_root_dir)
 
 
21
  from internals.util.failure_hander import FailureHandler
22
  from internals.util.slack import Slack
23
 
 
172
 
173
  replace_background.load(upscaler, remove_background_v2)
174
 
 
 
175
  print("Logs: model loaded ....")
176
  return
177
 
 
183
 
184
  FailureHandler.handle(task)
185
 
 
 
 
186
  try:
187
  # Set set_environment
188
  set_configs_from_task(task)
189
 
190
+ # Apply safety checker based on environment
191
+ safety_checker.apply(inpainter)
192
+
193
  # Fetch avatars
194
  avatar.fetch_from_network(task.get_model_id())
195
 
internals/pipelines/safety_checker.py CHANGED
@@ -4,6 +4,7 @@ import cv2
4
  import numpy as np
5
  import torch
6
  import torch.nn as nn
 
7
  from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
8
 
9
  from internals.pipelines.commons import AbstractPipeline
@@ -23,10 +24,17 @@ class SafetyChecker:
23
  ).to("cuda")
24
 
25
  def apply(self, pipeline: AbstractPipeline):
 
26
  if hasattr(pipeline, "pipe"):
27
- pipeline.pipe.safety_checker = self.model
28
  if hasattr(pipeline, "pipe2"):
29
- pipeline.pipe2.safety_checker = self.model
 
 
 
 
 
 
30
 
31
 
32
  class StableDiffusionSafetyCheckerV2(PreTrainedModel):
@@ -102,7 +110,7 @@ class StableDiffusionSafetyCheckerV2(PreTrainedModel):
102
  result_img["concept_scores"][concept_idx] = round(
103
  concept_cos - concept_threshold + adjustment, 3
104
  )
105
- if result_img["concept_scores"][concept_idx] > get_nsfw_threshold():
106
  result_img["bad_concepts"].append(concept_idx)
107
 
108
  result.append(result_img)
@@ -115,11 +123,12 @@ class StableDiffusionSafetyCheckerV2(PreTrainedModel):
115
  if any(has_nsfw_concepts) and not get_nsfw_access():
116
  if torch.is_tensor(images) or torch.is_tensor(images[0]):
117
  image = images[idx].cpu().numpy().astype(np.float32)
118
- image = cv2.blur(image, (30, 30))
 
119
  image = torch.from_numpy(image)
120
  images[idx] = image
121
  else:
122
- images[idx] = cv2.blur(images[idx], (30, 30))
123
 
124
  if any(has_nsfw_concepts):
125
  print("NSFW")
@@ -150,13 +159,14 @@ class StableDiffusionSafetyCheckerV2(PreTrainedModel):
150
 
151
  concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
152
  # concept_scores = concept_scores.round(decimals=3)
153
- has_nsfw_concepts = torch.any(concept_scores > get_nsfw_threshold(), dim=1)
154
 
 
155
  # Blur images based on NSFW score
156
  # -------------------------------
157
  if not get_nsfw_access():
158
  image = images[has_nsfw_concepts].cpu().numpy().astype(np.float32)
159
- image = cv2.blur(image, (30, 30))
160
  image = torch.from_numpy(image)
161
  images[has_nsfw_concepts] = image
162
 
 
4
  import numpy as np
5
  import torch
6
  import torch.nn as nn
7
+ from scipy.ndimage.filters import gaussian_filter
8
  from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
9
 
10
  from internals.pipelines.commons import AbstractPipeline
 
24
  ).to("cuda")
25
 
26
  def apply(self, pipeline: AbstractPipeline):
27
+ model = self.model if not get_nsfw_access() else None
28
  if hasattr(pipeline, "pipe"):
29
+ pipeline.pipe.safety_checker = model
30
  if hasattr(pipeline, "pipe2"):
31
+ pipeline.pipe2.safety_checker = model
32
+
33
+
34
+ def cosine_distance(image_embeds, text_embeds):
35
+ normalized_image_embeds = nn.functional.normalize(image_embeds)
36
+ normalized_text_embeds = nn.functional.normalize(text_embeds)
37
+ return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
38
 
39
 
40
  class StableDiffusionSafetyCheckerV2(PreTrainedModel):
 
110
  result_img["concept_scores"][concept_idx] = round(
111
  concept_cos - concept_threshold + adjustment, 3
112
  )
113
+ if result_img["concept_scores"][concept_idx] > 0:
114
  result_img["bad_concepts"].append(concept_idx)
115
 
116
  result.append(result_img)
 
123
  if any(has_nsfw_concepts) and not get_nsfw_access():
124
  if torch.is_tensor(images) or torch.is_tensor(images[0]):
125
  image = images[idx].cpu().numpy().astype(np.float32)
126
+ image = gaussian_filter(image, sigma=7)
127
+ # image = cv2.blur(image, (30, 30))
128
  image = torch.from_numpy(image)
129
  images[idx] = image
130
  else:
131
+ images[idx] = gaussian_filter(images[idx], sigma=7)
132
 
133
  if any(has_nsfw_concepts):
134
  print("NSFW")
 
159
 
160
  concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
161
  # concept_scores = concept_scores.round(decimals=3)
162
+ has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
163
 
164
+ # images[has_nsfw_concepts] = 0.0 # black image
165
  # Blur images based on NSFW score
166
  # -------------------------------
167
  if not get_nsfw_access():
168
  image = images[has_nsfw_concepts].cpu().numpy().astype(np.float32)
169
+ image = gaussian_filter(image, sigma=7)
170
  image = torch.from_numpy(image)
171
  images[has_nsfw_concepts] = image
172