File size: 6,768 Bytes
6724ca0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46e4e58
 
 
 
 
 
 
6724ca0
 
 
 
 
 
ad35187
6724ca0
 
 
 
 
 
 
 
 
 
 
 
 
 
ad35187
6724ca0
 
 
 
 
 
64135f7
6724ca0
 
 
 
 
 
 
 
 
 
 
5664062
6724ca0
 
 
 
 
 
 
 
 
 
7f50dc5
 
6724ca0
 
 
 
 
 
 
55f0240
 
 
6724ca0
 
 
 
 
 
 
 
 
 
 
 
 
6252c13
6724ca0
 
 
47512ef
 
 
 
 
 
 
 
 
 
5664062
 
da6de3d
 
 
 
1b1a677
da6de3d
1646c1f
da6de3d
 
 
 
 
 
 
8b4c4c8
5664062
6724ca0
 
152d97b
ae6d3bd
5664062
152d97b
975c056
45007eb
975c056
6724ca0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import streamlit as st
from diffusers import StableDiffusionInpaintPipeline
import os

from tqdm import tqdm
from PIL import Image
import numpy as np
import cv2
import warnings
from huggingface_hub import hf_hub_download

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms

from data.base_dataset import Normalize_image
from utils.saving_utils import load_checkpoint_mgpu
from networks import U2NET
import argparse
from enum import Enum
from rembg import remove
from dataclasses import dataclass


@dataclass
class StableFashionCLIArgs:
    image = None
    part = None
    resolution = None
    promt = None
    num_steps = None
    guidance_scale = None
    rembg = None


class Parts:
    UPPER = 1
    LOWER = 2

@st.cache(allow_output_mutation=True)
def load_u2net():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    checkpoint_path = hf_hub_download(repo_id="maiti/cloth-segmentation", filename="cloth_segm_u2net_latest.pth")
    net = U2NET(in_ch=3, out_ch=4)
    net = load_checkpoint_mgpu(net, checkpoint_path)
    net = net.to(device)
    net = net.eval()
    return net

def change_bg_color(rgba_image, color):
    new_image = Image.new("RGBA", rgba_image.size, color)
    new_image.paste(rgba_image, (0, 0), rgba_image)
    return new_image.convert("RGB")

@st.cache(allow_output_mutation=True)
def load_inpainting_pipeline():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting",
            revision="fp16",
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            use_auth_token=os.environ["hf_auth_token"]
        ).to(device)
    return inpainting_pipeline


def process_image(args, inpainting_pipeline, net):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    image_path = args.image
    transforms_list = []
    transforms_list += [transforms.ToTensor()]
    transforms_list += [Normalize_image(0.5, 0.5)]
    transform_rgb = transforms.Compose(transforms_list)
    img = Image.open(image_path)
    img = img.convert("RGB")
    img = img.resize((args.resolution, args.resolution))
    if args.rembg:
        img_with_green_bg = remove(img)
        img_with_green_bg = change_bg_color(img_with_green_bg, color="GREEN")
        img_with_green_bg = img_with_green_bg.convert("RGB")
    else:
        img_with_green_bg = img
    image_tensor = transform_rgb(img_with_green_bg)
    image_tensor = image_tensor.unsqueeze(0)
    with torch.autocast(device_type=device):
        output_tensor = net(image_tensor.to(device))
    output_tensor = F.log_softmax(output_tensor[0], dim=1)
    output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
    output_tensor = torch.squeeze(output_tensor, dim=0)
    output_tensor = torch.squeeze(output_tensor, dim=0)
    output_arr = output_tensor.cpu().numpy()
    mask_code = eval(f"Parts.{args.part.upper()}")
    mask = (output_arr == mask_code)
    print(f"Numbers in output_rr")
    print(np.unique(output_arr))
    print(f"mask code {mask_code}")
    output_arr[mask] = 1
    output_arr[~mask] = 0
    output_arr *= 255
    mask_PIL = Image.fromarray(output_arr.astype("uint8"), mode="L")
    clothed_image_from_pipeline = inpainting_pipeline(prompt=args.prompt,
                                                    image=img_with_green_bg,
                                                    mask_image=mask_PIL,
                                                    width=args.resolution,
                                                    height=args.resolution,
                                                    guidance_scale=args.guidance_scale,
                                                    num_inference_steps=args.num_steps).images[0]
    clothed_image_from_pipeline = remove(clothed_image_from_pipeline)
    clothed_image_from_pipeline = change_bg_color(clothed_image_from_pipeline, "WHITE")
    return clothed_image_from_pipeline.convert("RGB"), mask_PIL

net = load_u2net()
inpainting_pipeline = load_inpainting_pipeline()
st.markdown(
    """
    <p style='text-align: center'>
    <a href='https://github.com/ovshake' target='_blank'>ovshake Github</a> | <a href='https://github.com/ovshake/stable-fashion' target='_blank'>Stable Fashion Github</a> | <a href='https://huggingface.co/spaces/maiti/stable-fashion' target='_blank'>Stable Fashion Demo</a>
    <br />
    Follow me for more! <a href='https://twitter.com/o_v_shake' target='_blank'> <img src="https://img.icons8.com/color/48/000000/twitter--v1.png" height="30"></a><a href='https://github.com/ovshake' target='_blank'><img src="https://img.icons8.com/fluency/48/000000/github.png" height="27"></a><a href='https://www.linkedin.com/in/ovshake/' target='_blank'><img src="https://img.icons8.com/fluency/48/000000/linkedin.png" height="30"></a>
    </p>
    """,
    unsafe_allow_html=True,
)
st.title("Stable Fashion Huggingface Spaces")
file_name = st.file_uploader("Upload a clear full length picture of yourself, preferably in a less noisy background")
stable_fashion_args = StableFashionCLIArgs()
stable_fashion_args.image = file_name
body_part = st.radio("Would you like to try clothes on your upper body (such as shirts, kurtas etc) or lower (Jeans, Pants etc)? ", ('Upper', 'Lower'))
stable_fashion_args.part = body_part
resolution = st.radio("Which resolution would you like to get the resulting picture in? (Keep in mind, higher the resolution, higher the queue times)", (128, 256, 512), index=2)
stable_fashion_args.resolution = resolution
rembg_status = st.radio("Would you like to remove background in your image before putting new clothes on you? (Sometimes it results in better images)", ("Yes", "No"), index=0)
stable_fashion_args.rembg = (rembg_status == "Yes")
guidance_scale = st.slider("Select a guidance scale. 7.5 gives the best results.", 1.0, 15.0, value=7.5)
stable_fashion_args.guidance_scale = guidance_scale
prompt = st.text_input('Write the description of cloth you want to try', 'a bright yellow t shirt')
stable_fashion_args.prompt = prompt

num_steps = st.slider("No. of inference steps for the diffusion process", 5, 50, value=25)
stable_fashion_args.num_steps = num_steps


if file_name is not None:
    result_image, mask_PIL = process_image(stable_fashion_args, inpainting_pipeline, net)
    print(np.unique(np.asarray(mask_PIL)))
    st.image(result_image, caption='Result')
    st.image(mask_PIL, caption='Mask')
else:
    stock_image = Image.open('assets/abhishek_yellow.jpg')
    st.image(stock_image, caption='Result')