Spaces:
Runtime error
Runtime error
# adapted from https://huggingface.co/spaces/HumanAIGC/OutfitAnyone/blob/main/app.py | |
import os | |
from os.path import join as opj | |
LD_PRELOAD = os.getenv("LD_PRELOAD") | |
os.environ["LD_PRELOAD"] = f"{LD_PRELOAD}:/usr/lib/x86_64-linux-gnu/libjemalloc.so" | |
os.environ["MALLOC_CONF"] = "oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms: 60000,muzzy_decay_ms:60000" | |
LD_PRELOAD= os.getenv("LD_PRELOAD") | |
os.environ["LD_PRELOAD"] = f"{LD_PRELOAD}:/usr/lib/x86_64-linux-gnu/libjemalloc.so:/usr/lib/x86_64-linux-gnu/libiomp5.so" | |
os.environ["OMP_NUM_THREADS"] ==32 | |
token = os.getenv("ACCESS_TOKEN") | |
os.system(f"python -m pip install git+https://{token}@github.com/logn-2024/StableGarment.git") | |
import torch | |
import spaces | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
from torchvision import transforms | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from diffusers import UniPCMultistepScheduler | |
from diffusers import AutoencoderKL | |
from diffusers import StableDiffusionPipeline | |
from diffusers.loaders import LoraLoaderMixin | |
import intel_extension_for_pytorch as ipex | |
from stablegarment.models import GarmentEncoderModel,ControlNetModel | |
from stablegarment.piplines import StableGarmentPipeline,StableGarmentControlNetPipeline | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.bfloat16 if device=="cpu" else torch.float16 | |
height = 512 | |
width = 384 | |
base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE" | |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dtype=torch_dtype,device=device) | |
scheduler = UniPCMultistepScheduler.from_pretrained("runwayml/stable-diffusion-v1-5",subfolder="scheduler") | |
pretrained_garment_encoder_path = "loooooong/StableGarment_text2img" | |
garment_encoder = GarmentEncoderModel.from_pretrained(pretrained_garment_encoder_path,torch_dtype=torch_dtype,subfolder="garment_encoder") | |
garment_encoder = garment_encoder.to(device=device,dtype=torch_dtype) | |
pipeline_t2i = StableGarmentPipeline.from_pretrained(base_model_path, vae=vae, torch_dtype=torch_dtype,).to(device=device) # variant="fp16" | |
# pipeline = StableDiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V4.0_noVAE", vae=vae, torch_dtype=torch_dtype).to(device=device) | |
pipeline_t2i.scheduler = scheduler | |
if False: #device=="cpu": | |
# speed up for cpu | |
# to channels last | |
pipeline_t2i.unet = pipeline_t2i.unet.to(memory_format=torch.channels_last) | |
pipeline_t2i.vae = pipeline_t2i.vae.to(memory_format=torch.channels_last) | |
pipeline_t2i.text_encoder = pipeline_t2i.text_encoder.to(memory_format=torch.channels_last) | |
# pipeline_t2i.safety_checker = pipeline_t2i.safety_checker.to(memory_format=torch.channels_last) | |
# Create random input to enable JIT compilation | |
sample = torch.randn(2,4,64,64).type(torch_dtype) | |
timestep = torch.rand(1)*999 | |
encoder_hidden_status = torch.randn(2,77,768).type(torch_dtype) | |
input_example = (sample, timestep, encoder_hidden_status) | |
# optimize with IPEX | |
pipeline_t2i.unet = ipex.optimize(pipeline_t2i.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=input_example) | |
pipeline_t2i.vae = ipex.optimize(pipeline_t2i.vae.eval(), dtype=torch.bfloat16, inplace=True) | |
pipeline_t2i.text_encoder = ipex.optimize(pipeline_t2i.text_encoder.eval(), dtype=torch.bfloat16, inplace=True) | |
# pipeline_t2i.safety_checker = ipex.optimize(pipeline_t2i.safety_checker.eval(), dtype=torch.bfloat16, inplace=True) | |
pipeline_tryon = None | |
''' | |
# not ready | |
pretrained_model_path = "part_module_controlnet_imp2" | |
controlnet = ControlNetModel.from_pretrained(pretrained_model_path,subfolder="controlnet") | |
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder='text_encoder') | |
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder='tokenizer') | |
pipeline_tryon = StableGarmentControlNetPipeline( | |
vae, | |
text_encoder, | |
tokenizer, | |
pipeline_t2i.unet, | |
controlnet, | |
scheduler, | |
).to(device=device,dtype=torch_dtype) | |
''' | |
def prepare_controlnet_inputs(agn_mask_list,densepose_list): | |
for i,agn_mask_img in enumerate(agn_mask_list): | |
agn_mask_img = np.array(agn_mask_img.convert("L")) | |
agn_mask_img = np.expand_dims(agn_mask_img, axis=-1) | |
agn_mask_img = (agn_mask_img >= 128).astype(np.float32) # 0 or 1 | |
agn_mask_list[i] = 1. - agn_mask_img | |
densepose_list = [np.array(img)/255. for img in densepose_list] | |
controlnet_inputs = [] | |
for mask,pose in zip(agn_mask_list,densepose_list): | |
controlnet_inputs.append(torch.tensor(np.concatenate([mask, pose], axis=-1)).permute(2,0,1)) | |
controlnet_inputs = torch.stack(controlnet_inputs) | |
return controlnet_inputs | |
def tryon(prompt,init_image,garment_top,garment_down,): | |
basename = os.path.splitext(os.path.basename(init_image))[0] | |
image_agn = Image.open(opj(parse_dir,basename+"_agn.jpg")).resize((width,height)) | |
image_agn_mask = Image.open(opj(parse_dir,basename+"_mask.png")).resize((width,height)) | |
densepose_image = Image.open(opj(parse_dir,basename+"_densepose.png")).resize((width,height)) | |
garment_top = Image.open(garment_top).resize((width,height)) | |
garment_images = [garment_top,] | |
prompt = [prompt,] | |
cloth_prompt = ["",] | |
controlnet_condition = prepare_controlnet_inputs([image_agn_mask],[densepose_image]) | |
images = pipeline_tryon(prompt, negative_prompt="",cloth_prompt=cloth_prompt, # negative_cloth_prompt = n_prompt, | |
height=height,width=width,num_inference_steps=25,guidance_scale=1.5,eta=0.0, | |
controlnet_condition=controlnet_condition,reference_image=garment_images, | |
garment_encoder=garment_encoder,condition_extra=image_agn, | |
generator=None,).images | |
return images[0] | |
def text2image(prompt,init_image,garment_top,garment_down,style_fidelity=1.): | |
garment_top = Image.open(garment_top).resize((width,height)) | |
garment_top = transforms.CenterCrop((height,width))(transforms.Resize(max(height, width))(garment_top)) | |
# always enable classifier-free-guidance as it is related to garment | |
cfg = 4 # if prompt else 0 | |
garment_images = [garment_top,] | |
prompt = [prompt,] | |
cloth_prompt = ["",] | |
n_prompt = "nsfw, unsaturated, abnormal, unnatural, artifact" | |
negative_prompt = [n_prompt] | |
images = pipeline_t2i(prompt,negative_prompt=negative_prompt,cloth_prompt=cloth_prompt,height=height,width=width, | |
num_inference_steps=30,guidance_scale=cfg,num_images_per_prompt=1,style_fidelity=style_fidelity, | |
garment_encoder=garment_encoder,garment_image=garment_images,).images | |
return images[0] | |
# def text2image(prompt,init_image,garment_top,garment_down,): | |
# return pipeline(prompt).images[0] | |
def infer(prompt,init_image,garment_top,garment_down,t2i_only,style_fidelity): | |
if t2i_only: | |
return text2image(prompt,init_image,garment_top,garment_down,style_fidelity) | |
else: | |
return tryon(prompt,init_image,garment_top,garment_down) | |
init_state,prompt_state = None,"" | |
t2i_only_state = True | |
def set_mode(t2i_only,person_condition,prompt): | |
global init_state, prompt_state, t2i_only_state | |
t2i_only_state = not t2i_only_state | |
init_state, prompt_state = person_condition or init_state, prompt_state or prompt | |
if t2i_only: | |
return [gr.Image(sources='clipboard', type="filepath", label="model",value=None, interactive=False), | |
gr.Textbox(placeholder="", label="prompt(for t2i)", value=prompt_state, interactive=True), | |
] | |
else: | |
return [gr.Image(sources='clipboard', type="filepath", label="model",value=init_state, interactive=False), | |
gr.Textbox(placeholder="", label="prompt(for t2i)", value="", interactive=False), | |
] | |
def example_fn(inputs,): | |
if t2i_only_state: | |
return gr.Image(sources='clipboard', type="filepath", label="model", value=None, interactive=False) | |
return gr.Image(sources='clipboard', type="filepath", label="model",value=inputs, interactive=False) | |
gr.set_static_paths(paths=["assets/images/model"]) | |
model_dir = opj(os.path.dirname(__file__), "assets/images/model") | |
garment_dir = opj(os.path.dirname(__file__), "assets/images/garment") | |
parse_dir = opj(os.path.dirname(__file__), "assets/images/image_parse") | |
model = opj(model_dir, "13987_00.jpg") | |
all_person = [opj(model_dir,fname) for fname in os.listdir(model_dir) if fname.endswith(".jpg")] | |
with gr.Blocks(css = ".output-image, .input-image, .image-preview {height: 400px !important} ", ) as gradio_app: | |
gr.Markdown("# StableGarment") | |
with gr.Row(): | |
with gr.Column(): | |
init_image = gr.Image(sources='clipboard', type="filepath", label="model", value=None, interactive=False) | |
example = gr.Examples(inputs=gr.Image(visible=False), #init_image, | |
examples_per_page=4, | |
examples=all_person, | |
run_on_click=True, | |
outputs=init_image, | |
fn=example_fn,) | |
with gr.Column(): | |
with gr.Row(): | |
images_top = [opj(garment_dir,fname) for fname in os.listdir(garment_dir) if fname.endswith(".jpg")] | |
garment_top = gr.Image(sources='upload', type="filepath", label="top garment",value=images_top[0]) # ,interactive=False | |
example_top = gr.Examples(inputs=garment_top, | |
examples_per_page=4, | |
examples=images_top) | |
images_down = [] | |
garment_down = gr.Image(sources='upload', type="filepath", label="lower garment",interactive=False, visible=False) | |
example_down = gr.Examples(inputs=garment_down, | |
examples_per_page=4, | |
examples=images_down) | |
prompt = gr.Textbox(placeholder="", label="prompt(for t2i)",) # interactive=False | |
with gr.Row(): | |
t2i_only = gr.Checkbox(label="t2i with garment", info="Only text and garment.", elem_id="t2i_switch", value=True, interactive=False,) | |
run_button = gr.Button(value="Run") | |
t2i_only.change(fn=set_mode,inputs=[t2i_only,init_image,prompt],outputs=[init_image,prompt,]) | |
with gr.Accordion("advance options", open=False): | |
gr.Markdown("Garment fidelity control(Tune down it to reduce white edge).") | |
style_fidelity = gr.Slider(0, 1, value=1, label="fidelity(only for t2i)") # , info="" | |
with gr.Column(): | |
gallery = gr.Image() | |
run_button.click(fn=infer, | |
inputs=[ | |
prompt, | |
init_image, | |
garment_top, | |
garment_down, | |
t2i_only, | |
style_fidelity, | |
], | |
outputs=[gallery],) | |
if __name__ == "__main__": | |
gradio_app.launch() | |