BasicNp's picture
Upload 1672 files
e8aa256 verified
import numpy as np
import torch
import torch.nn as nn
from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel
from ...utils import logging
logger = logging.get_logger(__name__)
class IFSafetyChecker(PreTrainedModel):
config_class = CLIPConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPConfig):
super().__init__(config)
self.vision_model = CLIPVisionModelWithProjection(config.vision_config)
self.p_head = nn.Linear(config.vision_config.projection_dim, 1)
self.w_head = nn.Linear(config.vision_config.projection_dim, 1)
@torch.no_grad()
def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5):
image_embeds = self.vision_model(clip_input)[0]
nsfw_detected = self.p_head(image_embeds)
nsfw_detected = nsfw_detected.flatten()
nsfw_detected = nsfw_detected > p_threshold
nsfw_detected = nsfw_detected.tolist()
if any(nsfw_detected):
logger.warning(
"Potential NSFW content was detected in one or more images. A black image will be returned instead."
" Try again with a different prompt and/or seed."
)
for idx, nsfw_detected_ in enumerate(nsfw_detected):
if nsfw_detected_:
images[idx] = np.zeros(images[idx].shape)
watermark_detected = self.w_head(image_embeds)
watermark_detected = watermark_detected.flatten()
watermark_detected = watermark_detected > w_threshold
watermark_detected = watermark_detected.tolist()
if any(watermark_detected):
logger.warning(
"Potential watermarked content was detected in one or more images. A black image will be returned instead."
" Try again with a different prompt and/or seed."
)
for idx, watermark_detected_ in enumerate(watermark_detected):
if watermark_detected_:
images[idx] = np.zeros(images[idx].shape)
return images, nsfw_detected, watermark_detected