Upload folder using huggingface_hub
Browse files- inference.py +11 -16
- inference2.py +9 -12
- internals/pipelines/safety_checker.py +17 -7
inference.py
CHANGED
@@ -14,18 +14,12 @@ from internals.pipelines.prompt_modifier import PromptModifier
|
|
14 |
from internals.pipelines.safety_checker import SafetyChecker
|
15 |
from internals.util.args import apply_style_args
|
16 |
from internals.util.avatar import Avatar
|
17 |
-
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda,
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
)
|
24 |
-
from internals.util.config import (
|
25 |
-
num_return_sequences,
|
26 |
-
set_configs_from_task,
|
27 |
-
set_root_dir,
|
28 |
-
)
|
29 |
from internals.util.failure_hander import FailureHandler
|
30 |
from internals.util.lora_style import LoraStyle
|
31 |
from internals.util.slack import Slack
|
@@ -455,10 +449,6 @@ def model_fn(model_dir):
|
|
455 |
img2img_pipe.create(text2img_pipe)
|
456 |
inpainter.create(text2img_pipe)
|
457 |
|
458 |
-
safety_checker.apply(text2img_pipe)
|
459 |
-
safety_checker.apply(img2img_pipe)
|
460 |
-
safety_checker.apply(controlnet)
|
461 |
-
|
462 |
print("Logs: model loaded ....")
|
463 |
return
|
464 |
|
@@ -474,6 +464,11 @@ def predict_fn(data, pipe):
|
|
474 |
# Set set_environment
|
475 |
set_configs_from_task(task)
|
476 |
|
|
|
|
|
|
|
|
|
|
|
477 |
# Apply arguments
|
478 |
apply_style_args(data)
|
479 |
|
|
|
14 |
from internals.pipelines.safety_checker import SafetyChecker
|
15 |
from internals.util.args import apply_style_args
|
16 |
from internals.util.avatar import Avatar
|
17 |
+
from internals.util.cache import (auto_clear_cuda_and_gc, clear_cuda,
|
18 |
+
clear_cuda_and_gc)
|
19 |
+
from internals.util.commons import (download_image, pickPoses, upload_image,
|
20 |
+
upload_images)
|
21 |
+
from internals.util.config import (num_return_sequences, set_configs_from_task,
|
22 |
+
set_root_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
from internals.util.failure_hander import FailureHandler
|
24 |
from internals.util.lora_style import LoraStyle
|
25 |
from internals.util.slack import Slack
|
|
|
449 |
img2img_pipe.create(text2img_pipe)
|
450 |
inpainter.create(text2img_pipe)
|
451 |
|
|
|
|
|
|
|
|
|
452 |
print("Logs: model loaded ....")
|
453 |
return
|
454 |
|
|
|
464 |
# Set set_environment
|
465 |
set_configs_from_task(task)
|
466 |
|
467 |
+
# Apply safety checkers based on environment
|
468 |
+
safety_checker.apply(text2img_pipe)
|
469 |
+
safety_checker.apply(img2img_pipe)
|
470 |
+
safety_checker.apply(controlnet)
|
471 |
+
|
472 |
# Apply arguments
|
473 |
apply_style_args(data)
|
474 |
|
inference2.py
CHANGED
@@ -7,18 +7,17 @@ from internals.data.task import ModelType, Task, TaskType
|
|
7 |
from internals.pipelines.inpainter import InPainter
|
8 |
from internals.pipelines.object_remove import ObjectRemoval
|
9 |
from internals.pipelines.prompt_modifier import PromptModifier
|
10 |
-
from internals.pipelines.remove_background import RemoveBackground,
|
|
|
11 |
from internals.pipelines.replace_background import ReplaceBackground
|
12 |
from internals.pipelines.safety_checker import SafetyChecker
|
13 |
from internals.pipelines.upscaler import Upscaler
|
14 |
from internals.util.avatar import Avatar
|
15 |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
16 |
-
from internals.util.commons import construct_default_s3_url, upload_image,
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
set_root_dir,
|
21 |
-
)
|
22 |
from internals.util.failure_hander import FailureHandler
|
23 |
from internals.util.slack import Slack
|
24 |
|
@@ -173,8 +172,6 @@ def model_fn(model_dir):
|
|
173 |
|
174 |
replace_background.load(upscaler, remove_background_v2)
|
175 |
|
176 |
-
safety_checker.apply(inpainter)
|
177 |
-
|
178 |
print("Logs: model loaded ....")
|
179 |
return
|
180 |
|
@@ -186,13 +183,13 @@ def predict_fn(data, pipe):
|
|
186 |
|
187 |
FailureHandler.handle(task)
|
188 |
|
189 |
-
# Set set_environment
|
190 |
-
set_configs_from_task(task)
|
191 |
-
|
192 |
try:
|
193 |
# Set set_environment
|
194 |
set_configs_from_task(task)
|
195 |
|
|
|
|
|
|
|
196 |
# Fetch avatars
|
197 |
avatar.fetch_from_network(task.get_model_id())
|
198 |
|
|
|
7 |
from internals.pipelines.inpainter import InPainter
|
8 |
from internals.pipelines.object_remove import ObjectRemoval
|
9 |
from internals.pipelines.prompt_modifier import PromptModifier
|
10 |
+
from internals.pipelines.remove_background import (RemoveBackground,
|
11 |
+
RemoveBackgroundV2)
|
12 |
from internals.pipelines.replace_background import ReplaceBackground
|
13 |
from internals.pipelines.safety_checker import SafetyChecker
|
14 |
from internals.pipelines.upscaler import Upscaler
|
15 |
from internals.util.avatar import Avatar
|
16 |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
17 |
+
from internals.util.commons import (construct_default_s3_url, upload_image,
|
18 |
+
upload_images)
|
19 |
+
from internals.util.config import (num_return_sequences, set_configs_from_task,
|
20 |
+
set_root_dir)
|
|
|
|
|
21 |
from internals.util.failure_hander import FailureHandler
|
22 |
from internals.util.slack import Slack
|
23 |
|
|
|
172 |
|
173 |
replace_background.load(upscaler, remove_background_v2)
|
174 |
|
|
|
|
|
175 |
print("Logs: model loaded ....")
|
176 |
return
|
177 |
|
|
|
183 |
|
184 |
FailureHandler.handle(task)
|
185 |
|
|
|
|
|
|
|
186 |
try:
|
187 |
# Set set_environment
|
188 |
set_configs_from_task(task)
|
189 |
|
190 |
+
# Apply safety checker based on environment
|
191 |
+
safety_checker.apply(inpainter)
|
192 |
+
|
193 |
# Fetch avatars
|
194 |
avatar.fetch_from_network(task.get_model_id())
|
195 |
|
internals/pipelines/safety_checker.py
CHANGED
@@ -4,6 +4,7 @@ import cv2
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import torch.nn as nn
|
|
|
7 |
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
|
8 |
|
9 |
from internals.pipelines.commons import AbstractPipeline
|
@@ -23,10 +24,17 @@ class SafetyChecker:
|
|
23 |
).to("cuda")
|
24 |
|
25 |
def apply(self, pipeline: AbstractPipeline):
|
|
|
26 |
if hasattr(pipeline, "pipe"):
|
27 |
-
pipeline.pipe.safety_checker =
|
28 |
if hasattr(pipeline, "pipe2"):
|
29 |
-
pipeline.pipe2.safety_checker =
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
|
32 |
class StableDiffusionSafetyCheckerV2(PreTrainedModel):
|
@@ -102,7 +110,7 @@ class StableDiffusionSafetyCheckerV2(PreTrainedModel):
|
|
102 |
result_img["concept_scores"][concept_idx] = round(
|
103 |
concept_cos - concept_threshold + adjustment, 3
|
104 |
)
|
105 |
-
if result_img["concept_scores"][concept_idx] >
|
106 |
result_img["bad_concepts"].append(concept_idx)
|
107 |
|
108 |
result.append(result_img)
|
@@ -115,11 +123,12 @@ class StableDiffusionSafetyCheckerV2(PreTrainedModel):
|
|
115 |
if any(has_nsfw_concepts) and not get_nsfw_access():
|
116 |
if torch.is_tensor(images) or torch.is_tensor(images[0]):
|
117 |
image = images[idx].cpu().numpy().astype(np.float32)
|
118 |
-
image =
|
|
|
119 |
image = torch.from_numpy(image)
|
120 |
images[idx] = image
|
121 |
else:
|
122 |
-
images[idx] =
|
123 |
|
124 |
if any(has_nsfw_concepts):
|
125 |
print("NSFW")
|
@@ -150,13 +159,14 @@ class StableDiffusionSafetyCheckerV2(PreTrainedModel):
|
|
150 |
|
151 |
concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
|
152 |
# concept_scores = concept_scores.round(decimals=3)
|
153 |
-
has_nsfw_concepts = torch.any(concept_scores >
|
154 |
|
|
|
155 |
# Blur images based on NSFW score
|
156 |
# -------------------------------
|
157 |
if not get_nsfw_access():
|
158 |
image = images[has_nsfw_concepts].cpu().numpy().astype(np.float32)
|
159 |
-
image =
|
160 |
image = torch.from_numpy(image)
|
161 |
images[has_nsfw_concepts] = image
|
162 |
|
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
+
from scipy.ndimage.filters import gaussian_filter
|
8 |
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
|
9 |
|
10 |
from internals.pipelines.commons import AbstractPipeline
|
|
|
24 |
).to("cuda")
|
25 |
|
26 |
def apply(self, pipeline: AbstractPipeline):
|
27 |
+
model = self.model if not get_nsfw_access() else None
|
28 |
if hasattr(pipeline, "pipe"):
|
29 |
+
pipeline.pipe.safety_checker = model
|
30 |
if hasattr(pipeline, "pipe2"):
|
31 |
+
pipeline.pipe2.safety_checker = model
|
32 |
+
|
33 |
+
|
34 |
+
def cosine_distance(image_embeds, text_embeds):
|
35 |
+
normalized_image_embeds = nn.functional.normalize(image_embeds)
|
36 |
+
normalized_text_embeds = nn.functional.normalize(text_embeds)
|
37 |
+
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
|
38 |
|
39 |
|
40 |
class StableDiffusionSafetyCheckerV2(PreTrainedModel):
|
|
|
110 |
result_img["concept_scores"][concept_idx] = round(
|
111 |
concept_cos - concept_threshold + adjustment, 3
|
112 |
)
|
113 |
+
if result_img["concept_scores"][concept_idx] > 0:
|
114 |
result_img["bad_concepts"].append(concept_idx)
|
115 |
|
116 |
result.append(result_img)
|
|
|
123 |
if any(has_nsfw_concepts) and not get_nsfw_access():
|
124 |
if torch.is_tensor(images) or torch.is_tensor(images[0]):
|
125 |
image = images[idx].cpu().numpy().astype(np.float32)
|
126 |
+
image = gaussian_filter(image, sigma=7)
|
127 |
+
# image = cv2.blur(image, (30, 30))
|
128 |
image = torch.from_numpy(image)
|
129 |
images[idx] = image
|
130 |
else:
|
131 |
+
images[idx] = gaussian_filter(images[idx], sigma=7)
|
132 |
|
133 |
if any(has_nsfw_concepts):
|
134 |
print("NSFW")
|
|
|
159 |
|
160 |
concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
|
161 |
# concept_scores = concept_scores.round(decimals=3)
|
162 |
+
has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
|
163 |
|
164 |
+
# images[has_nsfw_concepts] = 0.0 # black image
|
165 |
# Blur images based on NSFW score
|
166 |
# -------------------------------
|
167 |
if not get_nsfw_access():
|
168 |
image = images[has_nsfw_concepts].cpu().numpy().astype(np.float32)
|
169 |
+
image = gaussian_filter(image, sigma=7)
|
170 |
image = torch.from_numpy(image)
|
171 |
images[has_nsfw_concepts] = image
|
172 |
|