aiqtech commited on
Commit
a135ad5
โ€ข
1 Parent(s): 3c5364f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -90
app.py CHANGED
@@ -2,8 +2,6 @@ import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
4
  import os
5
- os.environ['SPCONV_ALGO'] = 'native'
6
- from typing import *
7
  import torch
8
  import numpy as np
9
  import imageio
@@ -15,6 +13,11 @@ from trellis.representations import Gaussian, MeshExtractResult
15
  from trellis.utils import render_utils, postprocessing_utils
16
  from transformers import pipeline as translation_pipeline
17
  from diffusers import FluxPipeline
 
 
 
 
 
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  TMP_DIR = "/tmp/Trellis-demo"
@@ -24,37 +27,27 @@ def initialize_models():
24
  global pipeline, translator, flux_pipe
25
 
26
  try:
27
- # GPU ๋ฉ”๋ชจ๋ฆฌ ์ดˆ๊ธฐํ™”
28
- torch.cuda.empty_cache()
29
-
30
- # GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์—ฌ๋ถ€ ํ™•์ธ
31
- device = "cuda" if torch.cuda.is_available() else "cpu"
32
-
33
- # Trellis ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
34
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
35
- pipeline.to(device)
36
 
37
- # ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
38
  translator = translation_pipeline(
39
  "translation",
40
  model="Helsinki-NLP/opus-mt-ko-en",
41
- device=0 if device=="cuda" else -1
42
  )
43
 
44
- # Flux ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
45
  flux_pipe = FluxPipeline.from_pretrained(
46
  "black-forest-labs/FLUX.1-dev",
47
- torch_dtype=torch.float16 if device=="cuda" else torch.float32
48
  )
49
 
50
- if device == "cuda":
51
- flux_pipe.enable_model_cpu_offload()
52
-
53
  return True
54
 
55
  except Exception as e:
56
  print(f"Model initialization error: {str(e)}")
57
- torch.cuda.empty_cache()
58
  return False
59
 
60
  def translate_if_korean(text):
@@ -63,11 +56,25 @@ def translate_if_korean(text):
63
  return translated
64
  return text
65
 
 
66
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
67
- trial_id = str(uuid.uuid4())
68
- processed_image = pipeline.preprocess_image(image)
69
- processed_image.save(f"{TMP_DIR}/{trial_id}.png")
70
- return trial_id, processed_image
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
73
  return {
@@ -86,7 +93,6 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
86
  'trial_id': trial_id,
87
  }
88
 
89
-
90
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
91
  gs = Gaussian(
92
  aabb=state['gaussian']['aabb'],
@@ -113,31 +119,32 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
113
  def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float,
114
  ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int):
115
  try:
116
- torch.cuda.empty_cache()
117
-
118
  if randomize_seed:
119
  seed = np.random.randint(0, MAX_SEED)
120
 
121
  input_image = Image.open(f"{TMP_DIR}/{trial_id}.png")
122
 
123
- with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
124
- with torch.no_grad():
125
- outputs = pipeline.run(
126
- input_image,
127
- seed=seed,
128
- formats=["gaussian", "mesh"],
129
- preprocess_image=False,
130
- sparse_structure_sampler_params={
131
- "steps": ss_sampling_steps,
132
- "cfg_strength": ss_guidance_strength,
133
- },
134
- slat_sampler_params={
135
- "steps": slat_sampling_steps,
136
- "cfg_strength": slat_guidance_strength,
137
- }
138
- )
 
 
 
 
139
 
140
- # ๋น„๋””์˜ค ๋ Œ๋”๋ง
141
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
142
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
143
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
@@ -149,37 +156,51 @@ def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_stre
149
 
150
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
151
 
152
- if torch.cuda.is_available():
153
- torch.cuda.empty_cache()
154
-
155
  return state, video_path
156
 
157
  except Exception as e:
158
  print(f"Error in image_to_3d: {str(e)}")
159
- torch.cuda.empty_cache()
160
  raise e
161
 
162
  @spaces.GPU
163
  def generate_image_from_text(prompt, height, width, guidance_scale, num_steps):
164
- # ๊ธฐ๋ณธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ถ”๊ฐ€
165
- base_prompt = "wbgmsst, 3D, white background"
166
-
167
- # ์‚ฌ์šฉ์ž ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋ฒˆ์—ญ (ํ•œ๊ตญ์–ด์ธ ๊ฒฝ์šฐ)
168
- translated_prompt = translate_if_korean(prompt)
169
-
170
- # ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ ์กฐํ•ฉ
171
- final_prompt = f"{translated_prompt}, {base_prompt}"
172
-
173
- with torch.inference_mode():
174
- image = flux_pipe(
175
- prompt=[final_prompt],
176
- height=height,
177
- width=width,
178
- guidance_scale=guidance_scale,
179
- num_inference_steps=num_steps
180
- ).images[0]
 
 
 
 
 
 
 
 
 
181
 
182
  return image
 
 
 
 
 
183
 
184
  @spaces.GPU
185
  def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]:
@@ -195,14 +216,13 @@ def activate_button() -> gr.Button:
195
  def deactivate_button() -> gr.Button:
196
  return gr.Button(interactive=False)
197
 
198
-
199
  css = """
200
  footer {
201
  visibility: hidden;
202
  }
203
  """
204
 
205
-
206
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
207
  gr.Markdown("""
208
  # Craft3D : 3D Asset Creation & Text-to-Image Generation
@@ -278,7 +298,7 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
278
  examples_per_page=64,
279
  )
280
 
281
- # Handlers
282
  image_prompt.upload(
283
  preprocess_image,
284
  inputs=[image_prompt],
@@ -292,59 +312,48 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
292
 
293
  generate_btn.click(
294
  image_to_3d,
295
- inputs=[trial_id, seed, randomize_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
 
296
  outputs=[output_buf, video_output],
 
297
  ).then(
298
  activate_button,
299
- outputs=[extract_glb_btn],
300
- )
301
-
302
- video_output.clear(
303
- deactivate_button,
304
- outputs=[extract_glb_btn],
305
  )
306
 
307
  extract_glb_btn.click(
308
  extract_glb,
309
  inputs=[output_buf, mesh_simplify, texture_size],
310
  outputs=[model_output, download_glb],
 
311
  ).then(
312
  activate_button,
313
- outputs=[download_glb],
314
  )
315
 
316
- model_output.clear(
317
- deactivate_button,
318
- outputs=[download_glb],
319
- )
320
-
321
- # Text to Image ํ•ธ๋“ค๋Ÿฌ
322
  generate_txt2img_btn.click(
323
  generate_image_from_text,
324
  inputs=[text_prompt, txt2img_height, txt2img_width, guidance_scale, num_steps],
325
- outputs=[txt2img_output]
 
326
  )
327
 
328
  if __name__ == "__main__":
329
- # ์ดˆ๊ธฐ GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
330
- if torch.cuda.is_available():
331
- torch.cuda.empty_cache()
332
-
333
- # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ํ™•์ธ
334
  if not initialize_models():
335
  print("Failed to initialize models")
336
  exit(1)
337
 
338
  try:
339
  # rembg ์‚ฌ์ „ ๋กœ๋“œ ์‹œ๋„
340
- test_image = Image.fromarray(np.zeros((256, 256, 3), dtype=np.uint8))
341
  pipeline.preprocess_image(test_image)
342
  except Exception as e:
343
  print(f"Warning: Failed to preload rembg: {str(e)}")
344
 
345
  # Gradio ์•ฑ ์‹คํ–‰
346
- demo.queue(concurrency_count=1).launch(
347
  share=True,
348
- enable_queue=True,
349
- max_threads=1
350
  )
 
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
4
  import os
 
 
5
  import torch
6
  import numpy as np
7
  import imageio
 
13
  from trellis.utils import render_utils, postprocessing_utils
14
  from transformers import pipeline as translation_pipeline
15
  from diffusers import FluxPipeline
16
+ from typing import *
17
+
18
+ # ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
19
+ os.environ['SPCONV_ALGO'] = 'native'
20
+ os.environ['WARP_USE_CPU'] = '1' # Warp๋ฅผ CPU ๋ชจ๋“œ๋กœ ๊ฐ•์ œ
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
  TMP_DIR = "/tmp/Trellis-demo"
 
27
  global pipeline, translator, flux_pipe
28
 
29
  try:
30
+ # Trellis ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” (CPU ๋ชจ๋“œ๋กœ)
 
 
 
 
 
 
31
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
 
32
 
33
+ # ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™” (CPU ๋ชจ๋“œ๋กœ)
34
  translator = translation_pipeline(
35
  "translation",
36
  model="Helsinki-NLP/opus-mt-ko-en",
37
+ device=-1
38
  )
39
 
40
+ # Flux ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” (CPU ๋ชจ๋“œ๋กœ)
41
  flux_pipe = FluxPipeline.from_pretrained(
42
  "black-forest-labs/FLUX.1-dev",
43
+ torch_dtype=torch.float32
44
  )
45
 
46
+ print("Models initialized successfully")
 
 
47
  return True
48
 
49
  except Exception as e:
50
  print(f"Model initialization error: {str(e)}")
 
51
  return False
52
 
53
  def translate_if_korean(text):
 
56
  return translated
57
  return text
58
 
59
+ @spaces.GPU
60
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
61
+ try:
62
+ trial_id = str(uuid.uuid4())
63
+
64
+ # ์ด๋ฏธ์ง€๊ฐ€ ๋„ˆ๋ฌด ์ž‘์€ ๊ฒฝ์šฐ ํฌ๊ธฐ ์กฐ์ •
65
+ min_size = 64
66
+ if image.size[0] < min_size or image.size[1] < min_size:
67
+ ratio = min_size / min(image.size)
68
+ new_size = tuple(int(dim * ratio) for dim in image.size)
69
+ image = image.resize(new_size, Image.LANCZOS)
70
+
71
+ processed_image = pipeline.preprocess_image(image)
72
+ processed_image.save(f"{TMP_DIR}/{trial_id}.png")
73
+ return trial_id, processed_image
74
+
75
+ except Exception as e:
76
+ print(f"Error in preprocess_image: {str(e)}")
77
+ return None, None
78
 
79
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
80
  return {
 
93
  'trial_id': trial_id,
94
  }
95
 
 
96
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
97
  gs = Gaussian(
98
  aabb=state['gaussian']['aabb'],
 
119
  def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float,
120
  ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int):
121
  try:
 
 
122
  if randomize_seed:
123
  seed = np.random.randint(0, MAX_SEED)
124
 
125
  input_image = Image.open(f"{TMP_DIR}/{trial_id}.png")
126
 
127
+ # GPU ์„ค์ •
128
+ if torch.cuda.is_available():
129
+ pipeline.to("cuda")
130
+ pipeline.to(torch.float16)
131
+
132
+ with torch.no_grad():
133
+ outputs = pipeline.run(
134
+ input_image,
135
+ seed=seed,
136
+ formats=["gaussian", "mesh"],
137
+ preprocess_image=False,
138
+ sparse_structure_sampler_params={
139
+ "steps": ss_sampling_steps,
140
+ "cfg_strength": ss_guidance_strength,
141
+ },
142
+ slat_sampler_params={
143
+ "steps": slat_sampling_steps,
144
+ "cfg_strength": slat_guidance_strength,
145
+ }
146
+ )
147
 
 
148
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
149
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
150
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
156
 
157
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
158
 
159
+ # CPU ๋ชจ๋“œ๋กœ ๋Œ์•„๊ฐ€๊ธฐ
160
+ pipeline.to("cpu")
161
+
162
  return state, video_path
163
 
164
  except Exception as e:
165
  print(f"Error in image_to_3d: {str(e)}")
166
+ pipeline.to("cpu")
167
  raise e
168
 
169
  @spaces.GPU
170
  def generate_image_from_text(prompt, height, width, guidance_scale, num_steps):
171
+ try:
172
+ # GPU ์„ค์ •
173
+ if torch.cuda.is_available():
174
+ flux_pipe.to("cuda")
175
+ flux_pipe.to(torch.float16)
176
+
177
+ # ๊ธฐ๋ณธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ถ”๊ฐ€
178
+ base_prompt = "wbgmsst, 3D, white background"
179
+
180
+ # ์‚ฌ์šฉ์ž ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋ฒˆ์—ญ (ํ•œ๊ตญ์–ด์ธ ๊ฒฝ์šฐ)
181
+ translated_prompt = translate_if_korean(prompt)
182
+
183
+ # ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ ์กฐํ•ฉ
184
+ final_prompt = f"{translated_prompt}, {base_prompt}"
185
+
186
+ with torch.inference_mode():
187
+ image = flux_pipe(
188
+ prompt=[final_prompt],
189
+ height=height,
190
+ width=width,
191
+ guidance_scale=guidance_scale,
192
+ num_inference_steps=num_steps
193
+ ).images[0]
194
+
195
+ # CPU ๋ชจ๋“œ๋กœ ๋Œ์•„๊ฐ€๊ธฐ
196
+ flux_pipe.to("cpu")
197
 
198
  return image
199
+
200
+ except Exception as e:
201
+ print(f"Error in generate_image_from_text: {str(e)}")
202
+ flux_pipe.to("cpu")
203
+ raise e
204
 
205
  @spaces.GPU
206
  def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]:
 
216
  def deactivate_button() -> gr.Button:
217
  return gr.Button(interactive=False)
218
 
 
219
  css = """
220
  footer {
221
  visibility: hidden;
222
  }
223
  """
224
 
225
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜
226
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
227
  gr.Markdown("""
228
  # Craft3D : 3D Asset Creation & Text-to-Image Generation
 
298
  examples_per_page=64,
299
  )
300
 
301
+ # Handlers
302
  image_prompt.upload(
303
  preprocess_image,
304
  inputs=[image_prompt],
 
312
 
313
  generate_btn.click(
314
  image_to_3d,
315
+ inputs=[trial_id, seed, randomize_seed, ss_guidance_strength, ss_sampling_steps,
316
+ slat_guidance_strength, slat_sampling_steps],
317
  outputs=[output_buf, video_output],
318
+ concurrency_limit=1
319
  ).then(
320
  activate_button,
321
+ outputs=[extract_glb_btn]
 
 
 
 
 
322
  )
323
 
324
  extract_glb_btn.click(
325
  extract_glb,
326
  inputs=[output_buf, mesh_simplify, texture_size],
327
  outputs=[model_output, download_glb],
328
+ concurrency_limit=1
329
  ).then(
330
  activate_button,
331
+ outputs=[download_glb]
332
  )
333
 
 
 
 
 
 
 
334
  generate_txt2img_btn.click(
335
  generate_image_from_text,
336
  inputs=[text_prompt, txt2img_height, txt2img_width, guidance_scale, num_steps],
337
+ outputs=[txt2img_output],
338
+ concurrency_limit=1
339
  )
340
 
341
  if __name__ == "__main__":
342
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
 
 
 
 
343
  if not initialize_models():
344
  print("Failed to initialize models")
345
  exit(1)
346
 
347
  try:
348
  # rembg ์‚ฌ์ „ ๋กœ๋“œ ์‹œ๋„
349
+ test_image = Image.fromarray(np.ones((256, 256, 3), dtype=np.uint8) * 255)
350
  pipeline.preprocess_image(test_image)
351
  except Exception as e:
352
  print(f"Warning: Failed to preload rembg: {str(e)}")
353
 
354
  # Gradio ์•ฑ ์‹คํ–‰
355
+ demo.queue(max_size=20).launch(
356
  share=True,
357
+ max_threads=4,
358
+ show_error=True
359
  )