import os import random import sys from typing import Sequence, Mapping, Any, Union import torch import gradio as gr from PIL import Image from huggingface_hub import hf_hub_download, login import spaces # Hugging Face 토큰으로 로그인 HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN is None: raise ValueError("Please set the HF_TOKEN environment variable") login(token=HF_TOKEN) # 이후 모델 다운로드 hf_hub_download( repo_id="black-forest-labs/FLUX.1-Redux-dev", filename="flux1-redux-dev.safetensors", local_dir="models/style_models", token=HF_TOKEN ) hf_hub_download( repo_id="black-forest-labs/FLUX.1-Depth-dev", filename="flux1-depth-dev.safetensors", local_dir="models/diffusion_models", token=HF_TOKEN ) hf_hub_download( repo_id="Comfy-Org/sigclip_vision_384", filename="sigclip_vision_patch14_384.safetensors", local_dir="models/clip_vision", token=HF_TOKEN ) hf_hub_download( repo_id="Kijai/DepthAnythingV2-safetensors", filename="depth_anything_v2_vitl_fp32.safetensors", local_dir="models/depthanything", token=HF_TOKEN ) hf_hub_download( repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="models/vae/FLUX1", token=HF_TOKEN ) hf_hub_download( repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir="models/text_encoders", token=HF_TOKEN ) t5_path = hf_hub_download( repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp16.safetensors", local_dir="models/text_encoders/t5", token=HF_TOKEN ) def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: try: return obj[index] except KeyError: return obj["result"][index] def find_path(name: str, path: str = None) -> str: if path is None: path = os.getcwd() if name in os.listdir(path): path_name = os.path.join(path, name) print(f"{name} found: {path_name}") return path_name parent_directory = os.path.dirname(path) if parent_directory == path: return None return find_path(name, parent_directory) def add_comfyui_directory_to_sys_path() -> None: comfyui_path = find_path("ComfyUI") if comfyui_path is not None and os.path.isdir(comfyui_path): sys.path.append(comfyui_path) print(f"'{comfyui_path}' added to sys.path") def add_extra_model_paths() -> None: try: from main import load_extra_path_config except ImportError: from utils.extra_config import load_extra_path_config extra_model_paths = find_path("extra_model_paths.yaml") if extra_model_paths is not None: load_extra_path_config(extra_model_paths) else: print("Could not find the extra_model_paths config file.") # Initialize paths add_comfyui_directory_to_sys_path() add_extra_model_paths() def import_custom_nodes() -> None: import asyncio import execution from nodes import init_extra_nodes import server loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) server_instance = server.PromptServer(loop) execution.PromptQueue(server_instance) init_extra_nodes() # Import all necessary nodes from nodes import ( StyleModelLoader, VAEEncode, NODE_CLASS_MAPPINGS, LoadImage, CLIPVisionLoader, SaveImage, VAELoader, CLIPVisionEncode, DualCLIPLoader, EmptyLatentImage, VAEDecode, UNETLoader, CLIPTextEncode, ) # Initialize all constant nodes and models in global context import_custom_nodes() # Global variables for preloaded models and constants intconstant = NODE_CLASS_MAPPINGS["INTConstant"]() CONST_1024 = intconstant.get_value(value=1024) # Load CLIP dualcliploader = DualCLIPLoader() CLIP_MODEL = dualcliploader.load_clip( clip_name1="t5/t5xxl_fp16.safetensors", clip_name2="clip_l.safetensors", type="flux", ) # Load VAE vaeloader = VAELoader() VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors") # Load UNET unetloader = UNETLoader() UNET_MODEL = unetloader.load_unet( unet_name="flux1-depth-dev.safetensors", weight_dtype="default" ) # Load CLIP Vision clipvisionloader = CLIPVisionLoader() CLIP_VISION_MODEL = clipvisionloader.load_clip( clip_name="sigclip_vision_patch14_384.safetensors" ) # Load Style Model stylemodelloader = StyleModelLoader() STYLE_MODEL = stylemodelloader.load_style_model( style_model_name="flux1-redux-dev.safetensors" ) # Initialize samplers ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]() SAMPLER = ksamplerselect.get_sampler(sampler_name="euler") # Initialize depth model cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]() downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]() DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel( model="depth_anything_v2_vitl_fp32.safetensors" ) # Initialize other nodes cliptextencode = CLIPTextEncode() loadimage = LoadImage() vaeencode = VAEEncode() fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]() instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]() clipvisionencode = CLIPVisionEncode() stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]() emptylatentimage = EmptyLatentImage() basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]() basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]() randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]() samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]() vaedecode = VAEDecode() cr_text = NODE_CLASS_MAPPINGS["CR Text"]() saveimage = SaveImage() getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]() depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]() imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]() @spaces.GPU def generate_image(prompt, structure_image, style_image, depth_strength=15, style_strength=0.5, progress=gr.Progress(track_tqdm=True)) -> str: """Main generation function that processes inputs and returns the path to the generated image.""" with torch.inference_mode(): # Set up CLIP clip_switch = cr_clip_input_switch.switch( Input=1, clip1=get_value_at_index(CLIP_MODEL, 0), clip2=get_value_at_index(CLIP_MODEL, 0), ) # Encode text text_encoded = cliptextencode.encode( text=prompt, clip=get_value_at_index(clip_switch, 0), ) empty_text = cliptextencode.encode( text="", clip=get_value_at_index(clip_switch, 0), ) # Process structure image structure_img = loadimage.load_image(image=structure_image) # Resize image resized_img = imageresize.execute( width=get_value_at_index(CONST_1024, 0), height=get_value_at_index(CONST_1024, 0), interpolation="bicubic", method="keep proportion", condition="always", multiple_of=16, image=get_value_at_index(structure_img, 0), ) # Get image size size_info = getimagesizeandcount.getsize( image=get_value_at_index(resized_img, 0) ) # Encode VAE vae_encoded = vaeencode.encode( pixels=get_value_at_index(size_info, 0), vae=get_value_at_index(VAE_MODEL, 0), ) # Process depth depth_processed = depthanything_v2.process( da_model=get_value_at_index(DEPTH_MODEL, 0), images=get_value_at_index(size_info, 0), ) # Apply Flux guidance flux_guided = fluxguidance.append( guidance=depth_strength, conditioning=get_value_at_index(text_encoded, 0), ) # Process style image style_img = loadimage.load_image(image=style_image) # Encode style with CLIP Vision style_encoded = clipvisionencode.encode( crop="center", clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0), image=get_value_at_index(style_img, 0), ) # Set up conditioning conditioning = instructpixtopixconditioning.encode( positive=get_value_at_index(flux_guided, 0), negative=get_value_at_index(empty_text, 0), vae=get_value_at_index(VAE_MODEL, 0), pixels=get_value_at_index(depth_processed, 0), ) # Apply style style_applied = stylemodelapplyadvanced.apply_stylemodel( strength=style_strength, conditioning=get_value_at_index(conditioning, 0), style_model=get_value_at_index(STYLE_MODEL, 0), clip_vision_output=get_value_at_index(style_encoded, 0), ) # Set up empty latent empty_latent = emptylatentimage.generate( width=get_value_at_index(resized_img, 1), height=get_value_at_index(resized_img, 2), batch_size=1, ) # Set up guidance guided = basicguider.get_guider( model=get_value_at_index(UNET_MODEL, 0), conditioning=get_value_at_index(style_applied, 0), ) # Set up scheduler schedule = basicscheduler.get_sigmas( scheduler="simple", steps=28, denoise=1, model=get_value_at_index(UNET_MODEL, 0), ) # Generate random noise noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64)) # Sample sampled = samplercustomadvanced.sample( noise=get_value_at_index(noise, 0), guider=get_value_at_index(guided, 0), sampler=get_value_at_index(SAMPLER, 0), sigmas=get_value_at_index(schedule, 0), latent_image=get_value_at_index(empty_latent, 0), ) # Decode VAE decoded = vaedecode.decode( samples=get_value_at_index(sampled, 0), vae=get_value_at_index(VAE_MODEL, 0), ) # Save image prefix = cr_text.text_multiline(text="Virtual_TryOn") saved = saveimage.save_images( filename_prefix=get_value_at_index(prefix, 0), images=get_value_at_index(decoded, 0), ) saved_path = f"output/{saved['ui']['images'][0]['filename']}" return saved_path output_image = gr.Image(label="Virtual Try-On Result") # Create Gradio interface examples = [ ["person wearing fashionable clothing", "f1.webp", "f11.webp", 15, 0.6], ["person wearing elegant dress", "f2.webp", "f21.webp", 15, 0.5], ["person wearing casual outfit", "f3.webp", "f31.webp", 15, 0.5], ] # Gradio 인터페이스 생성 demo = gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") with demo: gr.Markdown("# 🎭 StyleGen : Fashion AI Studio") gr.Markdown("Generate fashion images and try on virtual clothing using AI") with gr.Tabs(): # Virtual Try-On 탭 with gr.TabItem("👔 Virtual Try-On"): with gr.Row(): with gr.Column(): prompt_input = gr.Textbox( label="Style Description", placeholder="Describe the desired style (e.g., 'person wearing elegant dress')" ) with gr.Row(): with gr.Group(): structure_image = gr.Image( label="Your Photo (Full-body)", type="filepath" ) gr.Markdown("*Upload a clear, well-lit full-body photo*") depth_strength = gr.Slider( minimum=0, maximum=50, value=15, label="Fitting Strength" ) with gr.Group(): style_image = gr.Image( label="Clothing Item", type="filepath" ) gr.Markdown("*Upload the clothing item you want to try on*") style_strength = gr.Slider( minimum=0, maximum=1, value=0.5, label="Style Transfer Strength" ) with gr.Column(): output_image = gr.Image(label="Virtual Try-On Result") generate_button = gr.Button("Generate Try-On") gr.Examples( examples=examples, inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength], outputs=output_image, fn=generate_image, cache_examples=False ) # Connect the button to the generation function generate_button.click( fn=generate_image, inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength], outputs=output_image ) if __name__ == "__main__": demo.launch(share=True)