abreza commited on
Commit
375ee53
1 Parent(s): baad75d

restyle launch codes

Browse files
app.py CHANGED
@@ -1,319 +1,16 @@
1
- import os
2
- import shutil
3
- import tempfile
4
-
5
  import gradio as gr
6
- import numpy as np
7
- import rembg
8
- import spaces
9
- import torch
10
- from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, StableDiffusionXLPipeline, EulerDiscreteScheduler
11
- from einops import rearrange
12
- from huggingface_hub import hf_hub_download
13
- from omegaconf import OmegaConf
14
- from PIL import Image
15
- from pytorch_lightning import seed_everything
16
- from torchvision.transforms import v2
17
- from safetensors.torch import load_file
18
-
19
- from src.utils.camera_util import (FOV_to_intrinsics, get_circular_camera_poses,
20
- get_zero123plus_input_cameras)
21
- from src.utils.infer_util import (remove_background, resize_foreground)
22
- from src.utils.mesh_util import save_glb, save_obj
23
- from src.utils.train_util import instantiate_from_config
24
-
25
-
26
- def find_cuda():
27
- cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
28
- if cuda_home and os.path.exists(cuda_home):
29
- return cuda_home
30
-
31
- nvcc_path = shutil.which('nvcc')
32
- if nvcc_path:
33
- cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
34
- return cuda_path
35
-
36
- return None
37
-
38
-
39
- def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
40
- c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
41
- if is_flexicubes:
42
- cameras = torch.linalg.inv(c2ws)
43
- cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
44
- else:
45
- extrinsics = c2ws.flatten(-2)
46
- intrinsics = FOV_to_intrinsics(50.0).unsqueeze(
47
- 0).repeat(M, 1, 1).float().flatten(-2)
48
- cameras = torch.cat([extrinsics, intrinsics], dim=-1)
49
- cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
50
- return cameras
51
-
52
-
53
- def check_input_image(input_image):
54
- if input_image is None:
55
- raise gr.Error("No image selected!")
56
-
57
-
58
- def preprocess(input_image):
59
- rembg_session = rembg.new_session()
60
-
61
- input_image = remove_background(input_image, rembg_session)
62
- input_image = resize_foreground(input_image, 0.85)
63
-
64
- return input_image
65
-
66
-
67
- def generate_prompt(subject, style, color_scheme, angle, lighting_type, additional_details):
68
- prompt = f"A 3D cartoon render of {subject}, featuring the entire body and shape, on a transparent background. The style should be {style}, with {color_scheme} colors, emphasizing the essential features and lines. The pose should clearly showcase the full form of the {subject} from a {angle} perspective. Lighting is {lighting_type}, highlighting the volume and depth of the subject. {additional_details}. Output as a high-resolution PNG with no background."
69
- return prompt
70
-
71
-
72
- @spaces.GPU
73
- def generate_image(subject, style, color_scheme, angle, lighting_type, additional_details):
74
- checkpoint = "sdxl_lightning_8step_unet.safetensors"
75
- num_inference_steps = 8
76
-
77
- pipe.scheduler = EulerDiscreteScheduler.from_config(
78
- pipe.scheduler.config, timestep_spacing="trailing")
79
- pipe.unet.load_state_dict(
80
- load_file(hf_hub_download(repo, checkpoint), device="cuda"))
81
-
82
- prompt = generate_prompt(subject, style, color_scheme,
83
- angle, lighting_type, additional_details)
84
- results = pipe(
85
- prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
86
- return results.images[0]
87
-
88
-
89
- @spaces.GPU
90
- def generate_mvs(input_image, sample_steps, sample_seed):
91
- seed_everything(sample_seed)
92
-
93
- z123_image = pipeline(
94
- input_image, num_inference_steps=sample_steps).images[0]
95
-
96
- show_image = np.asarray(z123_image, dtype=np.uint8)
97
- show_image = torch.from_numpy(show_image)
98
- show_image = rearrange(
99
- show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
100
- show_image = rearrange(
101
- show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
102
- show_image = Image.fromarray(show_image.numpy())
103
 
104
- return z123_image, show_image
 
105
 
106
 
107
- @spaces.GPU
108
- def make3d(images):
109
- global model
110
- if IS_FLEXICUBES:
111
- model.init_flexicubes_geometry(device, use_renderer=False)
112
- model = model.eval()
113
-
114
- images = np.asarray(images, dtype=np.float32) / 255.0
115
- images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float()
116
- images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2)
117
-
118
- input_cameras = get_zero123plus_input_cameras(
119
- batch_size=1, radius=4.0).to(device)
120
- render_cameras = get_render_cameras(
121
- batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
122
-
123
- images = images.unsqueeze(0).to(device)
124
- images = v2.functional.resize(
125
- images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
126
-
127
- mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
128
- print(mesh_fpath)
129
- mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
130
- mesh_dirname = os.path.dirname(mesh_fpath)
131
- mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
132
-
133
- with torch.no_grad():
134
- planes = model.forward_planes(images, input_cameras)
135
- mesh_out = model.extract_mesh(
136
- planes, use_texture_map=False, **infer_config)
137
-
138
- vertices, faces, vertex_colors = mesh_out
139
- vertices = vertices[:, [1, 2, 0]]
140
-
141
- save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
142
- save_obj(vertices, faces, vertex_colors, mesh_fpath)
143
-
144
- print(f"Mesh saved to {mesh_fpath}")
145
-
146
- return mesh_fpath, mesh_glb_fpath
147
-
148
-
149
- # Configuration
150
- cuda_path = find_cuda()
151
- config_path = 'configs/instant-mesh-large.yaml'
152
- config = OmegaConf.load(config_path)
153
- config_name = os.path.basename(config_path).replace('.yaml', '')
154
- model_config = config.model_config
155
- infer_config = config.infer_config
156
-
157
- IS_FLEXICUBES = config_name.startswith('instant-mesh')
158
- device = torch.device('cuda')
159
-
160
- # Load diffusion model
161
- print('Loading diffusion model ...')
162
- pipeline = DiffusionPipeline.from_pretrained(
163
- "sudo-ai/zero123plus-v1.2",
164
- custom_pipeline="zero123plus",
165
- torch_dtype=torch.float16,
166
- )
167
- pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
168
- pipeline.scheduler.config, timestep_spacing='trailing'
169
- )
170
-
171
- unet_ckpt_path = hf_hub_download(
172
- repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
173
- state_dict = torch.load(unet_ckpt_path, map_location='cpu')
174
- pipeline.unet.load_state_dict(state_dict, strict=True)
175
-
176
- pipeline = pipeline.to(device)
177
-
178
- # Load reconstruction model
179
- print('Loading reconstruction model ...')
180
- model_ckpt_path = hf_hub_download(
181
- repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
182
- model = instantiate_from_config(model_config)
183
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
184
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith(
185
- 'lrm_generator.') and 'source_camera' not in k}
186
- model.load_state_dict(state_dict, strict=True)
187
-
188
- model = model.to(device)
189
-
190
- # Load StableDiffusionXL model
191
- base = "stabilityai/stable-diffusion-xl-base-1.0"
192
- repo = "ByteDance/SDXL-Lightning"
193
-
194
- pipe = StableDiffusionXLPipeline.from_pretrained(
195
- base, torch_dtype=torch.float16, variant="fp16").to("cuda")
196
-
197
- print('Loading Finished!')
198
-
199
  with gr.Blocks() as demo:
200
  with gr.Group():
201
  with gr.Tab("Generate Image and Remove Background"):
202
- with gr.Row():
203
- subject = gr.Textbox(label='Subject', scale=2)
204
- style = gr.Dropdown(
205
- label='Style',
206
- choices=['Pixar-like', 'Disney-esque', 'Anime-inspired'],
207
- value='Pixar-like',
208
- multiselect=False,
209
- scale=2
210
- )
211
- color_scheme = gr.Dropdown(
212
- label='Color Scheme',
213
- choices=['Vibrant', 'Pastel',
214
- 'Monochromatic', 'Black and White'],
215
- value='Vibrant',
216
- multiselect=False,
217
- scale=2
218
- )
219
- angle = gr.Dropdown(
220
- label='Angle',
221
- choices=['Front', 'Side', 'Three-quarter'],
222
- value='Front',
223
- multiselect=False,
224
- scale=2
225
- )
226
- lighting_type = gr.Dropdown(
227
- label='Lighting Type',
228
- choices=['Bright and Even',
229
- 'Dramatic Shadows', 'Soft and Warm'],
230
- value='Bright and Even',
231
- multiselect=False,
232
- scale=2
233
- )
234
- additional_details = gr.Textbox(
235
- label='Additional Details', scale=2)
236
- submit_prompt = gr.Button(
237
- 'Generate Image', scale=1, variant='primary')
238
-
239
- with gr.Row(variant="panel"):
240
- with gr.Column():
241
- with gr.Row():
242
- input_image = gr.Image(
243
- label="Input Image",
244
- image_mode="RGBA",
245
- sources="upload",
246
- type="pil",
247
- elem_id="content_image",
248
- )
249
- processed_image = gr.Image(
250
- label="Processed Image",
251
- image_mode="RGBA",
252
- type="pil",
253
- interactive=False
254
- )
255
- with gr.Row():
256
- submit_process = gr.Button(
257
- "Remove Background", elem_id="process", variant="primary")
258
- with gr.Row(variant="panel"):
259
- gr.Examples(
260
- examples=[os.path.join("examples", img_name) for img_name in sorted(
261
- os.listdir("examples"))],
262
- inputs=[input_image],
263
- label="Examples",
264
- cache_examples=False,
265
- examples_per_page=16
266
- )
267
 
268
  with gr.Tab("Generate 3D Model"):
269
- with gr.Column():
270
- with gr.Row():
271
- with gr.Column():
272
- mv_show_images = gr.Image(
273
- label="Generated Multi-views",
274
- type="pil",
275
- width=379,
276
- interactive=False
277
- )
278
- with gr.Row():
279
- with gr.Group():
280
- sample_seed = gr.Number(
281
- value=42, label="Seed Value", precision=0)
282
- sample_steps = gr.Slider(
283
- label="Sample Steps", minimum=30, maximum=75, value=75, step=5)
284
- with gr.Row():
285
- submit_mesh = gr.Button(
286
- "Generate 3D Model", elem_id="generate", variant="primary")
287
- with gr.Row():
288
- with gr.Tab("OBJ"):
289
- output_model_obj = gr.Model3D(
290
- label="Output Model (OBJ Format)",
291
- interactive=False,
292
- )
293
- gr.Markdown(
294
- "Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
295
- with gr.Tab("GLB"):
296
- output_model_glb = gr.Model3D(
297
- label="Output Model (GLB Format)",
298
- interactive=False,
299
- )
300
- gr.Markdown(
301
- "Note: The model shown here has a darker appearance. Download to get correct results.")
302
- with gr.Row():
303
- gr.Markdown(
304
- '''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
305
-
306
- mv_images = gr.State()
307
-
308
- submit_prompt.click(fn=generate_image, inputs=[subject, style, color_scheme, angle, lighting_type, additional_details], outputs=input_image).success(
309
- fn=preprocess, inputs=[input_image], outputs=[processed_image]
310
- )
311
- submit_process.click(fn=check_input_image, inputs=[input_image]).success(
312
- fn=preprocess, inputs=[input_image], outputs=[processed_image],
313
- )
314
- submit_mesh.click(fn=generate_mvs, inputs=[processed_image, sample_steps, sample_seed], outputs=[mv_images, mv_show_images]).success(
315
- fn=make3d, inputs=[mv_images], outputs=[
316
- output_model_obj, output_model_glb]
317
- )
318
 
319
  demo.launch()
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ from launch.image_generation import image_generation_ui
4
+ from launch.model_generation import model_generation_ui
5
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  with gr.Blocks() as demo:
8
  with gr.Group():
9
  with gr.Tab("Generate Image and Remove Background"):
10
+ input_image, processed_image = image_generation_ui()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  with gr.Tab("Generate 3D Model"):
13
+ output_model_obj, output_model_glb = model_generation_ui(
14
+ processed_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  demo.launch()
launch/__init__.py ADDED
File without changes
launch/image_generation.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import rembg
5
+ import spaces
6
+ import torch
7
+ from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
8
+ from huggingface_hub import hf_hub_download
9
+ from safetensors.torch import load_file
10
+
11
+ from src.utils.infer_util import (remove_background, resize_foreground)
12
+
13
+
14
+ # Load StableDiffusionXL model
15
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
16
+ repo = "ByteDance/SDXL-Lightning"
17
+
18
+ pipe = StableDiffusionXLPipeline.from_pretrained(
19
+ base, torch_dtype=torch.float16, variant="fp16").to("cuda")
20
+
21
+
22
+ def generate_prompt(subject, style, color_scheme, angle, lighting_type, additional_details):
23
+ return f"A 3D cartoon render of {subject}, featuring the entire body and shape, on a transparent background. The style should be {style}, with {color_scheme} colors, emphasizing the essential features and lines. The pose should clearly showcase the full form of the {subject} from a {angle} perspective. Lighting is {lighting_type}, highlighting the volume and depth of the subject. {additional_details}. Output as a high-resolution PNG with no background."
24
+
25
+
26
+ @spaces.GPU
27
+ def generate_image(subject, style, color_scheme, angle, lighting_type, additional_details):
28
+ checkpoint = "sdxl_lightning_8step_unet.safetensors"
29
+ num_inference_steps = 8
30
+
31
+ pipe.scheduler = EulerDiscreteScheduler.from_config(
32
+ pipe.scheduler.config, timestep_spacing="trailing")
33
+ pipe.unet.load_state_dict(
34
+ load_file(hf_hub_download(repo, checkpoint), device="cuda"))
35
+
36
+ prompt = generate_prompt(subject, style, color_scheme,
37
+ angle, lighting_type, additional_details)
38
+ results = pipe(
39
+ prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
40
+ return results.images[0]
41
+
42
+
43
+ def check_input_image(input_image):
44
+ if input_image is None:
45
+ raise gr.Error("No image selected!")
46
+
47
+
48
+ def preprocess(input_image):
49
+ rembg_session = rembg.new_session()
50
+
51
+ input_image = remove_background(input_image, rembg_session)
52
+ input_image = resize_foreground(input_image, 0.85)
53
+
54
+ return input_image
55
+
56
+
57
+ def image_generation_ui():
58
+ with gr.Row():
59
+ subject = gr.Textbox(label='Subject', scale=2)
60
+ style = gr.Dropdown(
61
+ label='Style',
62
+ choices=['Pixar-like', 'Disney-esque', 'Anime-inspired'],
63
+ value='Pixar-like',
64
+ multiselect=False,
65
+ scale=2
66
+ )
67
+ color_scheme = gr.Dropdown(
68
+ label='Color Scheme',
69
+ choices=['Vibrant', 'Pastel', 'Monochromatic', 'Black and White'],
70
+ value='Vibrant',
71
+ multiselect=False,
72
+ scale=2
73
+ )
74
+ angle = gr.Dropdown(
75
+ label='Angle',
76
+ choices=['Front', 'Side', 'Three-quarter'],
77
+ value='Front',
78
+ multiselect=False,
79
+ scale=2
80
+ )
81
+ lighting_type = gr.Dropdown(
82
+ label='Lighting Type',
83
+ choices=['Bright and Even', 'Dramatic Shadows', 'Soft and Warm'],
84
+ value='Bright and Even',
85
+ multiselect=False,
86
+ scale=2
87
+ )
88
+ additional_details = gr.Textbox(label='Additional Details', scale=2)
89
+ submit_prompt = gr.Button('Generate Image', scale=1, variant='primary')
90
+
91
+ with gr.Row(variant="panel"):
92
+ with gr.Column():
93
+ with gr.Row():
94
+ input_image = gr.Image(
95
+ label="Input Image",
96
+ image_mode="RGBA",
97
+ sources="upload",
98
+ type="pil",
99
+ elem_id="content_image",
100
+ )
101
+ processed_image = gr.Image(
102
+ label="Processed Image",
103
+ image_mode="RGBA",
104
+ type="pil",
105
+ interactive=False
106
+ )
107
+ with gr.Row():
108
+ submit_process = gr.Button(
109
+ "Remove Background", elem_id="process", variant="primary")
110
+ with gr.Row(variant="panel"):
111
+ gr.Examples(
112
+ examples=[os.path.join("examples", img_name)
113
+ for img_name in sorted(os.listdir("examples"))],
114
+ inputs=[input_image],
115
+ label="Examples",
116
+ cache_examples=False,
117
+ examples_per_page=16
118
+ )
119
+
120
+ submit_prompt.click(fn=generate_image, inputs=[subject, style, color_scheme, angle, lighting_type, additional_details], outputs=input_image).success(
121
+ fn=preprocess, inputs=[input_image], outputs=[processed_image]
122
+ )
123
+ submit_process.click(fn=check_input_image, inputs=[input_image]).success(
124
+ fn=preprocess, inputs=[input_image], outputs=[processed_image],
125
+ )
126
+
127
+ return input_image, processed_image
launch/model_generation.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ from launch.utils import find_cuda
7
+ import spaces
8
+ import torch
9
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
10
+ from einops import rearrange
11
+ from huggingface_hub import hf_hub_download
12
+ from omegaconf import OmegaConf
13
+ from PIL import Image
14
+ from pytorch_lightning import seed_everything
15
+ from torchvision.transforms import v2
16
+
17
+ from src.utils.camera_util import (FOV_to_intrinsics, get_circular_camera_poses,
18
+ get_zero123plus_input_cameras)
19
+ from src.utils.mesh_util import save_glb, save_obj
20
+ from src.utils.train_util import instantiate_from_config
21
+
22
+ # Configuration
23
+ cuda_path = find_cuda()
24
+ config_path = 'configs/instant-mesh-large.yaml'
25
+ config = OmegaConf.load(config_path)
26
+ config_name = os.path.basename(config_path).replace('.yaml', '')
27
+ model_config = config.model_config
28
+ infer_config = config.infer_config
29
+
30
+ IS_FLEXICUBES = config_name.startswith('instant-mesh')
31
+ device = torch.device('cuda')
32
+
33
+ # Load diffusion model
34
+ print('Loading diffusion model ...')
35
+ pipeline = DiffusionPipeline.from_pretrained(
36
+ "sudo-ai/zero123plus-v1.2",
37
+ custom_pipeline="zero123plus",
38
+ torch_dtype=torch.float16,
39
+ )
40
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
41
+ pipeline.scheduler.config, timestep_spacing='trailing'
42
+ )
43
+
44
+ unet_ckpt_path = hf_hub_download(
45
+ repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
46
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
47
+ pipeline.unet.load_state_dict(state_dict, strict=True)
48
+
49
+ pipeline = pipeline.to(device)
50
+
51
+ # Load reconstruction model
52
+ print('Loading reconstruction model ...')
53
+ model_ckpt_path = hf_hub_download(
54
+ repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
55
+ model = instantiate_from_config(model_config)
56
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
57
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith(
58
+ 'lrm_generator.') and 'source_camera' not in k}
59
+ model.load_state_dict(state_dict, strict=True)
60
+
61
+ model = model.to(device)
62
+
63
+
64
+ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
65
+ c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
66
+ if is_flexicubes:
67
+ cameras = torch.linalg.inv(c2ws)
68
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
69
+ else:
70
+ extrinsics = c2ws.flatten(-2)
71
+ intrinsics = FOV_to_intrinsics(50.0).unsqueeze(
72
+ 0).repeat(M, 1, 1).float().flatten(-2)
73
+ cameras = torch.cat([extrinsics, intrinsics], dim=-1)
74
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
75
+ return cameras
76
+
77
+
78
+ @spaces.GPU
79
+ def generate_mvs(input_image, sample_steps, sample_seed):
80
+ seed_everything(sample_seed)
81
+
82
+ z123_image = pipeline(
83
+ input_image, num_inference_steps=sample_steps).images[0]
84
+
85
+ show_image = np.asarray(z123_image, dtype=np.uint8)
86
+ show_image = torch.from_numpy(show_image)
87
+ show_image = rearrange(
88
+ show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
89
+ show_image = rearrange(
90
+ show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
91
+ show_image = Image.fromarray(show_image.numpy())
92
+
93
+ return z123_image, show_image
94
+
95
+
96
+ @spaces.GPU
97
+ def make3d(images):
98
+ global model
99
+ if IS_FLEXICUBES:
100
+ model.init_flexicubes_geometry(device, use_renderer=False)
101
+ model = model.eval()
102
+
103
+ images = np.asarray(images, dtype=np.float32) / 255.0
104
+ images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float()
105
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2)
106
+
107
+ input_cameras = get_zero123plus_input_cameras(
108
+ batch_size=1, radius=4.0).to(device)
109
+ render_cameras = get_render_cameras(
110
+ batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
111
+
112
+ images = images.unsqueeze(0).to(device)
113
+ images = v2.functional.resize(
114
+ images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
115
+
116
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
117
+ print(mesh_fpath)
118
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
119
+ mesh_dirname = os.path.dirname(mesh_fpath)
120
+ mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
121
+
122
+ with torch.no_grad():
123
+ planes = model.forward_planes(images, input_cameras)
124
+ mesh_out = model.extract_mesh(
125
+ planes, use_texture_map=False, **infer_config)
126
+
127
+ vertices, faces, vertex_colors = mesh_out
128
+ vertices = vertices[:, [1, 2, 0]]
129
+
130
+ save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
131
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
132
+
133
+ print(f"Mesh saved to {mesh_fpath}")
134
+
135
+ return mesh_fpath, mesh_glb_fpath
136
+
137
+
138
+ def model_generation_ui(processed_image):
139
+ with gr.Column():
140
+ with gr.Row():
141
+ with gr.Column():
142
+ mv_show_images = gr.Image(
143
+ label="Generated Multi-views",
144
+ type="pil",
145
+ width=379,
146
+ interactive=False
147
+ )
148
+ with gr.Row():
149
+ with gr.Group():
150
+ sample_seed = gr.Number(
151
+ value=42, label="Seed Value", precision=0)
152
+ sample_steps = gr.Slider(
153
+ label="Sample Steps", minimum=30, maximum=75, value=75, step=5)
154
+ with gr.Row():
155
+ submit_mesh = gr.Button(
156
+ "Generate 3D Model", elem_id="generate", variant="primary")
157
+ with gr.Row():
158
+ with gr.Tab("OBJ"):
159
+ output_model_obj = gr.Model3D(
160
+ label="Output Model (OBJ Format)",
161
+ interactive=False,
162
+ )
163
+ gr.Markdown(
164
+ "Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
165
+ with gr.Tab("GLB"):
166
+ output_model_glb = gr.Model3D(
167
+ label="Output Model (GLB Format)",
168
+ interactive=False,
169
+ )
170
+ gr.Markdown(
171
+ "Note: The model shown here has a darker appearance. Download to get correct results.")
172
+ with gr.Row():
173
+ gr.Markdown(
174
+ '''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
175
+
176
+ mv_images = gr.State()
177
+
178
+ submit_mesh.click(fn=generate_mvs, inputs=[processed_image, sample_steps, sample_seed], outputs=[mv_images, mv_show_images]).success(
179
+ fn=make3d, inputs=[mv_images], outputs=[
180
+ output_model_obj, output_model_glb]
181
+ )
182
+
183
+ return output_model_obj, output_model_glb
launch/utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ def find_cuda():
5
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
6
+ if cuda_home and os.path.exists(cuda_home):
7
+ return cuda_home
8
+
9
+ nvcc_path = shutil.which('nvcc')
10
+ if nvcc_path:
11
+ cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
12
+ return cuda_path
13
+
14
+ return None