aiqtech commited on
Commit
ee210e2
1 Parent(s): c260a18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -90
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
4
-
5
  import os
6
  os.environ['SPCONV_ALGO'] = 'native'
7
  from typing import *
@@ -14,12 +13,13 @@ from PIL import Image
14
  from trellis.pipelines import TrellisImageTo3DPipeline
15
  from trellis.representations import Gaussian, MeshExtractResult
16
  from trellis.utils import render_utils, postprocessing_utils
17
-
18
- # 기존 import문 아래에 추가
19
  from transformers import pipeline as translation_pipeline
20
  from diffusers import FluxPipeline
21
 
22
- # 초기화 부분에 추가
 
 
 
23
  def initialize_models():
24
  global pipeline, translator, flux_pipe
25
 
@@ -35,30 +35,19 @@ def initialize_models():
35
  flux_pipe.load_lora_weights("gokaygokay/Flux-Game-Assets-LoRA-v2")
36
  flux_pipe.fuse_lora(lora_scale=1.0)
37
  flux_pipe.to(device="cuda", dtype=torch.bfloat16)
38
-
39
- MAX_SEED = np.iinfo(np.int32).max
40
- TMP_DIR = "/tmp/Trellis-demo"
41
-
42
- os.makedirs(TMP_DIR, exist_ok=True)
43
 
 
 
 
 
 
44
 
45
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
46
- """
47
- Preprocess the input image.
48
-
49
- Args:
50
- image (Image.Image): The input image.
51
-
52
- Returns:
53
- str: uuid of the trial.
54
- Image.Image: The preprocessed image.
55
- """
56
  trial_id = str(uuid.uuid4())
57
  processed_image = pipeline.preprocess_image(image)
58
  processed_image.save(f"{TMP_DIR}/{trial_id}.png")
59
  return trial_id, processed_image
60
 
61
-
62
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
63
  return {
64
  'gaussian': {
@@ -75,8 +64,8 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
75
  },
76
  'trial_id': trial_id,
77
  }
78
-
79
-
80
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
81
  gs = Gaussian(
82
  aabb=state['gaussian']['aabb'],
@@ -99,25 +88,8 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
99
 
100
  return gs, mesh, state['trial_id']
101
 
102
-
103
  @spaces.GPU
104
  def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int) -> Tuple[dict, str]:
105
- """
106
- Convert an image to a 3D model.
107
-
108
- Args:
109
- trial_id (str): The uuid of the trial.
110
- seed (int): The random seed.
111
- randomize_seed (bool): Whether to randomize the seed.
112
- ss_guidance_strength (float): The guidance strength for sparse structure generation.
113
- ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
114
- slat_guidance_strength (float): The guidance strength for structured latent generation.
115
- slat_sampling_steps (int): The number of sampling steps for structured latent generation.
116
-
117
- Returns:
118
- dict: The information of the generated 3D model.
119
- str: The path to the video of the 3D model.
120
- """
121
  if randomize_seed:
122
  seed = np.random.randint(0, MAX_SEED)
123
  outputs = pipeline.run(
@@ -144,75 +116,98 @@ def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_stre
144
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
145
  return state, video_path
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  @spaces.GPU
149
  def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]:
150
- """
151
- Extract a GLB file from the 3D model.
152
-
153
- Args:
154
- state (dict): The state of the generated 3D model.
155
- mesh_simplify (float): The mesh simplification factor.
156
- texture_size (int): The texture resolution.
157
-
158
- Returns:
159
- str: The path to the extracted GLB file.
160
- """
161
  gs, mesh, trial_id = unpack_state(state)
162
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
163
  glb_path = f"{TMP_DIR}/{trial_id}.glb"
164
  glb.export(glb_path)
165
  return glb_path, glb_path
166
 
167
-
168
  def activate_button() -> gr.Button:
169
  return gr.Button(interactive=True)
170
 
171
-
172
  def deactivate_button() -> gr.Button:
173
  return gr.Button(interactive=False)
174
 
175
 
176
  with gr.Blocks() as demo:
177
  gr.Markdown("""
178
- ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
179
- * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
180
- * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
181
  """)
182
 
183
- with gr.Row():
184
- with gr.Column():
185
- image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)
186
-
187
- with gr.Accordion(label="Generation Settings", open=False):
188
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
189
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
190
- gr.Markdown("Stage 1: Sparse Structure Generation")
191
- with gr.Row():
192
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
193
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
194
- gr.Markdown("Stage 2: Structured Latent Generation")
195
- with gr.Row():
196
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
197
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
198
-
199
- generate_btn = gr.Button("Generate")
200
-
201
- with gr.Accordion(label="GLB Extraction Settings", open=False):
202
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
203
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
204
-
205
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
206
-
207
- with gr.Column():
208
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
209
- model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
210
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
211
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  trial_id = gr.Textbox(visible=False)
213
  output_buf = gr.State()
214
 
215
- # Example images at the bottom of the page
216
  with gr.Row():
217
  examples = gr.Examples(
218
  examples=[
@@ -226,12 +221,13 @@ with gr.Blocks() as demo:
226
  examples_per_page=64,
227
  )
228
 
229
- # Handlers
230
  image_prompt.upload(
231
  preprocess_image,
232
  inputs=[image_prompt],
233
  outputs=[trial_id, image_prompt],
234
  )
 
235
  image_prompt.clear(
236
  lambda: '',
237
  outputs=[trial_id],
@@ -264,14 +260,19 @@ with gr.Blocks() as demo:
264
  deactivate_button,
265
  outputs=[download_glb],
266
  )
267
-
 
 
 
 
 
 
268
 
269
  # Launch the Gradio app
270
  if __name__ == "__main__":
271
- pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
272
- pipeline.cuda()
273
  try:
274
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
275
  except:
276
  pass
277
- demo.launch()
 
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
 
4
  import os
5
  os.environ['SPCONV_ALGO'] = 'native'
6
  from typing import *
 
13
  from trellis.pipelines import TrellisImageTo3DPipeline
14
  from trellis.representations import Gaussian, MeshExtractResult
15
  from trellis.utils import render_utils, postprocessing_utils
 
 
16
  from transformers import pipeline as translation_pipeline
17
  from diffusers import FluxPipeline
18
 
19
+ MAX_SEED = np.iinfo(np.int32).max
20
+ TMP_DIR = "/tmp/Trellis-demo"
21
+ os.makedirs(TMP_DIR, exist_ok=True)
22
+
23
  def initialize_models():
24
  global pipeline, translator, flux_pipe
25
 
 
35
  flux_pipe.load_lora_weights("gokaygokay/Flux-Game-Assets-LoRA-v2")
36
  flux_pipe.fuse_lora(lora_scale=1.0)
37
  flux_pipe.to(device="cuda", dtype=torch.bfloat16)
 
 
 
 
 
38
 
39
+ def translate_if_korean(text):
40
+ if any(ord('가') <= ord(char) <= ord('힣') for char in text):
41
+ translated = translator(text)[0]['translation_text']
42
+ return translated
43
+ return text
44
 
45
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
 
 
 
 
 
 
 
 
 
 
46
  trial_id = str(uuid.uuid4())
47
  processed_image = pipeline.preprocess_image(image)
48
  processed_image.save(f"{TMP_DIR}/{trial_id}.png")
49
  return trial_id, processed_image
50
 
 
51
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
52
  return {
53
  'gaussian': {
 
64
  },
65
  'trial_id': trial_id,
66
  }
67
+
68
+
69
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
70
  gs = Gaussian(
71
  aabb=state['gaussian']['aabb'],
 
88
 
89
  return gs, mesh, state['trial_id']
90
 
 
91
  @spaces.GPU
92
  def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int) -> Tuple[dict, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  if randomize_seed:
94
  seed = np.random.randint(0, MAX_SEED)
95
  outputs = pipeline.run(
 
116
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
117
  return state, video_path
118
 
119
+ @spaces.GPU
120
+ def generate_image_from_text(prompt, height, width, guidance_scale, num_steps):
121
+ translated_prompt = translate_if_korean(prompt)
122
+
123
+ with torch.inference_mode():
124
+ image = flux_pipe(
125
+ prompt=[translated_prompt],
126
+ height=height,
127
+ width=width,
128
+ guidance_scale=guidance_scale,
129
+ num_inference_steps=num_steps
130
+ ).images[0]
131
+
132
+ return image
133
 
134
  @spaces.GPU
135
  def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]:
 
 
 
 
 
 
 
 
 
 
 
136
  gs, mesh, trial_id = unpack_state(state)
137
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
138
  glb_path = f"{TMP_DIR}/{trial_id}.glb"
139
  glb.export(glb_path)
140
  return glb_path, glb_path
141
 
 
142
  def activate_button() -> gr.Button:
143
  return gr.Button(interactive=True)
144
 
 
145
  def deactivate_button() -> gr.Button:
146
  return gr.Button(interactive=False)
147
 
148
 
149
  with gr.Blocks() as demo:
150
  gr.Markdown("""
151
+ # 3D Asset Creation & Text-to-Image Generation
 
 
152
  """)
153
 
154
+ with gr.Tabs():
155
+ with gr.TabItem("Image to 3D"):
156
+ with gr.Row():
157
+ with gr.Column():
158
+ image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)
159
+
160
+ with gr.Accordion(label="Generation Settings", open=False):
161
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
162
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
163
+ gr.Markdown("Stage 1: Sparse Structure Generation")
164
+ with gr.Row():
165
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
166
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
167
+ gr.Markdown("Stage 2: Structured Latent Generation")
168
+ with gr.Row():
169
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
170
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
171
+
172
+ generate_btn = gr.Button("Generate")
173
+
174
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
175
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
176
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
177
+
178
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
179
+
180
+ with gr.Column():
181
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
182
+ model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
183
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
184
+
185
+ with gr.TabItem("Text to Image"):
186
+ with gr.Row():
187
+ with gr.Column():
188
+ text_prompt = gr.Textbox(
189
+ label="Text Prompt",
190
+ placeholder="Enter your image description...",
191
+ lines=3
192
+ )
193
+
194
+ with gr.Row():
195
+ txt2img_height = gr.Slider(256, 1024, value=512, step=64, label="Height")
196
+ txt2img_width = gr.Slider(256, 1024, value=512, step=64, label="Width")
197
+
198
+ with gr.Row():
199
+ guidance_scale = gr.Slider(1.0, 20.0, value=7.5, label="Guidance Scale")
200
+ num_steps = gr.Slider(1, 50, value=20, label="Number of Steps")
201
+
202
+ generate_txt2img_btn = gr.Button("Generate Image")
203
+
204
+ with gr.Column():
205
+ txt2img_output = gr.Image(label="Generated Image")
206
+
207
  trial_id = gr.Textbox(visible=False)
208
  output_buf = gr.State()
209
 
210
+ # Example images
211
  with gr.Row():
212
  examples = gr.Examples(
213
  examples=[
 
221
  examples_per_page=64,
222
  )
223
 
224
+ # Handlers
225
  image_prompt.upload(
226
  preprocess_image,
227
  inputs=[image_prompt],
228
  outputs=[trial_id, image_prompt],
229
  )
230
+
231
  image_prompt.clear(
232
  lambda: '',
233
  outputs=[trial_id],
 
260
  deactivate_button,
261
  outputs=[download_glb],
262
  )
263
+
264
+ # Text to Image 핸들러
265
+ generate_txt2img_btn.click(
266
+ generate_image_from_text,
267
+ inputs=[text_prompt, txt2img_height, txt2img_width, guidance_scale, num_steps],
268
+ outputs=[txt2img_output]
269
+ )
270
 
271
  # Launch the Gradio app
272
  if __name__ == "__main__":
273
+ initialize_models() # 모든 모델 초기화
 
274
  try:
275
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
276
  except:
277
  pass
278
+ demo.launch()