Spaces:
ginipick
/
Running on Zero

ginipick commited on
Commit
827731a
1 Parent(s): ae4cba3

Create app-backup.py

Browse files
Files changed (1) hide show
  1. app-backup.py +397 -0
app-backup.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import sys
4
+ from typing import Sequence, Mapping, Any, Union
5
+ import torch
6
+ import gradio as gr
7
+ from PIL import Image
8
+ from huggingface_hub import hf_hub_download, login
9
+ import spaces
10
+
11
+ # Hugging Face 토큰으로 로그인
12
+ HF_TOKEN = os.getenv("HF_TOKEN")
13
+ if HF_TOKEN is None:
14
+ raise ValueError("Please set the HF_TOKEN environment variable")
15
+ login(token=HF_TOKEN)
16
+
17
+ # 이후 모델 다운로드
18
+ hf_hub_download(
19
+ repo_id="black-forest-labs/FLUX.1-Redux-dev",
20
+ filename="flux1-redux-dev.safetensors",
21
+ local_dir="models/style_models",
22
+ token=HF_TOKEN
23
+ )
24
+ hf_hub_download(
25
+ repo_id="black-forest-labs/FLUX.1-Depth-dev",
26
+ filename="flux1-depth-dev.safetensors",
27
+ local_dir="models/diffusion_models",
28
+ token=HF_TOKEN
29
+ )
30
+ hf_hub_download(
31
+ repo_id="Comfy-Org/sigclip_vision_384",
32
+ filename="sigclip_vision_patch14_384.safetensors",
33
+ local_dir="models/clip_vision",
34
+ token=HF_TOKEN
35
+ )
36
+ hf_hub_download(
37
+ repo_id="Kijai/DepthAnythingV2-safetensors",
38
+ filename="depth_anything_v2_vitl_fp32.safetensors",
39
+ local_dir="models/depthanything",
40
+ token=HF_TOKEN
41
+ )
42
+ hf_hub_download(
43
+ repo_id="black-forest-labs/FLUX.1-dev",
44
+ filename="ae.safetensors",
45
+ local_dir="models/vae/FLUX1",
46
+ token=HF_TOKEN
47
+ )
48
+ hf_hub_download(
49
+ repo_id="comfyanonymous/flux_text_encoders",
50
+ filename="clip_l.safetensors",
51
+ local_dir="models/text_encoders",
52
+ token=HF_TOKEN
53
+ )
54
+ t5_path = hf_hub_download(
55
+ repo_id="comfyanonymous/flux_text_encoders",
56
+ filename="t5xxl_fp16.safetensors",
57
+ local_dir="models/text_encoders/t5",
58
+ token=HF_TOKEN
59
+ )
60
+
61
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
62
+ try:
63
+ return obj[index]
64
+ except KeyError:
65
+ return obj["result"][index]
66
+
67
+ def find_path(name: str, path: str = None) -> str:
68
+ if path is None:
69
+ path = os.getcwd()
70
+ if name in os.listdir(path):
71
+ path_name = os.path.join(path, name)
72
+ print(f"{name} found: {path_name}")
73
+ return path_name
74
+ parent_directory = os.path.dirname(path)
75
+ if parent_directory == path:
76
+ return None
77
+ return find_path(name, parent_directory)
78
+
79
+ def add_comfyui_directory_to_sys_path() -> None:
80
+ comfyui_path = find_path("ComfyUI")
81
+ if comfyui_path is not None and os.path.isdir(comfyui_path):
82
+ sys.path.append(comfyui_path)
83
+ print(f"'{comfyui_path}' added to sys.path")
84
+
85
+ def add_extra_model_paths() -> None:
86
+ try:
87
+ from main import load_extra_path_config
88
+ except ImportError:
89
+ from utils.extra_config import load_extra_path_config
90
+ extra_model_paths = find_path("extra_model_paths.yaml")
91
+ if extra_model_paths is not None:
92
+ load_extra_path_config(extra_model_paths)
93
+ else:
94
+ print("Could not find the extra_model_paths config file.")
95
+
96
+ # Initialize paths
97
+ add_comfyui_directory_to_sys_path()
98
+ add_extra_model_paths()
99
+
100
+
101
+
102
+ def import_custom_nodes() -> None:
103
+ import asyncio
104
+ import execution
105
+ from nodes import init_extra_nodes
106
+ import server
107
+ loop = asyncio.new_event_loop()
108
+ asyncio.set_event_loop(loop)
109
+ server_instance = server.PromptServer(loop)
110
+ execution.PromptQueue(server_instance)
111
+ init_extra_nodes()
112
+
113
+ # Import all necessary nodes
114
+ from nodes import (
115
+ StyleModelLoader,
116
+ VAEEncode,
117
+ NODE_CLASS_MAPPINGS,
118
+ LoadImage,
119
+ CLIPVisionLoader,
120
+ SaveImage,
121
+ VAELoader,
122
+ CLIPVisionEncode,
123
+ DualCLIPLoader,
124
+ EmptyLatentImage,
125
+ VAEDecode,
126
+ UNETLoader,
127
+ CLIPTextEncode,
128
+ )
129
+
130
+ # Initialize all constant nodes and models in global context
131
+ import_custom_nodes()
132
+
133
+ # Global variables for preloaded models and constants
134
+ intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
135
+ CONST_1024 = intconstant.get_value(value=1024)
136
+
137
+ # Load CLIP
138
+ dualcliploader = DualCLIPLoader()
139
+ CLIP_MODEL = dualcliploader.load_clip(
140
+ clip_name1="t5/t5xxl_fp16.safetensors",
141
+ clip_name2="clip_l.safetensors",
142
+ type="flux",
143
+ )
144
+
145
+ # Load VAE
146
+ vaeloader = VAELoader()
147
+ VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
148
+
149
+ # Load UNET
150
+ unetloader = UNETLoader()
151
+ UNET_MODEL = unetloader.load_unet(
152
+ unet_name="flux1-depth-dev.safetensors", weight_dtype="default"
153
+ )
154
+
155
+ # Load CLIP Vision
156
+ clipvisionloader = CLIPVisionLoader()
157
+ CLIP_VISION_MODEL = clipvisionloader.load_clip(
158
+ clip_name="sigclip_vision_patch14_384.safetensors"
159
+ )
160
+
161
+ # Load Style Model
162
+ stylemodelloader = StyleModelLoader()
163
+ STYLE_MODEL = stylemodelloader.load_style_model(
164
+ style_model_name="flux1-redux-dev.safetensors"
165
+ )
166
+
167
+ # Initialize samplers
168
+ ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
169
+ SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
170
+
171
+ # Initialize depth model
172
+ cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
173
+ downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
174
+ DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
175
+ model="depth_anything_v2_vitl_fp32.safetensors"
176
+ )
177
+
178
+ # Initialize other nodes
179
+ cliptextencode = CLIPTextEncode()
180
+ loadimage = LoadImage()
181
+ vaeencode = VAEEncode()
182
+ fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
183
+ instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]()
184
+ clipvisionencode = CLIPVisionEncode()
185
+ stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
186
+ emptylatentimage = EmptyLatentImage()
187
+ basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
188
+ basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
189
+ randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
190
+ samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
191
+ vaedecode = VAEDecode()
192
+ cr_text = NODE_CLASS_MAPPINGS["CR Text"]()
193
+ saveimage = SaveImage()
194
+ getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]()
195
+ depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]()
196
+ imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
197
+
198
+ @spaces.GPU
199
+ def generate_image(prompt, structure_image, style_image, depth_strength=15, style_strength=0.5, progress=gr.Progress(track_tqdm=True)) -> str:
200
+ """Main generation function that processes inputs and returns the path to the generated image."""
201
+ with torch.inference_mode():
202
+ # Set up CLIP
203
+ clip_switch = cr_clip_input_switch.switch(
204
+ Input=1,
205
+ clip1=get_value_at_index(CLIP_MODEL, 0),
206
+ clip2=get_value_at_index(CLIP_MODEL, 0),
207
+ )
208
+
209
+ # Encode text
210
+ text_encoded = cliptextencode.encode(
211
+ text=prompt,
212
+ clip=get_value_at_index(clip_switch, 0),
213
+ )
214
+ empty_text = cliptextencode.encode(
215
+ text="",
216
+ clip=get_value_at_index(clip_switch, 0),
217
+ )
218
+
219
+ # Process structure image
220
+ structure_img = loadimage.load_image(image=structure_image)
221
+
222
+ # Resize image
223
+ resized_img = imageresize.execute(
224
+ width=get_value_at_index(CONST_1024, 0),
225
+ height=get_value_at_index(CONST_1024, 0),
226
+ interpolation="bicubic",
227
+ method="keep proportion",
228
+ condition="always",
229
+ multiple_of=16,
230
+ image=get_value_at_index(structure_img, 0),
231
+ )
232
+
233
+ # Get image size
234
+ size_info = getimagesizeandcount.getsize(
235
+ image=get_value_at_index(resized_img, 0)
236
+ )
237
+
238
+ # Encode VAE
239
+ vae_encoded = vaeencode.encode(
240
+ pixels=get_value_at_index(size_info, 0),
241
+ vae=get_value_at_index(VAE_MODEL, 0),
242
+ )
243
+
244
+ # Process depth
245
+ depth_processed = depthanything_v2.process(
246
+ da_model=get_value_at_index(DEPTH_MODEL, 0),
247
+ images=get_value_at_index(size_info, 0),
248
+ )
249
+
250
+ # Apply Flux guidance
251
+ flux_guided = fluxguidance.append(
252
+ guidance=depth_strength,
253
+ conditioning=get_value_at_index(text_encoded, 0),
254
+ )
255
+
256
+ # Process style image
257
+ style_img = loadimage.load_image(image=style_image)
258
+
259
+ # Encode style with CLIP Vision
260
+ style_encoded = clipvisionencode.encode(
261
+ crop="center",
262
+ clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0),
263
+ image=get_value_at_index(style_img, 0),
264
+ )
265
+
266
+ # Set up conditioning
267
+ conditioning = instructpixtopixconditioning.encode(
268
+ positive=get_value_at_index(flux_guided, 0),
269
+ negative=get_value_at_index(empty_text, 0),
270
+ vae=get_value_at_index(VAE_MODEL, 0),
271
+ pixels=get_value_at_index(depth_processed, 0),
272
+ )
273
+
274
+ # Apply style
275
+ style_applied = stylemodelapplyadvanced.apply_stylemodel(
276
+ strength=style_strength,
277
+ conditioning=get_value_at_index(conditioning, 0),
278
+ style_model=get_value_at_index(STYLE_MODEL, 0),
279
+ clip_vision_output=get_value_at_index(style_encoded, 0),
280
+ )
281
+
282
+ # Set up empty latent
283
+ empty_latent = emptylatentimage.generate(
284
+ width=get_value_at_index(resized_img, 1),
285
+ height=get_value_at_index(resized_img, 2),
286
+ batch_size=1,
287
+ )
288
+
289
+ # Set up guidance
290
+ guided = basicguider.get_guider(
291
+ model=get_value_at_index(UNET_MODEL, 0),
292
+ conditioning=get_value_at_index(style_applied, 0),
293
+ )
294
+
295
+ # Set up scheduler
296
+ schedule = basicscheduler.get_sigmas(
297
+ scheduler="simple",
298
+ steps=28,
299
+ denoise=1,
300
+ model=get_value_at_index(UNET_MODEL, 0),
301
+ )
302
+
303
+ # Generate random noise
304
+ noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
305
+
306
+ # Sample
307
+ sampled = samplercustomadvanced.sample(
308
+ noise=get_value_at_index(noise, 0),
309
+ guider=get_value_at_index(guided, 0),
310
+ sampler=get_value_at_index(SAMPLER, 0),
311
+ sigmas=get_value_at_index(schedule, 0),
312
+ latent_image=get_value_at_index(empty_latent, 0),
313
+ )
314
+
315
+ # Decode VAE
316
+ decoded = vaedecode.decode(
317
+ samples=get_value_at_index(sampled, 0),
318
+ vae=get_value_at_index(VAE_MODEL, 0),
319
+ )
320
+
321
+ # Save image
322
+ prefix = cr_text.text_multiline(text="Virtual_TryOn")
323
+
324
+ saved = saveimage.save_images(
325
+ filename_prefix=get_value_at_index(prefix, 0),
326
+ images=get_value_at_index(decoded, 0),
327
+ )
328
+ saved_path = f"output/{saved['ui']['images'][0]['filename']}"
329
+ return saved_path
330
+
331
+ # Create Gradio interface
332
+ examples = [
333
+ ["person wearing fashionable clothing", "f1.webp", "f11.webp", 15, 0.6],
334
+ ["person wearing elegant dress", "f2.webp", "f21.webp", 15, 0.5],
335
+ ["person wearing casual outfit", "f3.webp", "f31.webp", 15, 0.5],
336
+ ]
337
+
338
+ output_image = gr.Image(label="Virtual Try-On Result")
339
+
340
+ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as app:
341
+ gr.Markdown("# Style Generator")
342
+ gr.Markdown("Upload your photo and try on different clothing items virtually using AI. The system will generate an image of you wearing the selected clothing while maintaining your pose and appearance.")
343
+
344
+ with gr.Row():
345
+ with gr.Column():
346
+ prompt_input = gr.Textbox(
347
+ label="Style Description",
348
+ placeholder="Describe the desired style (e.g., 'person wearing elegant dress')"
349
+ )
350
+ with gr.Row():
351
+ with gr.Group():
352
+ structure_image = gr.Image(
353
+ label="Your Photo (Full-body)",
354
+ type="filepath"
355
+ )
356
+ gr.Markdown("*Upload a clear, well-lit full-body photo*")
357
+ depth_strength = gr.Slider(
358
+ minimum=0,
359
+ maximum=50,
360
+ value=15,
361
+ label="Fitting Strength"
362
+ )
363
+ with gr.Group():
364
+ style_image = gr.Image(
365
+ label="Clothing Item",
366
+ type="filepath"
367
+ )
368
+ gr.Markdown("*Upload the clothing item you want to try on*")
369
+ style_strength = gr.Slider(
370
+ minimum=0,
371
+ maximum=1,
372
+ value=0.5,
373
+ label="Style Transfer Strength"
374
+ )
375
+ generate_btn = gr.Button("Generate Try-On")
376
+
377
+ gr.Examples(
378
+ examples=examples,
379
+ inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
380
+ outputs=[output_image],
381
+ fn=generate_image,
382
+ cache_examples=True,
383
+ cache_mode="lazy"
384
+ )
385
+
386
+ with gr.Column():
387
+ output_image.render()
388
+
389
+
390
+ generate_btn.click(
391
+ fn=generate_image,
392
+ inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
393
+ outputs=[output_image]
394
+ )
395
+
396
+ if __name__ == "__main__":
397
+ app.launch(share=True)