Huiwenshi commited on
Commit
0694d37
·
verified ·
1 Parent(s): d7aa774

Delete folder infer/.ipynb_checkpoints with huggingface_hub

Browse files
infer/.ipynb_checkpoints/__init__-checkpoint.py DELETED
@@ -1,32 +0,0 @@
1
- # Open Source Model Licensed under the Apache License Version 2.0
2
- # and Other Licenses of the Third-Party Components therein:
3
- # The below Model in this distribution may have been modified by THL A29 Limited
4
- # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
-
6
- # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
- # The below software and/or models in this distribution may have been
8
- # modified by THL A29 Limited ("Tencent Modifications").
9
- # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
-
11
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
- # except for the third-party components listed below.
13
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
- # in the repsective licenses of these third-party components.
15
- # Users must comply with all terms and conditions of original licenses of these third-party
16
- # components and must ensure that the usage of the third party components adheres to
17
- # all relevant laws and regulations.
18
-
19
- # For avoidance of doubts, Hunyuan 3D means the large language models and
20
- # their software and algorithms, including trained model weights, parameters (including
21
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
- # fine-tuning enabling code and other elements of the foregoing made publicly available
23
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
-
25
- from .removebg import Removebg
26
- from .text_to_image import Text2Image
27
- from .image_to_views import Image2Views, save_gif
28
- from .views_to_mesh import Views2Mesh
29
- from .gif_render import GifRenderer
30
-
31
- from .utils import seed_everything, auto_amp_inference
32
- from .utils import get_parameter_number, set_parameter_grad_false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer/.ipynb_checkpoints/gif_render-checkpoint.py DELETED
@@ -1,79 +0,0 @@
1
- # Open Source Model Licensed under the Apache License Version 2.0
2
- # and Other Licenses of the Third-Party Components therein:
3
- # The below Model in this distribution may have been modified by THL A29 Limited
4
- # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
-
6
- # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
- # The below software and/or models in this distribution may have been
8
- # modified by THL A29 Limited ("Tencent Modifications").
9
- # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
-
11
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
- # except for the third-party components listed below.
13
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
- # in the repsective licenses of these third-party components.
15
- # Users must comply with all terms and conditions of original licenses of these third-party
16
- # components and must ensure that the usage of the third party components adheres to
17
- # all relevant laws and regulations.
18
-
19
- # For avoidance of doubts, Hunyuan 3D means the large language models and
20
- # their software and algorithms, including trained model weights, parameters (including
21
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
- # fine-tuning enabling code and other elements of the foregoing made publicly available
23
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
-
25
- import os, sys
26
- sys.path.insert(0, f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}")
27
-
28
- from svrm.ldm.vis_util import render
29
- from infer.utils import seed_everything, timing_decorator
30
-
31
- class GifRenderer():
32
- '''
33
- render frame(s) of mesh using pytorch3d
34
- '''
35
- def __init__(self, device="cuda:0"):
36
- self.device = device
37
-
38
- @timing_decorator("gif render")
39
- def __call__(
40
- self,
41
- obj_filename,
42
- elev=0,
43
- azim=0,
44
- resolution=512,
45
- gif_dst_path='',
46
- n_views=120,
47
- fps=30,
48
- rgb=True
49
- ):
50
- render(
51
- obj_filename,
52
- elev=elev,
53
- azim=azim,
54
- resolution=resolution,
55
- gif_dst_path=gif_dst_path,
56
- n_views=n_views,
57
- fps=fps,
58
- device=self.device,
59
- rgb=rgb
60
- )
61
-
62
- if __name__ == "__main__":
63
- import argparse
64
-
65
- def get_args():
66
- parser = argparse.ArgumentParser()
67
- parser.add_argument("--mesh_path", type=str, required=True)
68
- parser.add_argument("--output_gif_path", type=str, required=True)
69
- parser.add_argument("--device", default="cuda:0", type=str)
70
- return parser.parse_args()
71
-
72
- args = get_args()
73
-
74
- gif_renderer = GifRenderer(device=args.device)
75
-
76
- gif_renderer(
77
- args.mesh_path,
78
- gif_dst_path = args.output_gif_path
79
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer/.ipynb_checkpoints/image_to_views-checkpoint.py DELETED
@@ -1,126 +0,0 @@
1
- # Open Source Model Licensed under the Apache License Version 2.0
2
- # and Other Licenses of the Third-Party Components therein:
3
- # The below Model in this distribution may have been modified by THL A29 Limited
4
- # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
-
6
- # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
- # The below software and/or models in this distribution may have been
8
- # modified by THL A29 Limited ("Tencent Modifications").
9
- # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
-
11
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
- # except for the third-party components listed below.
13
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
- # in the repsective licenses of these third-party components.
15
- # Users must comply with all terms and conditions of original licenses of these third-party
16
- # components and must ensure that the usage of the third party components adheres to
17
- # all relevant laws and regulations.
18
-
19
- # For avoidance of doubts, Hunyuan 3D means the large language models and
20
- # their software and algorithms, including trained model weights, parameters (including
21
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
- # fine-tuning enabling code and other elements of the foregoing made publicly available
23
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
-
25
- import os, sys
26
- sys.path.insert(0, f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}")
27
-
28
- import time
29
- import torch
30
- import random
31
- import numpy as np
32
- from PIL import Image
33
- from einops import rearrange
34
- from PIL import Image, ImageSequence
35
-
36
- from infer.utils import seed_everything, timing_decorator, auto_amp_inference
37
- from infer.utils import get_parameter_number, set_parameter_grad_false, str_to_bool
38
- from mvd.hunyuan3d_mvd_std_pipeline import HunYuan3D_MVD_Std_Pipeline
39
- from mvd.hunyuan3d_mvd_lite_pipeline import Hunyuan3d_MVD_Lite_Pipeline
40
-
41
-
42
- def save_gif(pils, save_path, df=False):
43
- # save a list of PIL.Image to gif
44
- spf = 4000 / len(pils)
45
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
46
- pils[0].save(save_path, format="GIF", save_all=True, append_images=pils[1:], duration=spf, loop=0)
47
- return save_path
48
-
49
-
50
- class Image2Views():
51
- def __init__(self, device="cuda:0", use_lite=False, save_memory=False):
52
- self.device = device
53
- if use_lite:
54
- self.pipe = Hunyuan3d_MVD_Lite_Pipeline.from_pretrained(
55
- "./weights/mvd_lite",
56
- torch_dtype = torch.float16,
57
- use_safetensors = True,
58
- )
59
- else:
60
- self.pipe = HunYuan3D_MVD_Std_Pipeline.from_pretrained(
61
- "./weights/mvd_std",
62
- torch_dtype = torch.float16,
63
- use_safetensors = True,
64
- )
65
- self.pipe = self.pipe.to(device)
66
- self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
67
- self.save_memory = save_memory
68
- set_parameter_grad_false(self.pipe.unet)
69
- print('image2views unet model', get_parameter_number(self.pipe.unet))
70
-
71
- @torch.no_grad()
72
- @timing_decorator("image to views")
73
- @auto_amp_inference
74
- def __call__(self, *args, **kwargs):
75
- if self.save_memory:
76
- self.pipe = self.pipe.to(self.device)
77
- torch.cuda.empty_cache()
78
- res = self.call(*args, **kwargs)
79
- self.pipe = self.pipe.to("cpu")
80
- else:
81
- res = self.call(*args, **kwargs)
82
- torch.cuda.empty_cache()
83
- return res
84
-
85
- def call(self, pil_img, seed=0, steps=50, guidance_scale=2.0):
86
- seed_everything(seed)
87
- generator = torch.Generator(device=self.device)
88
- res_img = self.pipe(pil_img,
89
- num_inference_steps=steps,
90
- guidance_scale=guidance_scale,
91
- generat=generator).images
92
- show_image = rearrange(np.asarray(res_img[0], dtype=np.uint8), '(n h) (m w) c -> (n m) h w c', n=3, m=2)
93
- pils = [res_img[1]]+[Image.fromarray(show_image[idx]) for idx in self.order]
94
- torch.cuda.empty_cache()
95
- return res_img, pils
96
-
97
-
98
- if __name__ == "__main__":
99
- import argparse
100
-
101
- def get_args():
102
- parser = argparse.ArgumentParser()
103
- parser.add_argument("--rgba_path", type=str, required=True)
104
- parser.add_argument("--output_views_path", type=str, required=True)
105
- parser.add_argument("--output_cond_path", type=str, required=True)
106
- parser.add_argument("--seed", default=0, type=int)
107
- parser.add_argument("--steps", default=50, type=int)
108
- parser.add_argument("--device", default="cuda:0", type=str)
109
- parser.add_argument("--use_lite", default='false', type=str)
110
- return parser.parse_args()
111
-
112
- args = get_args()
113
-
114
- args.use_lite = str_to_bool(args.use_lite)
115
-
116
- rgba_pil = Image.open(args.rgba_path)
117
-
118
- assert rgba_pil.mode == "RGBA", "rgba_pil must be RGBA mode"
119
-
120
- model = Image2Views(device=args.device, use_lite=args.use_lite)
121
-
122
- (views_pil, cond), _ = model(rgba_pil, seed=args.seed, steps=args.steps)
123
-
124
- views_pil.save(args.output_views_path)
125
- cond.save(args.output_cond_path)
126
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer/.ipynb_checkpoints/removebg-checkpoint.py DELETED
@@ -1,101 +0,0 @@
1
- import os, sys
2
- sys.path.insert(0, f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}")
3
-
4
- import numpy as np
5
- from PIL import Image
6
- from rembg import remove, new_session
7
- from infer.utils import timing_decorator
8
-
9
- class Removebg():
10
- def __init__(self, name="u2net"):
11
- self.session = new_session(name)
12
-
13
- @timing_decorator("remove background")
14
- def __call__(self, rgb_maybe, force=True):
15
- '''
16
- args:
17
- rgb_maybe: PIL.Image, with RGB mode or RGBA mode
18
- force: bool, if input is RGBA mode, covert to RGB then remove bg
19
- return:
20
- rgba_img: PIL.Image, with RGBA mode
21
- '''
22
- if rgb_maybe.mode == "RGBA":
23
- if force:
24
- rgb_maybe = rgb_maybe.convert("RGB")
25
- rgba_img = remove(rgb_maybe, session=self.session)
26
- else:
27
- rgba_img = rgb_maybe
28
- else:
29
- rgba_img = remove(rgb_maybe, session=self.session)
30
-
31
- rgba_img = white_out_background(rgba_img)
32
-
33
- rgba_img = preprocess(rgba_img)
34
-
35
- return rgba_img
36
-
37
-
38
- def white_out_background(pil_img):
39
- data = pil_img.getdata()
40
- new_data = []
41
- for r, g, b, a in data:
42
- if a < 16: # background
43
- new_data.append((255, 255, 255, 0)) # full white color
44
- else:
45
- is_white = (r>235) and (g>235) and (b>235)
46
- new_r = 235 if is_white else r
47
- new_g = 235 if is_white else g
48
- new_b = 235 if is_white else b
49
- new_data.append((new_r, new_g, new_b, a))
50
- pil_img.putdata(new_data)
51
- return pil_img
52
-
53
- def preprocess(rgba_img, size=(512,512), ratio=1.15):
54
- image = np.asarray(rgba_img)
55
- rgb, alpha = image[:,:,:3] / 255., image[:,:,3:] / 255.
56
-
57
- # crop
58
- coords = np.nonzero(alpha > 0.1)
59
- x_min, x_max = coords[0].min(), coords[0].max()
60
- y_min, y_max = coords[1].min(), coords[1].max()
61
- rgb = (rgb[x_min:x_max, y_min:y_max, :] * 255).astype("uint8")
62
- alpha = (alpha[x_min:x_max, y_min:y_max, 0] * 255).astype("uint8")
63
-
64
- # padding
65
- h, w = rgb.shape[:2]
66
- resize_side = int(max(h, w) * ratio)
67
- pad_h, pad_w = resize_side - h, resize_side - w
68
- start_h, start_w = pad_h // 2, pad_w // 2
69
- new_rgb = np.ones((resize_side, resize_side, 3), dtype=np.uint8) * 255
70
- new_alpha = np.zeros((resize_side, resize_side), dtype=np.uint8)
71
- new_rgb[start_h:start_h + h, start_w:start_w + w] = rgb
72
- new_alpha[start_h:start_h + h, start_w:start_w + w] = alpha
73
- rgba_array = np.concatenate((new_rgb, new_alpha[:,:,None]), axis=-1)
74
-
75
- rgba_image = Image.fromarray(rgba_array, 'RGBA')
76
- rgba_image = rgba_image.resize(size)
77
- return rgba_image
78
-
79
-
80
- if __name__ == "__main__":
81
-
82
- import argparse
83
-
84
- def get_args():
85
- parser = argparse.ArgumentParser()
86
- parser.add_argument("--rgb_path", type=str, required=True)
87
- parser.add_argument("--output_rgba_path", type=str, required=True)
88
- parser.add_argument("--force", default=False, action="store_true")
89
- return parser.parse_args()
90
-
91
- args = get_args()
92
-
93
- rgb_maybe = Image.open(args.rgb_path)
94
-
95
- model = Removebg()
96
-
97
- rgba_pil = model(rgb_maybe, args.force)
98
-
99
- rgba_pil.save(args.output_rgba_path)
100
-
101
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer/.ipynb_checkpoints/text_to_image-checkpoint.py DELETED
@@ -1,105 +0,0 @@
1
- # Open Source Model Licensed under the Apache License Version 2.0
2
- # and Other Licenses of the Third-Party Components therein:
3
- # The below Model in this distribution may have been modified by THL A29 Limited
4
- # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
-
6
- # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
- # The below software and/or models in this distribution may have been
8
- # modified by THL A29 Limited ("Tencent Modifications").
9
- # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
-
11
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
- # except for the third-party components listed below.
13
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
- # in the repsective licenses of these third-party components.
15
- # Users must comply with all terms and conditions of original licenses of these third-party
16
- # components and must ensure that the usage of the third party components adheres to
17
- # all relevant laws and regulations.
18
-
19
- # For avoidance of doubts, Hunyuan 3D means the large language models and
20
- # their software and algorithms, including trained model weights, parameters (including
21
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
- # fine-tuning enabling code and other elements of the foregoing made publicly available
23
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
- import os , sys
25
- sys.path.insert(0, f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}")
26
-
27
- import torch
28
- from diffusers import HunyuanDiTPipeline, AutoPipelineForText2Image
29
-
30
- from infer.utils import seed_everything, timing_decorator, auto_amp_inference
31
- from infer.utils import get_parameter_number, set_parameter_grad_false
32
-
33
-
34
- class Text2Image():
35
- def __init__(self, pretrain="weights/hunyuanDiT", device="cuda:0", save_memory=None):
36
- '''
37
- save_memory: if GPU memory is low, can set it
38
- '''
39
- self.save_memory = save_memory
40
- self.device = device
41
- self.pipe = AutoPipelineForText2Image.from_pretrained(
42
- pretrain,
43
- torch_dtype = torch.float16,
44
- enable_pag = True,
45
- pag_applied_layers = ["blocks.(16|17|18|19)"]
46
- )
47
- set_parameter_grad_false(self.pipe.transformer)
48
- print('text2image transformer model', get_parameter_number(self.pipe.transformer))
49
- if not save_memory:
50
- self.pipe = self.pipe.to(device)
51
- self.neg_txt = "文本,特写,裁剪,出框,最差质量,低质量,JPEG伪影,PGLY,重复,病态,残缺,多余的手指,变异的手," \
52
- "画得不好的手,画得不好的脸,变异,畸形,模糊,脱水,糟糕的解剖学,糟糕的比例,多余的肢体,克隆的脸," \
53
- "毁容,恶心的比例,畸形的肢体,缺失的手臂,缺失的腿,额外的手臂,额外的腿,融合的手指,手指太多,长脖子"
54
-
55
- @torch.no_grad()
56
- @timing_decorator('text to image')
57
- @auto_amp_inference
58
- def __call__(self, *args, **kwargs):
59
- if self.save_memory:
60
- self.pipe = self.pipe.to(self.device)
61
- torch.cuda.empty_cache()
62
- res = self.call(*args, **kwargs)
63
- self.pipe = self.pipe.to("cpu")
64
- else:
65
- res = self.call(*args, **kwargs)
66
- torch.cuda.empty_cache()
67
- return res
68
-
69
- def call(self, prompt, seed=0, steps=25):
70
- '''
71
- args:
72
- prompr: str
73
- seed: int
74
- steps: int
75
- return:
76
- rgb: PIL.Image
77
- '''
78
- print("prompt is:", prompt)
79
- prompt = prompt + ",白色背景,3D风格,最佳质量"
80
- seed_everything(seed)
81
- generator = torch.Generator(device=self.device)
82
- if seed is not None: generator = generator.manual_seed(int(seed))
83
- rgb = self.pipe(prompt=prompt, negative_prompt=self.neg_txt, num_inference_steps=steps,
84
- pag_scale=1.3, width=1024, height=1024, generator=generator, return_dict=False)[0][0]
85
- torch.cuda.empty_cache()
86
- return rgb
87
-
88
- if __name__ == "__main__":
89
- import argparse
90
-
91
- def get_args():
92
- parser = argparse.ArgumentParser()
93
- parser.add_argument("--text2image_path", default="weights/hunyuanDiT", type=str)
94
- parser.add_argument("--text_prompt", default="", type=str)
95
- parser.add_argument("--output_img_path", default="./outputs/test/img.jpg", type=str)
96
- parser.add_argument("--device", default="cuda:0", type=str)
97
- parser.add_argument("--seed", default=0, type=int)
98
- parser.add_argument("--steps", default=25, type=int)
99
- return parser.parse_args()
100
- args = get_args()
101
-
102
- text2image_model = Text2Image(device=args.device)
103
- rgb_img = text2image_model(args.text_prompt, seed=args.seed, steps=args.steps)
104
- rgb_img.save(args.output_img_path)
105
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer/.ipynb_checkpoints/utils-checkpoint.py DELETED
@@ -1,87 +0,0 @@
1
- # Open Source Model Licensed under the Apache License Version 2.0
2
- # and Other Licenses of the Third-Party Components therein:
3
- # The below Model in this distribution may have been modified by THL A29 Limited
4
- # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
-
6
- # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
- # The below software and/or models in this distribution may have been
8
- # modified by THL A29 Limited ("Tencent Modifications").
9
- # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
-
11
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
- # except for the third-party components listed below.
13
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
- # in the repsective licenses of these third-party components.
15
- # Users must comply with all terms and conditions of original licenses of these third-party
16
- # components and must ensure that the usage of the third party components adheres to
17
- # all relevant laws and regulations.
18
-
19
- # For avoidance of doubts, Hunyuan 3D means the large language models and
20
- # their software and algorithms, including trained model weights, parameters (including
21
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
- # fine-tuning enabling code and other elements of the foregoing made publicly available
23
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
-
25
- import os
26
- import time
27
- import random
28
- import numpy as np
29
- import torch
30
- from torch.cuda.amp import autocast, GradScaler
31
- from functools import wraps
32
-
33
- def seed_everything(seed):
34
- '''
35
- seed everthing
36
- '''
37
- random.seed(seed)
38
- np.random.seed(seed)
39
- torch.manual_seed(seed)
40
- os.environ["PL_GLOBAL_SEED"] = str(seed)
41
-
42
- def timing_decorator(category: str):
43
- '''
44
- timing_decorator: record time
45
- '''
46
- def decorator(func):
47
- func.call_count = 0
48
- @wraps(func)
49
- def wrapper(*args, **kwargs):
50
- start_time = time.time()
51
- result = func(*args, **kwargs)
52
- end_time = time.time()
53
- elapsed_time = end_time - start_time
54
- func.call_count += 1
55
- print(f"[HunYuan3D]-[{category}], cost time: {elapsed_time:.4f}s") # huiwen
56
- return result
57
- return wrapper
58
- return decorator
59
-
60
- def auto_amp_inference(func):
61
- '''
62
- with torch.cuda.amp.autocast()"
63
- xxx
64
- '''
65
- @wraps(func)
66
- def wrapper(*args, **kwargs):
67
- with autocast():
68
- output = func(*args, **kwargs)
69
- return output
70
- return wrapper
71
-
72
- def get_parameter_number(model):
73
- total_num = sum(p.numel() for p in model.parameters())
74
- trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
75
- return {'Total': total_num, 'Trainable': trainable_num}
76
-
77
- def set_parameter_grad_false(model):
78
- for p in model.parameters():
79
- p.requires_grad = False
80
-
81
- def str_to_bool(s):
82
- if s.lower() in ['true', 't', 'yes', 'y', '1']:
83
- return True
84
- elif s.lower() in ['false', 'f', 'no', 'n', '0']:
85
- return False
86
- else:
87
- raise f"bool arg must one of ['true', 't', 'yes', 'y', '1', 'false', 'f', 'no', 'n', '0']"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer/.ipynb_checkpoints/views_to_mesh-checkpoint.py DELETED
@@ -1,154 +0,0 @@
1
- # Open Source Model Licensed under the Apache License Version 2.0
2
- # and Other Licenses of the Third-Party Components therein:
3
- # The below Model in this distribution may have been modified by THL A29 Limited
4
- # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
-
6
- # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
- # The below software and/or models in this distribution may have been
8
- # modified by THL A29 Limited ("Tencent Modifications").
9
- # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
-
11
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
- # except for the third-party components listed below.
13
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
- # in the repsective licenses of these third-party components.
15
- # Users must comply with all terms and conditions of original licenses of these third-party
16
- # components and must ensure that the usage of the third party components adheres to
17
- # all relevant laws and regulations.
18
-
19
- # For avoidance of doubts, Hunyuan 3D means the large language models and
20
- # their software and algorithms, including trained model weights, parameters (including
21
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
- # fine-tuning enabling code and other elements of the foregoing made publicly available
23
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
-
25
- import os, sys
26
- sys.path.insert(0, f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}")
27
-
28
- import time
29
- import torch
30
- import random
31
- import numpy as np
32
- from PIL import Image
33
- from einops import rearrange
34
- from PIL import Image, ImageSequence
35
-
36
- from infer.utils import seed_everything, timing_decorator, auto_amp_inference
37
- from infer.utils import get_parameter_number, set_parameter_grad_false, str_to_bool
38
- from svrm.predictor import MV23DPredictor
39
-
40
-
41
- class Views2Mesh():
42
- def __init__(self, mv23d_cfg_path, mv23d_ckt_path,
43
- device="cuda:0", use_lite=False, save_memory=False):
44
- '''
45
- mv23d_cfg_path: config yaml file
46
- mv23d_ckt_path: path to ckpt
47
- use_lite: lite version
48
- save_memory: cpu auto
49
- '''
50
- self.mv23d_predictor = MV23DPredictor(mv23d_ckt_path, mv23d_cfg_path, device=device)
51
- self.mv23d_predictor.model.eval()
52
- self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
53
- self.device = device
54
- self.save_memory = save_memory
55
- set_parameter_grad_false(self.mv23d_predictor.model)
56
- print('view2mesh model', get_parameter_number(self.mv23d_predictor.model))
57
-
58
- @torch.no_grad()
59
- @timing_decorator("views to mesh")
60
- @auto_amp_inference
61
- def __call__(self, *args, **kwargs):
62
- if self.save_memory:
63
- self.mv23d_predictor.model = self.mv23d_predictor.model.to(self.device)
64
- torch.cuda.empty_cache()
65
- res = self.call(*args, **kwargs)
66
- self.mv23d_predictor.model = self.mv23d_predictor.model.to("cpu")
67
- else:
68
- res = self.call(*args, **kwargs)
69
- torch.cuda.empty_cache()
70
- return res
71
-
72
- def call(
73
- self,
74
- views_pil=None,
75
- cond_pil=None,
76
- gif_pil=None,
77
- seed=0,
78
- target_face_count = 10000,
79
- do_texture_mapping = True,
80
- save_folder='./outputs/test'
81
- ):
82
- '''
83
- can set views_pil, cond_pil simutaously or set gif_pil only
84
- seed: int
85
- target_face_count: int
86
- save_folder: path to save mesh files
87
- '''
88
- save_dir = save_folder
89
- os.makedirs(save_dir, exist_ok=True)
90
-
91
- if views_pil is not None and cond_pil is not None:
92
- show_image = rearrange(np.asarray(views_pil, dtype=np.uint8),
93
- '(n h) (m w) c -> (n m) h w c', n=3, m=2)
94
- views = [Image.fromarray(show_image[idx]) for idx in self.order]
95
- image_list = [cond_pil]+ views
96
- image_list = [img.convert('RGB') for img in image_list]
97
- elif gif_pil is not None:
98
- image_list = [img.convert('RGB') for img in ImageSequence.Iterator(gif_pil)]
99
-
100
- image_input = image_list[0]
101
- image_list = image_list[1:] + image_list[:1]
102
-
103
- seed_everything(seed)
104
- self.mv23d_predictor.predict(
105
- image_list,
106
- save_dir = save_dir,
107
- image_input = image_input,
108
- target_face_count = target_face_count,
109
- do_texture_mapping = do_texture_mapping
110
- )
111
- torch.cuda.empty_cache()
112
- return save_dir
113
-
114
-
115
- if __name__ == "__main__":
116
-
117
- import argparse
118
-
119
- def get_args():
120
- parser = argparse.ArgumentParser()
121
- parser.add_argument("--views_path", type=str, required=True)
122
- parser.add_argument("--cond_path", type=str, required=True)
123
- parser.add_argument("--save_folder", default="./outputs/test/", type=str)
124
- parser.add_argument("--mv23d_cfg_path", default="./svrm/configs/svrm.yaml", type=str)
125
- parser.add_argument("--mv23d_ckt_path", default="weights/svrm/svrm.safetensors", type=str)
126
- parser.add_argument("--max_faces_num", default=90000, type=int,
127
- help="max num of face, suggest 90000 for effect, 10000 for speed")
128
- parser.add_argument("--device", default="cuda:0", type=str)
129
- parser.add_argument("--use_lite", default='false', type=str)
130
- parser.add_argument("--do_texture_mapping", default='false', type=str)
131
-
132
- return parser.parse_args()
133
-
134
- args = get_args()
135
- args.use_lite = str_to_bool(args.use_lite)
136
- args.do_texture_mapping = str_to_bool(args.do_texture_mapping)
137
-
138
- views = Image.open(args.views_path)
139
- cond = Image.open(args.cond_path)
140
-
141
- views_to_mesh_model = Views2Mesh(
142
- args.mv23d_cfg_path,
143
- args.mv23d_ckt_path,
144
- device = args.device,
145
- use_lite = args.use_lite
146
- )
147
-
148
- views_to_mesh_model(
149
- views, cond, 0,
150
- target_face_count = args.max_faces_num,
151
- save_folder = args.save_folder,
152
- do_texture_mapping = args.do_texture_mapping
153
- )
154
-