aiqtech commited on
Commit
b209823
·
verified ·
1 Parent(s): a135ad5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -37
app.py CHANGED
@@ -15,6 +15,15 @@ from transformers import pipeline as translation_pipeline
15
  from diffusers import FluxPipeline
16
  from typing import *
17
 
 
 
 
 
 
 
 
 
 
18
  # 환경 변수 설정
19
  os.environ['SPCONV_ALGO'] = 'native'
20
  os.environ['WARP_USE_CPU'] = '1' # Warp를 CPU 모드로 강제
@@ -27,27 +36,43 @@ def initialize_models():
27
  global pipeline, translator, flux_pipe
28
 
29
  try:
30
- # Trellis 파이프라인 초기화 (CPU 모드로)
31
- pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
 
 
 
 
 
32
 
33
- # 번역기 초기화 (CPU 모드로)
34
  translator = translation_pipeline(
35
  "translation",
36
  model="Helsinki-NLP/opus-mt-ko-en",
37
- device=-1
 
 
 
 
38
  )
39
 
40
- # Flux 파이프라인 초기화 (CPU 모드로)
41
  flux_pipe = FluxPipeline.from_pretrained(
42
  "black-forest-labs/FLUX.1-dev",
43
- torch_dtype=torch.float32
 
 
 
44
  )
45
 
 
 
 
46
  print("Models initialized successfully")
47
  return True
48
 
49
  except Exception as e:
50
  print(f"Model initialization error: {str(e)}")
 
51
  return False
52
 
53
  def translate_if_korean(text):
@@ -119,34 +144,46 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
119
  def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float,
120
  ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int):
121
  try:
 
 
122
  if randomize_seed:
123
  seed = np.random.randint(0, MAX_SEED)
124
 
125
  input_image = Image.open(f"{TMP_DIR}/{trial_id}.png")
126
 
127
- # GPU 설정
128
- if torch.cuda.is_available():
129
- pipeline.to("cuda")
130
- pipeline.to(torch.float16)
131
 
132
- with torch.no_grad():
133
- outputs = pipeline.run(
134
- input_image,
135
- seed=seed,
136
- formats=["gaussian", "mesh"],
137
- preprocess_image=False,
138
- sparse_structure_sampler_params={
139
- "steps": ss_sampling_steps,
140
- "cfg_strength": ss_guidance_strength,
141
- },
142
- slat_sampler_params={
143
- "steps": slat_sampling_steps,
144
- "cfg_strength": slat_guidance_strength,
145
- }
146
  )
147
-
148
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
149
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
151
 
152
  trial_id = str(uuid.uuid4())
@@ -156,14 +193,12 @@ def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_stre
156
 
157
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
158
 
159
- # CPU 모드로 돌아가기
160
- pipeline.to("cpu")
161
-
162
  return state, video_path
163
 
164
  except Exception as e:
165
  print(f"Error in image_to_3d: {str(e)}")
166
- pipeline.to("cpu")
167
  raise e
168
 
169
  @spaces.GPU
@@ -221,7 +256,23 @@ footer {
221
  visibility: hidden;
222
  }
223
  """
224
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  # Gradio 인터페이스 정의
226
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
227
  gr.Markdown("""
@@ -339,21 +390,27 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
339
  )
340
 
341
  if __name__ == "__main__":
 
 
342
  # 모델 초기화
343
  if not initialize_models():
344
  print("Failed to initialize models")
345
  exit(1)
346
 
347
  try:
348
- # rembg 사전 로드 시도
349
- test_image = Image.fromarray(np.ones((256, 256, 3), dtype=np.uint8) * 255)
350
  pipeline.preprocess_image(test_image)
351
  except Exception as e:
352
  print(f"Warning: Failed to preload rembg: {str(e)}")
353
 
354
  # Gradio 앱 실행
355
- demo.queue(max_size=20).launch(
356
  share=True,
357
- max_threads=4,
358
- show_error=True
 
 
 
 
359
  )
 
15
  from diffusers import FluxPipeline
16
  from typing import *
17
 
18
+
19
+ # 메모리 관련 환경 변수
20
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
21
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
22
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
23
+ os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
24
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
25
+ os.environ['HF_HOME'] = '/tmp/huggingface'
26
+
27
  # 환경 변수 설정
28
  os.environ['SPCONV_ALGO'] = 'native'
29
  os.environ['WARP_USE_CPU'] = '1' # Warp를 CPU 모드로 강제
 
36
  global pipeline, translator, flux_pipe
37
 
38
  try:
39
+ # Trellis 파이프라인 초기화 ( 강화된 메모리 최적화)
40
+ pipeline = TrellisImageTo3DPipeline.from_pretrained(
41
+ "JeffreyXiang/TRELLIS-image-large",
42
+ device_map="auto",
43
+ low_cpu_mem_usage=True,
44
+ torch_dtype=torch.float16 # 반정밀도 사용
45
+ )
46
 
47
+ # 번역기 초기화 ( 작은 모델 사용)
48
  translator = translation_pipeline(
49
  "translation",
50
  model="Helsinki-NLP/opus-mt-ko-en",
51
+ device="cpu",
52
+ model_kwargs={
53
+ "low_cpu_mem_usage": True,
54
+ "torch_dtype": torch.float16
55
+ }
56
  )
57
 
58
+ # Flux 파이프라인 초기화 (메모리 최적화)
59
  flux_pipe = FluxPipeline.from_pretrained(
60
  "black-forest-labs/FLUX.1-dev",
61
+ device_map="auto",
62
+ low_cpu_mem_usage=True,
63
+ torch_dtype=torch.float16,
64
+ variant="fp16"
65
  )
66
 
67
+ # 불필요한 캐시 정리
68
+ free_memory()
69
+
70
  print("Models initialized successfully")
71
  return True
72
 
73
  except Exception as e:
74
  print(f"Model initialization error: {str(e)}")
75
+ free_memory()
76
  return False
77
 
78
  def translate_if_korean(text):
 
144
  def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float,
145
  ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int):
146
  try:
147
+ free_memory()
148
+
149
  if randomize_seed:
150
  seed = np.random.randint(0, MAX_SEED)
151
 
152
  input_image = Image.open(f"{TMP_DIR}/{trial_id}.png")
153
 
154
+ # GPU 메모리 사용량 제한
155
+ torch.cuda.set_per_process_memory_fraction(0.6)
 
 
156
 
157
+ # 더 작은 이미지 크기 사용
158
+ max_size = 512
159
+ if max(input_image.size) > max_size:
160
+ ratio = max_size / max(input_image.size)
161
+ input_image = input_image.resize(
162
+ (int(input_image.size[0] * ratio),
163
+ int(input_image.size[1] * ratio)),
164
+ Image.LANCZOS
 
 
 
 
 
 
165
  )
166
+
167
+ with torch.cuda.amp.autocast():
168
+ with torch.no_grad():
169
+ outputs = pipeline.run(
170
+ input_image,
171
+ seed=seed,
172
+ formats=["gaussian", "mesh"],
173
+ preprocess_image=False,
174
+ sparse_structure_sampler_params={
175
+ "steps": min(ss_sampling_steps, 15),
176
+ "cfg_strength": ss_guidance_strength,
177
+ },
178
+ slat_sampler_params={
179
+ "steps": min(slat_sampling_steps, 15),
180
+ "cfg_strength": slat_guidance_strength,
181
+ }
182
+ )
183
+
184
+ # 더 적은 프레임으로 비디오 생성
185
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=30)['color']
186
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=30)['normal']
187
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
188
 
189
  trial_id = str(uuid.uuid4())
 
193
 
194
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
195
 
196
+ free_memory()
 
 
197
  return state, video_path
198
 
199
  except Exception as e:
200
  print(f"Error in image_to_3d: {str(e)}")
201
+ free_memory()
202
  raise e
203
 
204
  @spaces.GPU
 
256
  visibility: hidden;
257
  }
258
  """
259
+ def free_memory():
260
+ """메모리를 정리하는 강화된 유틸리티 함수"""
261
+ import gc
262
+ import psutil
263
+
264
+ # Python 가비지 컬렉션 강제 실행
265
+ gc.collect()
266
+
267
+ # CUDA 메모리 정리
268
+ if torch.cuda.is_available():
269
+ torch.cuda.empty_cache()
270
+ torch.cuda.synchronize()
271
+
272
+ # RAM 캐시 정리 시도
273
+ if psutil.POSIX:
274
+ import os
275
+ os.system('sync')
276
  # Gradio 인터페이스 정의
277
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
278
  gr.Markdown("""
 
390
  )
391
 
392
  if __name__ == "__main__":
393
+ free_memory()
394
+
395
  # 모델 초기화
396
  if not initialize_models():
397
  print("Failed to initialize models")
398
  exit(1)
399
 
400
  try:
401
+ # 최소 크기 이미지로 rembg 테스트
402
+ test_image = Image.fromarray(np.ones((64, 64, 3), dtype=np.uint8) * 255)
403
  pipeline.preprocess_image(test_image)
404
  except Exception as e:
405
  print(f"Warning: Failed to preload rembg: {str(e)}")
406
 
407
  # Gradio 앱 실행
408
+ demo.queue(max_size=5).launch(
409
  share=True,
410
+ max_threads=2,
411
+ show_error=True,
412
+ cache_examples=False,
413
+ enable_queue=True,
414
+ server_port=7860,
415
+ server_name="0.0.0.0"
416
  )