stable-fashion / main.py
ovshake
add app.py and related files
6724ca0
raw
history blame
4.66 kB
from diffusers import StableDiffusionInpaintPipeline
import os
from tqdm import tqdm
from PIL import Image
import numpy as np
import cv2
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from data.base_dataset import Normalize_image
from utils.saving_utils import load_checkpoint_mgpu
from networks import U2NET
import argparse
from enum import Enum
from rembg import remove
class Parts:
UPPER = 1
LOWER = 2
def parse_arguments():
parser = argparse.ArgumentParser(
description="Stable Fashion API, allows you to picture yourself in any cloth your imagination can think of!"
)
parser.add_argument('--image', type=str, required=True, help='path to image')
parser.add_argument('--part', choices=['upper', 'lower'], default='upper', type=str)
parser.add_argument('--resolution', choices=[256, 512, 1024, 2048], default=256, type=int)
parser.add_argument('--prompt', type=str, default="A pink cloth")
parser.add_argument('--num_steps', type=int, default=5)
parser.add_argument('--guidance_scale', type=float, default=7.5)
parser.add_argument('--rembg', action='store_true')
parser.add_argument('--output', default='output.jpg', type=str)
args, _ = parser.parse_known_args()
return args
def load_u2net():
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint_path = os.path.join("trained_checkpoint", "cloth_segm_u2net_latest.pth")
net = U2NET(in_ch=3, out_ch=4)
net = load_checkpoint_mgpu(net, checkpoint_path)
net = net.to(device)
net = net.eval()
return net
def change_bg_color(rgba_image, color):
new_image = Image.new("RGBA", rgba_image.size, color)
new_image.paste(rgba_image, (0, 0), rgba_image)
return new_image.convert("RGB")
def load_inpainting_pipeline():
device = "cuda" if torch.cuda.is_available() else "cpu"
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(device)
return inpainting_pipeline
def process_image(args, inpainting_pipeline, net):
device = "cuda" if torch.cuda.is_available() else "cpu"
image_path = args.image
transforms_list = []
transforms_list += [transforms.ToTensor()]
transforms_list += [Normalize_image(0.5, 0.5)]
transform_rgb = transforms.Compose(transforms_list)
img = Image.open(image_path)
img = img.convert("RGB")
img = img.resize((args.resolution, args.resolution))
if args.rembg:
img_with_green_bg = remove(img)
img_with_green_bg = change_bg_color(img_with_green_bg, color="GREEN")
img_with_green_bg = img_with_green_bg.convert("RGB")
else:
img_with_green_bg = img
image_tensor = transform_rgb(img_with_green_bg)
image_tensor = image_tensor.unsqueeze(0)
output_tensor = net(image_tensor.to(device))
output_tensor = F.log_softmax(output_tensor[0], dim=1)
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
output_tensor = torch.squeeze(output_tensor, dim=0)
output_tensor = torch.squeeze(output_tensor, dim=0)
output_arr = output_tensor.cpu().numpy()
mask_code = eval(f"Parts.{args.part.upper()}")
mask = (output_arr == mask_code)
output_arr[mask] = 1
output_arr[~mask] = 0
output_arr *= 255
mask_PIL = Image.fromarray(output_arr.astype("uint8"), mode="L")
clothed_image_from_pipeline = inpainting_pipeline(prompt=args.prompt,
image=img_with_green_bg,
mask_image=mask_PIL,
width=args.resolution,
height=args.resolution,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_steps).images[0]
clothed_image_from_pipeline = remove(clothed_image_from_pipeline)
clothed_image_from_pipeline = change_bg_color(clothed_image_from_pipeline, "WHITE")
return clothed_image_from_pipeline.convert("RGB")
if __name__ == '__main__':
args = parse_arguments()
net = load_u2net()
inpainting_pipeline = load_inpainting_pipeline()
result_image = process_image(args, inpainting_pipeline, net)
result_image.save(args.output)