File size: 12,591 Bytes
db6a3b7
3057b36
7d475c1
db6a3b7
690b53e
db6a3b7
9880f3d
7d475c1
db6a3b7
 
9880f3d
db6a3b7
 
9880f3d
db6a3b7
f4648fc
 
db6a3b7
ee210e2
 
 
 
f4648fc
 
 
868eab9
5201a38
 
 
 
 
 
868eab9
5201a38
 
868eab9
5201a38
868eab9
 
 
5201a38
868eab9
 
5201a38
868eab9
 
5201a38
868eab9
 
5201a38
 
 
 
868eab9
 
5201a38
 
 
d7b1815
ee210e2
 
 
 
 
bd46f72
a898014
 
db894f7
a898014
 
db6a3b7
a898014
9880f3d
 
 
 
 
 
 
 
 
 
 
 
 
a898014
9880f3d
ee210e2
 
9880f3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a898014
9880f3d
3057b36
5201a38
868eab9
 
5201a38
 
868eab9
 
 
 
 
5201a38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868eab9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5201a38
868eab9
db6a3b7
ee210e2
 
8fb8605
 
 
 
ee210e2
 
8fb8605
 
 
ee210e2
 
8fb8605
ee210e2
 
 
 
 
 
 
db6a3b7
3057b36
9880f3d
a898014
690b53e
a898014
db6a3b7
 
 
 
 
 
 
 
 
 
868eab9
 
 
 
 
 
 
 
7d475c1
868eab9
7d475c1
 
ee210e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a898014
2e78ab8
db6a3b7
ee210e2
db6a3b7
 
 
 
 
 
 
2e7f188
a898014
db6a3b7
 
 
 
ee210e2
db6a3b7
 
 
a898014
 
ee210e2
a898014
 
 
db6a3b7
 
 
 
a898014
2e78ab8
db6a3b7
 
 
 
 
 
 
 
 
 
 
 
2e78ab8
db6a3b7
 
 
 
 
 
 
 
 
 
ee210e2
 
 
 
 
 
 
db6a3b7
 
5201a38
 
 
868eab9
5201a38
 
 
 
868eab9
c666caf
5201a38
 
 
 
 
868eab9
 
5201a38
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
import gradio as gr
import spaces
from gradio_litmodel3d import LitModel3D
import os
os.environ['SPCONV_ALGO'] = 'native'
from typing import *
import torch
import numpy as np
import imageio
import uuid
from easydict import EasyDict as edict
from PIL import Image
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.representations import Gaussian, MeshExtractResult
from trellis.utils import render_utils, postprocessing_utils
from transformers import pipeline as translation_pipeline
from diffusers import FluxPipeline

MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = "/tmp/Trellis-demo"
os.makedirs(TMP_DIR, exist_ok=True)

def initialize_models():
    global pipeline, translator, flux_pipe
    
    try:
        # GPU 메모리 초기화
        torch.cuda.empty_cache()
        
        # GPU 사용 가능 여부 확인
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Trellis 파이프라인 초기화
        pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
        pipeline.to(device)
        
        # 번역기 초기화
        translator = translation_pipeline(
            "translation",
            model="Helsinki-NLP/opus-mt-ko-en",
            device=0 if device=="cuda" else -1
        )
        
        # Flux 파이프라인 초기화 
        flux_pipe = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            torch_dtype=torch.float16 if device=="cuda" else torch.float32
        )
        
        if device == "cuda":
            flux_pipe.enable_model_cpu_offload()
        
        return True
        
    except Exception as e:
        print(f"Model initialization error: {str(e)}")
        torch.cuda.empty_cache()
        return False

def translate_if_korean(text):
    if any(ord('가') <= ord(char) <= ord('힣') for char in text):
        translated = translator(text)[0]['translation_text']
        return translated
    return text

def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
    trial_id = str(uuid.uuid4())
    processed_image = pipeline.preprocess_image(image)
    processed_image.save(f"{TMP_DIR}/{trial_id}.png")
    return trial_id, processed_image

def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
    return {
        'gaussian': {
            **gs.init_params,
            '_xyz': gs._xyz.cpu().numpy(),
            '_features_dc': gs._features_dc.cpu().numpy(),
            '_scaling': gs._scaling.cpu().numpy(),
            '_rotation': gs._rotation.cpu().numpy(),
            '_opacity': gs._opacity.cpu().numpy(),
        },
        'mesh': {
            'vertices': mesh.vertices.cpu().numpy(),
            'faces': mesh.faces.cpu().numpy(),
        },
        'trial_id': trial_id,
    }


def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
    gs = Gaussian(
        aabb=state['gaussian']['aabb'],
        sh_degree=state['gaussian']['sh_degree'],
        mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
        scaling_bias=state['gaussian']['scaling_bias'],
        opacity_bias=state['gaussian']['opacity_bias'],
        scaling_activation=state['gaussian']['scaling_activation'],
    )
    gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
    gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
    gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
    gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
    gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
    
    mesh = edict(
        vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
        faces=torch.tensor(state['mesh']['faces'], device='cuda'),
    )
    
    return gs, mesh, state['trial_id']

@spaces.GPU
def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float,
                ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int):
    try:
        torch.cuda.empty_cache()
        
        if randomize_seed:
            seed = np.random.randint(0, MAX_SEED)
            
        input_image = Image.open(f"{TMP_DIR}/{trial_id}.png")
        
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            with torch.no_grad():
                outputs = pipeline.run(
                    input_image,
                    seed=seed,
                    formats=["gaussian", "mesh"],
                    preprocess_image=False,
                    sparse_structure_sampler_params={
                        "steps": ss_sampling_steps,
                        "cfg_strength": ss_guidance_strength,
                    },
                    slat_sampler_params={
                        "steps": slat_sampling_steps,
                        "cfg_strength": slat_guidance_strength,
                    }
                )
            
        # 비디오 렌더링
        video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
        video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
        video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
        
        trial_id = str(uuid.uuid4())
        video_path = f"{TMP_DIR}/{trial_id}.mp4"
        os.makedirs(os.path.dirname(video_path), exist_ok=True)
        imageio.mimsave(video_path, video, fps=15)
        
        state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
        return state, video_path
        
    except Exception as e:
        print(f"Error in image_to_3d: {str(e)}")
        torch.cuda.empty_cache()
        raise e

@spaces.GPU
def generate_image_from_text(prompt, height, width, guidance_scale, num_steps):
    # 기본 프롬프트를 추가
    base_prompt = "wbgmsst, 3D, white background"
    
    # 사용자 프롬프트를 번역 (한국어인 경우)
    translated_prompt = translate_if_korean(prompt)
    
    # 최종 프롬프트 조합
    final_prompt = f"{translated_prompt}, {base_prompt}"
    
    with torch.inference_mode():
        image = flux_pipe(
            prompt=[final_prompt],
            height=height,
            width=width,
            guidance_scale=guidance_scale,
            num_inference_steps=num_steps
        ).images[0]
        
        return image

@spaces.GPU
def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]:
    gs, mesh, trial_id = unpack_state(state)
    glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
    glb_path = f"{TMP_DIR}/{trial_id}.glb"
    glb.export(glb_path)
    return glb_path, glb_path

def activate_button() -> gr.Button:
    return gr.Button(interactive=True)

def deactivate_button() -> gr.Button:
    return gr.Button(interactive=False)


css = """
footer {
    visibility: hidden;
}
"""


with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
    gr.Markdown("""
    # Craft3D : 3D Asset Creation & Text-to-Image Generation
    """)
    
    with gr.Tabs():
        with gr.TabItem("Image to 3D"):
            with gr.Row():
                with gr.Column():
                    image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)
                    
                    with gr.Accordion(label="Generation Settings", open=False):
                        seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
                        randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
                        gr.Markdown("Stage 1: Sparse Structure Generation")
                        with gr.Row():
                            ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
                            ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
                        gr.Markdown("Stage 2: Structured Latent Generation")
                        with gr.Row():
                            slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
                            slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)

                    generate_btn = gr.Button("Generate")
                    
                    with gr.Accordion(label="GLB Extraction Settings", open=False):
                        mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
                        texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
                    
                    extract_glb_btn = gr.Button("Extract GLB", interactive=False)

                with gr.Column():
                    video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
                    model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
                    download_glb = gr.DownloadButton(label="Download GLB", interactive=False)

        with gr.TabItem("Text to Image"):
            with gr.Row():
                with gr.Column():
                    text_prompt = gr.Textbox(
                        label="Text Prompt",
                        placeholder="Enter your image description...",
                        lines=3
                    )
                    
                    with gr.Row():
                        txt2img_height = gr.Slider(256, 1024, value=512, step=64, label="Height")
                        txt2img_width = gr.Slider(256, 1024, value=512, step=64, label="Width")
                    
                    with gr.Row():
                        guidance_scale = gr.Slider(1.0, 20.0, value=7.5, label="Guidance Scale")
                        num_steps = gr.Slider(1, 50, value=20, label="Number of Steps")
                    
                    generate_txt2img_btn = gr.Button("Generate Image")
                
                with gr.Column():
                    txt2img_output = gr.Image(label="Generated Image")
    
    trial_id = gr.Textbox(visible=False)
    output_buf = gr.State()

    # Example images
    with gr.Row():
        examples = gr.Examples(
            examples=[
                f'assets/example_image/{image}'
                for image in os.listdir("assets/example_image")
            ],
            inputs=[image_prompt],
            fn=preprocess_image,
            outputs=[trial_id, image_prompt],
            run_on_click=True,
            examples_per_page=64,
        )

# Handlers
    image_prompt.upload(
        preprocess_image,
        inputs=[image_prompt],
        outputs=[trial_id, image_prompt],
    )
    
    image_prompt.clear(
        lambda: '',
        outputs=[trial_id],
    )

    generate_btn.click(
        image_to_3d,
        inputs=[trial_id, seed, randomize_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
        outputs=[output_buf, video_output],
    ).then(
        activate_button,
        outputs=[extract_glb_btn],
    )

    video_output.clear(
        deactivate_button,
        outputs=[extract_glb_btn],
    )

    extract_glb_btn.click(
        extract_glb,
        inputs=[output_buf, mesh_simplify, texture_size],
        outputs=[model_output, download_glb],
    ).then(
        activate_button,
        outputs=[download_glb],
    )

    model_output.clear(
        deactivate_button,
        outputs=[download_glb],
    )

    # Text to Image 핸들러
    generate_txt2img_btn.click(
        generate_image_from_text,
        inputs=[text_prompt, txt2img_height, txt2img_width, guidance_scale, num_steps],
        outputs=[txt2img_output]
    )

if __name__ == "__main__":
    # 초기 GPU 메모리 정리
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # 모델 초기화 확인
    if not initialize_models():
        print("Failed to initialize models")
        exit(1)
    
    try:
        # rembg 사전 로드 시도
        test_image = Image.fromarray(np.zeros((256, 256, 3), dtype=np.uint8))
        pipeline.preprocess_image(test_image)
    except Exception as e:
        print(f"Warning: Failed to preload rembg: {str(e)}")
    
    # Gradio 앱 실행
    demo.queue(concurrency_count=1).launch(
        share=True,
        enable_queue=True,
        max_threads=1
    )