jadechoghari commited on
Commit
7cbe23c
·
verified ·
1 Parent(s): bef5557

Create app.py

Browse files

TODO: make sure source_camero have the right shape and value
TODO: instead of outputting .obj file -> directly output a 3d model

Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import os
4
+ import numpy as np
5
+ import trimesh
6
+ import mcubes
7
+ from torchvision.utils import save_image
8
+ from PIL import Image
9
+ from transformers import AutoModel, AutoConfig
10
+ from rembg import remove, new_session
11
+ from functools import partial
12
+ from kiui.op import recenter
13
+ import kiui
14
+
15
+
16
+ # we load the pre-trained model from HF
17
+ class LRMGeneratorWrapper:
18
+ def __init__(self):
19
+ self.config = AutoConfig.from_pretrained("jadechoghari/custom-llrm", trust_remote_code=True)
20
+ self.model = AutoModel.from_pretrained("jadechoghari/custom-llrm", trust_remote_code=True)
21
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
+ self.model.to(self.device)
23
+ self.model.eval()
24
+
25
+ def forward(self, image, camera):
26
+ return self.model(image, camera)
27
+
28
+ model_wrapper = LRMGeneratorWrapper()
29
+
30
+
31
+ def preprocess_image(image, source_size):
32
+ session = new_session("isnet-general-use")
33
+ rembg_remove = partial(remove, session=session)
34
+ image = np.array(image)
35
+ image = rembg_remove(image)
36
+ mask = rembg_remove(image, only_mask=True)
37
+ image = recenter(image, mask, border_ratio=0.20)
38
+ image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) / 255.0
39
+ if image.shape[1] == 4:
40
+ image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
41
+ image = torch.nn.functional.interpolate(image, size=(source_size, source_size), mode='bicubic', align_corners=True)
42
+ image = torch.clamp(image, 0, 1)
43
+ return image
44
+
45
+ #Ref: https://github.com/jadechoghari/vfusion3d/blob/main/lrm/inferrer.py
46
+ def generate_mesh(image, source_size=512, render_size=384, mesh_size=512, export_mesh=True):
47
+ image = preprocess_image(image, source_size).to(model_wrapper.device)
48
+
49
+ # TODO: make sure source_camero have the right shape and value
50
+ source_camera = torch.tensor([[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]], dtype=torch.float32).to(model_wrapper.device)
51
+
52
+ render_camera = torch.tensor([[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]], dtype=torch.float32).to(model_wrapper.device)
53
+
54
+ with torch.no_grad():
55
+ planes = model_wrapper.forward(image, source_camera)
56
+
57
+ if export_mesh:
58
+ grid_out = model_wrapper.model.synthesizer.forward_grid(planes=planes, grid_size=mesh_size)
59
+
60
+ vtx, faces = mcubes.marching_cubes(grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy(), 1.0)
61
+ vtx = vtx / (mesh_size - 1) * 2 - 1
62
+ vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=model_wrapper.device).unsqueeze(0)
63
+ vtx_colors = model_wrapper.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
64
+
65
+ vtx_colors = (vtx_colors * 255).astype(np.uint8)
66
+ mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
67
+
68
+ mesh_path = "awesome_mesh.obj"
69
+ mesh.export(mesh_path, 'obj')
70
+ return mesh_path
71
+
72
+ # TODO: instead of outputting .obj file -> directly output a 3d model
73
+ def gradio_interface(image):
74
+ mesh_file = generate_mesh(image)
75
+ print("Generated Mesh File Path:", mesh_file)
76
+ return mesh_file
77
+
78
+
79
+ gr.Interface(
80
+ fn=gradio_interface,
81
+ inputs=gr.Image(type="pil", label="Input Image"),
82
+ outputs=gr.File(label="Awesome 3D Mesh (.obj)"),
83
+ title="3D Mesh Generator by FacebookAI",
84
+ description="Upload an image and generate a 3D mesh (.obj) file using VFusion3D by FacebookAI"
85
+ ).launch()