aiqtech commited on
Commit
8de87eb
โ€ข
1 Parent(s): c589cc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -62
app.py CHANGED
@@ -34,7 +34,10 @@ class GlobalVars:
34
 
35
  g = GlobalVars()
36
 
37
-
 
 
 
38
 
39
  def initialize_models(device):
40
  try:
@@ -85,10 +88,6 @@ def initialize_models(device):
85
  print(f"Error during model initialization: {str(e)}")
86
  raise
87
 
88
- # CUDA ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ์„ค์ •
89
- torch.cuda.empty_cache()
90
- torch.backends.cuda.matmul.allow_tf32 = True
91
- torch.backends.cudnn.benchmark = True
92
 
93
  # ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
94
  # ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
@@ -104,6 +103,13 @@ os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1'
104
  # CUDA ์ดˆ๊ธฐํ™” ๋ฐฉ์ง€
105
  torch.set_grad_enabled(False)
106
 
 
 
 
 
 
 
 
107
 
108
  # Hugging Face ํ† ํฐ ์„ค์ •
109
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -209,123 +215,110 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
209
  return gs, mesh, state['trial_id']
210
 
211
  @spaces.GPU
212
- def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int) -> Tuple[dict, str]:
213
- print(f"Starting image_to_3d with trial_id: {trial_id}")
214
-
215
- if not trial_id or trial_id.strip() == "":
216
- print("Error: No trial_id provided")
217
- return None, None
218
-
219
  try:
220
- # CUDA ๋ฉ”๋ชจ๋ฆฌ ์ดˆ๊ธฐํ™”
221
- if torch.cuda.is_available():
222
- torch.cuda.empty_cache()
223
- torch.cuda.synchronize()
224
- gc.collect()
225
 
226
- if randomize_seed:
227
- seed = np.random.randint(0, MAX_SEED)
228
 
229
  image_path = f"{TMP_DIR}/{trial_id}.png"
230
- print(f"Looking for image at: {image_path}")
231
-
232
  if not os.path.exists(image_path):
233
- print(f"Error: Image file not found at {image_path}")
234
  return None, None
235
 
236
  image = Image.open(image_path)
237
- print(f"Successfully loaded image with size: {image.size}")
238
 
239
- # ์ด๋ฏธ์ง€ ํฌ๊ธฐ ์ œํ•œ
240
- max_size = 512
241
  if max(image.size) > max_size:
242
  ratio = max_size / max(image.size)
243
  new_size = tuple(int(dim * ratio) for dim in image.size)
244
  image = image.resize(new_size, Image.LANCZOS)
245
- print(f"Resized image to: {image.size}")
246
 
247
- # GPU ์ž‘์—… ์‹œ์ž‘
248
  with torch.inference_mode():
249
  try:
250
- # ๋ชจ๋ธ์„ GPU๋กœ ์ด๋™
251
  g.trellis_pipeline.to('cuda')
252
- torch.cuda.synchronize()
253
 
254
- # 3D ์ƒ์„ฑ
255
  outputs = g.trellis_pipeline.run(
256
  image,
257
  seed=seed,
258
  formats=["gaussian", "mesh"],
259
  preprocess_image=False,
260
  sparse_structure_sampler_params={
261
- "steps": min(ss_sampling_steps, 12),
262
  "cfg_strength": ss_guidance_strength,
 
263
  },
264
  slat_sampler_params={
265
- "steps": min(slat_sampling_steps, 12),
266
  "cfg_strength": slat_guidance_strength,
 
267
  },
268
  )
269
- torch.cuda.synchronize()
270
 
271
- # ๋น„๋””์˜ค ๋ Œ๋”๋ง
 
 
 
272
  video = render_utils.render_video(
273
  outputs['gaussian'][0],
274
- num_frames=60,
275
- resolution=512
276
  )['color']
277
- torch.cuda.synchronize()
278
 
279
  video_geo = render_utils.render_video(
280
  outputs['mesh'][0],
281
- num_frames=60,
282
- resolution=512
283
  )['normal']
284
- torch.cuda.synchronize()
285
 
286
- # CPU๋กœ ๋ฐ์ดํ„ฐ ์ด๋™
287
- video = [v.cpu().numpy() if torch.is_tensor(v) else v for v in video]
288
- video_geo = [v.cpu().numpy() if torch.is_tensor(v) else v for v in video_geo]
 
289
 
 
290
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
291
  new_trial_id = str(uuid.uuid4())
292
  video_path = f"{TMP_DIR}/{new_trial_id}.mp4"
293
- os.makedirs(os.path.dirname(video_path), exist_ok=True)
294
  imageio.mimsave(video_path, video, fps=15)
295
 
296
- # ์ƒํƒœ ์ €์žฅ
297
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], new_trial_id)
298
-
299
  return state, video_path
300
 
301
  finally:
302
  # ์ •๋ฆฌ ์ž‘์—…
303
  g.trellis_pipeline.to('cpu')
304
- if torch.cuda.is_available():
305
- torch.cuda.empty_cache()
306
- torch.cuda.synchronize()
307
- gc.collect()
308
-
309
  except Exception as e:
310
  print(f"Error in image_to_3d: {str(e)}")
311
- if hasattr(g.trellis_pipeline, 'to'):
312
- g.trellis_pipeline.to('cpu')
313
- if torch.cuda.is_available():
314
- torch.cuda.empty_cache()
315
- torch.cuda.synchronize()
316
- gc.collect()
317
  return None, None
318
 
319
  def clear_gpu_memory():
320
- """GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ •๋ฆฌํ•˜๋Š” ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜"""
321
  try:
322
  if torch.cuda.is_available():
323
- with torch.cuda.device('cuda'):
324
- torch.cuda.empty_cache()
325
- torch.cuda.synchronize()
 
 
 
 
 
 
 
 
326
  gc.collect()
327
  except Exception as e:
328
- print(f"Error clearing GPU memory: {e}")
329
 
330
  def move_to_device(model, device):
331
  """๋ชจ๋ธ์„ ์•ˆ์ „ํ•˜๊ฒŒ ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™ํ•˜๋Š” ํ•จ์ˆ˜"""
 
34
 
35
  g = GlobalVars()
36
 
37
+ # ํŒŒ์ผ ์ƒ๋‹จ์— ์ถ”๊ฐ€
38
+ torch.backends.cudnn.benchmark = False # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๊ฐ์†Œ
39
+ torch.backends.cudnn.deterministic = True
40
+ torch.cuda.set_per_process_memory_fraction(0.7) # GPU ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ์ œํ•œ
41
 
42
  def initialize_models(device):
43
  try:
 
88
  print(f"Error during model initialization: {str(e)}")
89
  raise
90
 
 
 
 
 
91
 
92
  # ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
93
  # ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
 
103
  # CUDA ์ดˆ๊ธฐํ™” ๋ฐฉ์ง€
104
  torch.set_grad_enabled(False)
105
 
106
+ def periodic_cleanup():
107
+ """์ฃผ๊ธฐ์ ์œผ๋กœ ์‹คํ–‰๋  ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜"""
108
+ clear_gpu_memory()
109
+ return None
110
+
111
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค์— ์ฃผ๊ธฐ์  ์ •๋ฆฌ ์ถ”๊ฐ€
112
+ demo.load(periodic_cleanup, every=5) # 5์ดˆ๋งˆ๋‹ค ์ •๋ฆฌ
113
 
114
  # Hugging Face ํ† ํฐ ์„ค์ •
115
  HF_TOKEN = os.getenv("HF_TOKEN")
 
215
  return gs, mesh, state['trial_id']
216
 
217
  @spaces.GPU
218
+ def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float,
219
+ ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int) -> Tuple[dict, str]:
 
 
 
 
 
220
  try:
221
+ # ์ดˆ๊ธฐ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
222
+ clear_gpu_memory()
 
 
 
223
 
224
+ if not trial_id or trial_id.strip() == "":
225
+ return None, None
226
 
227
  image_path = f"{TMP_DIR}/{trial_id}.png"
 
 
228
  if not os.path.exists(image_path):
 
229
  return None, None
230
 
231
  image = Image.open(image_path)
 
232
 
233
+ # ์ด๋ฏธ์ง€ ํฌ๊ธฐ ์ œํ•œ ๊ฐ•ํ™”
234
+ max_size = 384 # ๋” ์ž‘์€ ํฌ๊ธฐ๋กœ ์ œํ•œ
235
  if max(image.size) > max_size:
236
  ratio = max_size / max(image.size)
237
  new_size = tuple(int(dim * ratio) for dim in image.size)
238
  image = image.resize(new_size, Image.LANCZOS)
 
239
 
 
240
  with torch.inference_mode():
241
  try:
242
+ # ํŒŒ์ดํ”„๋ผ์ธ์„ GPU๋กœ ์ด๋™
243
  g.trellis_pipeline.to('cuda')
 
244
 
245
+ # ๋ฐฐ์น˜ ํฌ๊ธฐ ์ œํ•œ
246
  outputs = g.trellis_pipeline.run(
247
  image,
248
  seed=seed,
249
  formats=["gaussian", "mesh"],
250
  preprocess_image=False,
251
  sparse_structure_sampler_params={
252
+ "steps": min(ss_sampling_steps, 8), # ์Šคํ… ์ˆ˜ ์ œํ•œ
253
  "cfg_strength": ss_guidance_strength,
254
+ "batch_size": 1 # ๋ฐฐ์น˜ ํฌ๊ธฐ ๋ช…์‹œ์  ์ œํ•œ
255
  },
256
  slat_sampler_params={
257
+ "steps": min(slat_sampling_steps, 8), # ์Šคํ… ์ˆ˜ ์ œํ•œ
258
  "cfg_strength": slat_guidance_strength,
259
+ "batch_size": 1
260
  },
261
  )
 
262
 
263
+ # ์ค‘๊ฐ„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
264
+ clear_gpu_memory()
265
+
266
+ # ๋น„๋””์˜ค ๋ Œ๋”๋ง ์ตœ์ ํ™”
267
  video = render_utils.render_video(
268
  outputs['gaussian'][0],
269
+ num_frames=30, # ํ”„๋ ˆ์ž„ ์ˆ˜ ๊ฐ์†Œ
270
+ resolution=384 # ํ•ด์ƒ๋„ ์ œํ•œ
271
  )['color']
 
272
 
273
  video_geo = render_utils.render_video(
274
  outputs['mesh'][0],
275
+ num_frames=30,
276
+ resolution=384
277
  )['normal']
 
278
 
279
+ # CPU๋กœ ๋ฐ์ดํ„ฐ ์ด๋™ ๋ฐ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
280
+ video = [v.cpu().numpy() for v in video]
281
+ video_geo = [v.cpu().numpy() for v in video_geo]
282
+ clear_gpu_memory()
283
 
284
+ # ๋‚˜๋จธ์ง€ ์ฒ˜๋ฆฌ
285
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
286
  new_trial_id = str(uuid.uuid4())
287
  video_path = f"{TMP_DIR}/{new_trial_id}.mp4"
 
288
  imageio.mimsave(video_path, video, fps=15)
289
 
 
290
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], new_trial_id)
 
291
  return state, video_path
292
 
293
  finally:
294
  # ์ •๋ฆฌ ์ž‘์—…
295
  g.trellis_pipeline.to('cpu')
296
+ clear_gpu_memory()
297
+
 
 
 
298
  except Exception as e:
299
  print(f"Error in image_to_3d: {str(e)}")
300
+ g.trellis_pipeline.to('cpu')
301
+ clear_gpu_memory()
 
 
 
 
302
  return None, None
303
 
304
  def clear_gpu_memory():
305
+ """GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋” ์ฒ ์ €ํ•˜๊ฒŒ ์ •๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜"""
306
  try:
307
  if torch.cuda.is_available():
308
+ # ๋ชจ๋“  GPU ์บ์‹œ ์ •๋ฆฌ
309
+ torch.cuda.empty_cache()
310
+ torch.cuda.synchronize()
311
+
312
+ # ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š” ์บ์‹œ๋œ ๋ฉ”๋ชจ๋ฆฌ ํ•ด์ œ
313
+ for i in range(torch.cuda.device_count()):
314
+ with torch.cuda.device(i):
315
+ torch.cuda.empty_cache()
316
+ torch.cuda.ipc_collect()
317
+
318
+ # Python ๊ฐ€๋น„์ง€ ์ปฌ๋ ‰ํ„ฐ ์‹คํ–‰
319
  gc.collect()
320
  except Exception as e:
321
+ print(f"Error in clear_gpu_memory: {e}")
322
 
323
  def move_to_device(model, device):
324
  """๋ชจ๋ธ์„ ์•ˆ์ „ํ•˜๊ฒŒ ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™ํ•˜๋Š” ํ•จ์ˆ˜"""