elioonpc commited on
Commit
8e2da05
·
verified ·
1 Parent(s): 0b6b43b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -1
app.py CHANGED
@@ -188,4 +188,194 @@ with gr.Blocks() as demo:
188
  )
189
 
190
  demo.queue(max_size=10)
191
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  )
189
 
190
  demo.queue(max_size=10)
191
+ demo.launch(import logging
192
+ import os
193
+ import shlex
194
+ import subprocess
195
+ import tempfile
196
+ import time
197
+
198
+ import gradio as gr
199
+ import numpy as np
200
+ import rembg
201
+ import spaces
202
+ import torch
203
+ from PIL import Image
204
+ from functools import partial
205
+
206
+ subprocess.run(shlex.split('pip install wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl'))
207
+
208
+ from tsr.system import TSR
209
+ from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
210
+
211
+
212
+ HEADER = """
213
+ # TripoSR Demo
214
+ <table bgcolor="#1E2432" cellspacing="0" cellpadding="0" width="450">
215
+ <tr style="height:50px;">
216
+ <td style="text-align: center;">
217
+ <a href="https://stability.ai">
218
+ <img src="https://images.squarespace-cdn.com/content/v1/6213c340453c3f502425776e/6c9c4c25-5410-4547-bc26-dc621cdacb25/Stability+AI+logo.png" width="200" height="40" />
219
+ </a>
220
+ </td>
221
+ <td style="text-align: center;">
222
+ <a href="https://www.tripo3d.ai">
223
+ <img src="https://www.tripo3d.ai/logo.png" width="170" height="40" />
224
+ </a>
225
+ </td>
226
+ </tr>
227
+ </table>
228
+ <table bgcolor="#1E2432" cellspacing="0" cellpadding="0" width="450">
229
+ <tr style="height:30px;">
230
+ <td style="text-align: center;">
231
+ <a href="https://huggingface.co/stabilityai/TripoSR"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange" height="20"></a>
232
+ </td>
233
+ <td style="text-align: center;">
234
+ <a href="https://github.com/VAST-AI-Research/TripoSR"><img src="https://postimage.me/images/2024/03/04/GitHub_Logo_White.png" width="100" height="20"></a>
235
+ </td>
236
+ <td style="text-align: center; color: white;">
237
+ <a href="https://arxiv.org/abs/2403.02151"><img src="https://img.shields.io/badge/arXiv-2403.02151-b31b1b.svg" height="20"></a>
238
+ </td>
239
+ </tr>
240
+ </table>
241
+
242
+ **TripoSR** is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, developed in collaboration between [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
243
+
244
+ **Tips:**
245
+ 1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
246
+ 2. It's better to disable "Remove Background" for the provided examples since they have been already preprocessed.
247
+ 3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
248
+ """
249
+
250
+
251
+ if torch.cuda.is_available():
252
+ device = "cuda:0"
253
+ else:
254
+ device = "cpu"
255
+
256
+ model = TSR.from_pretrained(
257
+ "stabilityai/TripoSR",
258
+ config_name="config.yaml",
259
+ weight_name="model.ckpt",
260
+ )
261
+ model.renderer.set_chunk_size(131072)
262
+ model.to(device)
263
+
264
+ rembg_session = rembg.new_session()
265
+
266
+
267
+ def check_input_image(input_image):
268
+ if input_image is None:
269
+ raise gr.Error("No image uploaded!")
270
+
271
+
272
+ def preprocess(input_image, do_remove_background, foreground_ratio):
273
+ def fill_background(image):
274
+ image = np.array(image).astype(np.float32) / 255.0
275
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
276
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
277
+ return image
278
+
279
+ if do_remove_background:
280
+ image = input_image.convert("RGB")
281
+ image = remove_background(image, rembg_session)
282
+ image = resize_foreground(image, foreground_ratio)
283
+ image = fill_background(image)
284
+ else:
285
+ image = input_image
286
+ if image.mode == "RGBA":
287
+ image = fill_background(image)
288
+ return image
289
+
290
+
291
+ @spaces.GPU
292
+ def generate(image, mc_resolution, formats=["obj", "glb"]):
293
+ scene_codes = model(image, device=device)
294
+ mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
295
+ mesh = to_gradio_3d_orientation(mesh)
296
+
297
+ mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False)
298
+ mesh.export(mesh_path_glb.name)
299
+
300
+ mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
301
+ mesh.apply_scale([-1, 1, 1]) # Otherwise the visualized .obj will be flipped
302
+ mesh.export(mesh_path_obj.name)
303
+
304
+ return mesh_path_obj.name, mesh_path_glb.name
305
+
306
+ def run_example(image_pil):
307
+ preprocessed = preprocess(image_pil, False, 0.9)
308
+ mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
309
+ return preprocessed, mesh_name_obj, mesh_name_glb
310
+
311
+ with gr.Blocks() as demo:
312
+ gr.Markdown(HEADER)
313
+ with gr.Row(variant="panel"):
314
+ with gr.Column():
315
+ with gr.Row():
316
+ input_image = gr.Image(
317
+ label="Input Image",
318
+ image_mode="RGBA",
319
+ sources="upload",
320
+ type="pil",
321
+ elem_id="content_image",
322
+ )
323
+ processed_image = gr.Image(label="Processed Image", interactive=False)
324
+ with gr.Row():
325
+ with gr.Group():
326
+ do_remove_background = gr.Checkbox(
327
+ label="Remove Background", value=True
328
+ )
329
+ foreground_ratio = gr.Slider(
330
+ label="Foreground Ratio",
331
+ minimum=0.5,
332
+ maximum=1.0,
333
+ value=0.85,
334
+ step=0.05,
335
+ )
336
+ mc_resolution = gr.Slider(
337
+ label="Marching Cubes Resolution",
338
+ minimum=32,
339
+ maximum=320,
340
+ value=256,
341
+ step=32
342
+ )
343
+ with gr.Row():
344
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
345
+ with gr.Column():
346
+ with gr.Tab("OBJ"):
347
+ output_model_obj = gr.Model3D(
348
+ label="Output Model (OBJ Format)",
349
+ interactive=False,
350
+ )
351
+ gr.Markdown("Note: Downloaded object will be flipped in case of .obj export. Export .glb instead or manually flip it before usage.")
352
+ with gr.Tab("GLB"):
353
+ output_model_glb = gr.Model3D(
354
+ label="Output Model (GLB Format)",
355
+ interactive=False,
356
+ )
357
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
358
+ with gr.Row(variant="panel"):
359
+ gr.Examples(
360
+ examples=[
361
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
362
+ ],
363
+ inputs=[input_image],
364
+ outputs=[processed_image, output_model_obj, output_model_glb],
365
+ cache_examples=True,
366
+ fn=partial(run_example),
367
+ label="Examples",
368
+ examples_per_page=20
369
+ )
370
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
371
+ fn=preprocess,
372
+ inputs=[input_image, do_remove_background, foreground_ratio],
373
+ outputs=[processed_image],
374
+ ).success(
375
+ fn=generate,
376
+ inputs=[processed_image, mc_resolution],
377
+ outputs=[output_model_obj, output_model_glb],
378
+ )
379
+
380
+ demo.queue(max_size=10)
381
+ demo.launch(share=True)