Spaces:
Running
Running
concepts = ['sexual', 'nude', 'sex', '18+', 'naked', 'nsfw', 'porn', 'dick', 'vagina', 'naked person (approximation)', | |
'explicit content', 'uncensored', 'fuck', 'nipples', 'nipples (approximation)', 'naked breasts', 'areola'] | |
special_concepts = ["small girl (approximation)", "young child", "young girl"] | |
import dbimutils | |
import torch | |
def init_nsfw_pipe(): | |
from diffusers import StableDiffusionPipeline | |
from torch import nn | |
# make sure you're logged in with `huggingface-cli login` | |
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", | |
torch_dtype=torch.float16) | |
if torch.cuda.is_available(): | |
pipe = pipe.to('cuda') | |
def cosine_distance(image_embeds, text_embeds): | |
normalized_image_embeds = nn.functional.normalize(image_embeds) | |
normalized_text_embeds = nn.functional.normalize(text_embeds) | |
return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) | |
def forward_ours(self, clip_input, images): | |
pooled_output = self.vision_model(clip_input)[1] # pooled_output | |
image_embeds = self.visual_projection(pooled_output) | |
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy() | |
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy() | |
result = [] | |
batch_size = image_embeds.shape[0] | |
for i in range(batch_size): | |
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} | |
# increase this value to create a stronger `nfsw` filter | |
# at the cost of increasing the possibility of filtering benign images | |
adjustment = 0.0 | |
for concet_idx in range(len(special_cos_dist[0])): | |
concept_cos = special_cos_dist[i][concet_idx] | |
concept_threshold = self.special_care_embeds_weights[concet_idx].item() | |
result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) | |
if result_img["special_scores"][concet_idx] > 0: | |
result_img["special_care"].append({"tag": special_concepts[concet_idx], | |
"confidence": result_img["special_scores"][concet_idx]}) | |
adjustment = 0.01 | |
print("Special concept matched:", special_concepts[concet_idx]) | |
for concet_idx in range(len(cos_dist[0])): | |
concept_cos = cos_dist[i][concet_idx] | |
concept_threshold = self.concept_embeds_weights[concet_idx].item() | |
result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) | |
# print("no-special", concet_idx, concepts[concet_idx], concept_threshold, round(concept_cos - concept_threshold + adjustment, 3)) | |
if result_img["concept_scores"][concet_idx] > 0: | |
result_img["bad_concepts"].append({"tag": concepts[concet_idx], | |
"confidence": result_img["concept_scores"][concet_idx]}) | |
print("NSFW concept found:", concepts[concet_idx]) | |
special_tags = list(filter(lambda x: x['confidence'] > 0.4, result_img['special_care'])) | |
bad_tags = list(filter(lambda x: x['confidence'] > 0.4, result_img['bad_concepts'])) | |
result.append({"special_tags": special_tags, | |
"bad_tags": bad_tags, }) | |
return images, result | |
from functools import partial | |
pipe.safety_checker.forward = partial(forward_ours, self=pipe.safety_checker) | |
return pipe | |
def check_nsfw(img, pipe): | |
if isinstance(img, str): | |
img = dbimutils.read_img_from_url(img) | |
safety_checker_input = pipe.feature_extractor(images=img, return_tensors="pt") | |
if torch.cuda.is_available(): | |
safety_checker_input = safety_checker_input.to('cuda') | |
from torch.cuda.amp import autocast | |
with autocast(): | |
_, nsfw_tags = pipe.safety_checker.forward(clip_input=safety_checker_input.pixel_values, images=img) | |
return nsfw_tags | |