Spaces:
Runtime error
Runtime error
# Have SwinIR upsample | |
# Have BLIP auto caption | |
# Have CLIPSeg auto mask concept | |
from typing import List, Literal, Union, Optional, Tuple | |
import os | |
from PIL import Image, ImageFilter | |
import torch | |
import numpy as np | |
import fire | |
from tqdm import tqdm | |
import glob | |
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation | |
def swin_ir_sr( | |
images: List[Image.Image], | |
model_id: Literal[ | |
"caidas/swin2SR-classical-sr-x2-64", "caidas/swin2SR-classical-sr-x4-48" | |
] = "caidas/swin2SR-classical-sr-x2-64", | |
target_size: Optional[Tuple[int, int]] = None, | |
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), | |
**kwargs, | |
) -> List[Image.Image]: | |
""" | |
Upscales images using SwinIR. Returns a list of PIL images. | |
""" | |
# So this is currently in main branch, so this can be used in the future I guess? | |
from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor | |
model = Swin2SRForImageSuperResolution.from_pretrained( | |
model_id, | |
).to(device) | |
processor = Swin2SRImageProcessor() | |
out_images = [] | |
for image in tqdm(images): | |
ori_w, ori_h = image.size | |
if target_size is not None: | |
if ori_w >= target_size[0] and ori_h >= target_size[1]: | |
out_images.append(image) | |
continue | |
inputs = processor(image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
output = ( | |
outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() | |
) | |
output = np.moveaxis(output, source=0, destination=-1) | |
output = (output * 255.0).round().astype(np.uint8) | |
output = Image.fromarray(output) | |
out_images.append(output) | |
return out_images | |
def clipseg_mask_generator( | |
images: List[Image.Image], | |
target_prompts: Union[List[str], str], | |
model_id: Literal[ | |
"CIDAS/clipseg-rd64-refined", "CIDAS/clipseg-rd16" | |
] = "CIDAS/clipseg-rd64-refined", | |
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), | |
bias: float = 0.01, | |
temp: float = 1.0, | |
**kwargs, | |
) -> List[Image.Image]: | |
""" | |
Returns a greyscale mask for each image, where the mask is the probability of the target prompt being present in the image | |
""" | |
if isinstance(target_prompts, str): | |
print( | |
f'Warning: only one target prompt "{target_prompts}" was given, so it will be used for all images' | |
) | |
target_prompts = [target_prompts] * len(images) | |
processor = CLIPSegProcessor.from_pretrained(model_id) | |
model = CLIPSegForImageSegmentation.from_pretrained(model_id).to(device) | |
masks = [] | |
for image, prompt in tqdm(zip(images, target_prompts)): | |
original_size = image.size | |
inputs = processor( | |
text=[prompt, ""], | |
images=[image] * 2, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt", | |
).to(device) | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probs = torch.nn.functional.softmax(logits / temp, dim=0)[0] | |
probs = (probs + bias).clamp_(0, 1) | |
probs = 255 * probs / probs.max() | |
# make mask greyscale | |
mask = Image.fromarray(probs.cpu().numpy()).convert("L") | |
# resize mask to original size | |
mask = mask.resize(original_size) | |
masks.append(mask) | |
return masks | |
def blip_captioning_dataset( | |
images: List[Image.Image], | |
text: Optional[str] = None, | |
model_id: Literal[ | |
"Salesforce/blip-image-captioning-large", | |
"Salesforce/blip-image-captioning-base", | |
] = "Salesforce/blip-image-captioning-large", | |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), | |
**kwargs, | |
) -> List[str]: | |
""" | |
Returns a list of captions for the given images | |
""" | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
processor = BlipProcessor.from_pretrained(model_id) | |
model = BlipForConditionalGeneration.from_pretrained(model_id).to(device) | |
captions = [] | |
for image in tqdm(images): | |
inputs = processor(image, text=text, return_tensors="pt").to("cuda") | |
out = model.generate( | |
**inputs, max_length=150, do_sample=True, top_k=50, temperature=0.7 | |
) | |
caption = processor.decode(out[0], skip_special_tokens=True) | |
captions.append(caption) | |
return captions | |
def face_mask_google_mediapipe( | |
images: List[Image.Image], blur_amount: float = 80.0, bias: float = 0.05 | |
) -> List[Image.Image]: | |
""" | |
Returns a list of images with mask on the face parts. | |
""" | |
import mediapipe as mp | |
mp_face_detection = mp.solutions.face_detection | |
face_detection = mp_face_detection.FaceDetection( | |
model_selection=1, min_detection_confidence=0.5 | |
) | |
masks = [] | |
for image in tqdm(images): | |
image = np.array(image) | |
results = face_detection.process(image) | |
black_image = np.ones((image.shape[0], image.shape[1]), dtype=np.uint8) | |
if results.detections: | |
for detection in results.detections: | |
x_min = int( | |
detection.location_data.relative_bounding_box.xmin * image.shape[1] | |
) | |
y_min = int( | |
detection.location_data.relative_bounding_box.ymin * image.shape[0] | |
) | |
width = int( | |
detection.location_data.relative_bounding_box.width * image.shape[1] | |
) | |
height = int( | |
detection.location_data.relative_bounding_box.height | |
* image.shape[0] | |
) | |
# draw the colored rectangle | |
black_image[y_min : y_min + height, x_min : x_min + width] = 255 | |
black_image = Image.fromarray(black_image) | |
masks.append(black_image) | |
return masks | |
def _crop_to_square( | |
image: Image.Image, com: List[Tuple[int, int]], resize_to: Optional[int] = None | |
): | |
cx, cy = com | |
width, height = image.size | |
if width > height: | |
left_possible = max(cx - height / 2, 0) | |
left = min(left_possible, width - height) | |
right = left + height | |
top = 0 | |
bottom = height | |
else: | |
left = 0 | |
right = width | |
top_possible = max(cy - width / 2, 0) | |
top = min(top_possible, height - width) | |
bottom = top + width | |
image = image.crop((left, top, right, bottom)) | |
if resize_to: | |
image = image.resize((resize_to, resize_to), Image.Resampling.LANCZOS) | |
return image | |
def _center_of_mass(mask: Image.Image): | |
""" | |
Returns the center of mass of the mask | |
""" | |
x, y = np.meshgrid(np.arange(mask.size[0]), np.arange(mask.size[1])) | |
x_ = x * np.array(mask) | |
y_ = y * np.array(mask) | |
x = np.sum(x_) / np.sum(mask) | |
y = np.sum(y_) / np.sum(mask) | |
return x, y | |
def load_and_save_masks_and_captions( | |
files: Union[str, List[str]], | |
output_dir: str, | |
caption_text: Optional[str] = None, | |
target_prompts: Optional[Union[List[str], str]] = None, | |
target_size: int = 512, | |
crop_based_on_salience: bool = True, | |
use_face_detection_instead: bool = False, | |
temp: float = 1.0, | |
n_length: int = -1, | |
): | |
""" | |
Loads images from the given files, generates masks for them, and saves the masks and captions and upscale images | |
to output dir. | |
""" | |
os.makedirs(output_dir, exist_ok=True) | |
# load images | |
if isinstance(files, str): | |
# check if it is a directory | |
if os.path.isdir(files): | |
# get all the .png .jpg in the directory | |
files = glob.glob(os.path.join(files, "*.png")) + glob.glob( | |
os.path.join(files, "*.jpg") | |
) | |
if len(files) == 0: | |
raise Exception( | |
f"No files found in {files}. Either {files} is not a directory or it does not contain any .png or .jpg files." | |
) | |
if n_length == -1: | |
n_length = len(files) | |
files = sorted(files)[:n_length] | |
images = [Image.open(file) for file in files] | |
# captions | |
print(f"Generating {len(images)} captions...") | |
captions = blip_captioning_dataset(images, text=caption_text) | |
if target_prompts is None: | |
target_prompts = captions | |
print(f"Generating {len(images)} masks...") | |
if not use_face_detection_instead: | |
seg_masks = clipseg_mask_generator( | |
images=images, target_prompts=target_prompts, temp=temp | |
) | |
else: | |
seg_masks = face_mask_google_mediapipe(images=images) | |
# find the center of mass of the mask | |
if crop_based_on_salience: | |
coms = [_center_of_mass(mask) for mask in seg_masks] | |
else: | |
coms = [(image.size[0] / 2, image.size[1] / 2) for image in images] | |
# based on the center of mass, crop the image to a square | |
images = [ | |
_crop_to_square(image, com, resize_to=None) for image, com in zip(images, coms) | |
] | |
print(f"Upscaling {len(images)} images...") | |
# upscale images anyways | |
images = swin_ir_sr(images, target_size=(target_size, target_size)) | |
images = [ | |
image.resize((target_size, target_size), Image.Resampling.LANCZOS) | |
for image in images | |
] | |
seg_masks = [ | |
_crop_to_square(mask, com, resize_to=target_size) | |
for mask, com in zip(seg_masks, coms) | |
] | |
with open(os.path.join(output_dir, "caption.txt"), "w") as f: | |
# save images and masks | |
for idx, (image, mask, caption) in enumerate(zip(images, seg_masks, captions)): | |
image.save(os.path.join(output_dir, f"{idx}.src.jpg"), quality=99) | |
mask.save(os.path.join(output_dir, f"{idx}.mask.png")) | |
f.write(caption + "\n") | |
def main(): | |
fire.Fire(load_and_save_masks_and_captions) | |