YulianSa commited on
Commit
def0065
·
1 Parent(s): 8d53de2
Files changed (2) hide show
  1. infer_api.py +81 -75
  2. infer_api_bk.py +889 -0
infer_api.py CHANGED
@@ -367,13 +367,13 @@ class InferAPI:
367
  continue
368
  hf_hub_download(repo_id, file, local_dir="./ckpt")
369
 
370
- self.canonical_infer = InferCanonicalAPI(self.canonical_configs)
371
  # self.multiview_infer = InferMultiviewAPI(self.multiview_configs)
372
  # self.slrm_infer = InferSlrmAPI(self.slrm_configs)
373
  # self.refine_infer = InferRefineAPI(self.refine_configs)
374
 
375
  def genStage1(self, img, seed):
376
- return self.canonical_infer.gen(img, seed)
377
 
378
  def genStage2(self, img, seed, num_levels):
379
  return self.multiview_infer.gen(img, seed, num_levels)
@@ -811,79 +811,85 @@ class InferMultiviewAPI:
811
  return results
812
 
813
 
814
- class InferCanonicalAPI:
815
- def __init__(self, config):
816
- self.config = config
817
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
818
-
819
- self.config_path = config['config_path']
820
- self.loaded_config = OmegaConf.load(self.config_path)
821
-
822
- self.setup(**self.loaded_config)
823
-
824
- def setup(self,
825
- validation: Dict,
826
- pretrained_model_path: str,
827
- local_crossattn: bool = True,
828
- unet_from_pretrained_kwargs=None,
829
- unet_condition_type=None,
830
- use_noise=True,
831
- noise_d=256,
832
- timestep: int = 40,
833
- width_input: int = 640,
834
- height_input: int = 1024,
 
 
 
835
  ):
836
- self.width_input = width_input
837
- self.height_input = height_input
838
- self.timestep = timestep
839
- self.use_noise = use_noise
840
- self.noise_d = noise_d
841
- self.validation = validation
842
- self.unet_condition_type = unet_condition_type
843
- self.pretrained_model_path = pretrained_model_path
844
-
845
- self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
846
- self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
847
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_path, subfolder="image_encoder")
848
- self.feature_extractor = CLIPImageProcessor()
849
- self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
850
- self.unet = UNetMV2DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs)
851
- self.ref_unet = UNetMV2DRefModel.from_pretrained_2d(pretrained_model_path, subfolder="ref_unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs)
852
-
853
- self.text_encoder.to(device, dtype=weight_dtype)
854
- self.image_encoder.to(device, dtype=weight_dtype)
855
- self.vae.to(device, dtype=weight_dtype)
856
- self.ref_unet.to(device, dtype=weight_dtype)
857
- self.unet.to(device, dtype=weight_dtype)
858
-
859
- self.vae.requires_grad_(False)
860
- self.ref_unet.requires_grad_(False)
861
- self.unet.requires_grad_(False)
862
-
863
- self.noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler-zerosnr")
864
- self.validation_pipeline = CanonicalizationPipeline(
865
- vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, ref_unet=self.ref_unet,feature_extractor=self.feature_extractor,image_encoder=self.image_encoder,
866
- scheduler=self.noise_scheduler
867
- )
868
- self.validation_pipeline.set_progress_bar_config(disable=True)
 
 
 
 
 
869
 
870
- def canonicalize(self, image, seed):
871
- return inference(
872
- self.validation_pipeline, image, self.vae, self.feature_extractor, self.image_encoder, self.unet, self.ref_unet, self.tokenizer, self.text_encoder,
873
- self.pretrained_model_path, self.validation, self.width_input, self.height_input, self.unet_condition_type,
874
- use_noise=self.use_noise, noise_d=self.noise_d, crop=True, seed=seed, timestep=self.timestep
875
- )
876
 
877
- def gen(self, img_input, seed=0):
878
- if np.array(img_input).shape[-1] == 4 and np.array(img_input)[..., 3].min() == 255:
879
- # convert to RGB
880
- img_input = img_input.convert("RGB")
881
- img_output = self.canonicalize(img_input, seed)
882
-
883
- max_dim = max(img_output.width, img_output.height)
884
- new_image = Image.new("RGBA", (max_dim, max_dim))
885
- left = (max_dim - img_output.width) // 2
886
- top = (max_dim - img_output.height) // 2
887
- new_image.paste(img_output, (left, top))
888
-
889
- return new_image
 
 
 
 
 
367
  continue
368
  hf_hub_download(repo_id, file, local_dir="./ckpt")
369
 
370
+ # self.canonical_infer = InferCanonicalAPI(self.canonical_configs)
371
  # self.multiview_infer = InferMultiviewAPI(self.multiview_configs)
372
  # self.slrm_infer = InferSlrmAPI(self.slrm_configs)
373
  # self.refine_infer = InferRefineAPI(self.refine_configs)
374
 
375
  def genStage1(self, img, seed):
376
+ return infer_canonicalize_gen(img, seed)
377
 
378
  def genStage2(self, img, seed, num_levels):
379
  return self.multiview_infer.gen(img, seed, num_levels)
 
811
  return results
812
 
813
 
814
+ infer_canonicalize_config = {
815
+ 'config_path': './configs/canonicalization-infer.yaml',
816
+ }
817
+ infer_canonicalize_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
818
+ # print device stderr
819
+ import sys
820
+ print(f"Using device!!!!!!!!!!!!: {infer_canonicalize_device}", file=sys.stderr)
821
+
822
+ infer_canonicalize_config_path = infer_canonicalize_config['config_path']
823
+ infer_canonicalize_loaded_config = OmegaConf.load(infer_canonicalize_config_path)
824
+
825
+ # infer_canonicalize_setup(**infer_canonicalize_loaded_config)
826
+
827
+ def infer_canonicalize_setup(
828
+ validation: Dict,
829
+ pretrained_model_path: str,
830
+ local_crossattn: bool = True,
831
+ unet_from_pretrained_kwargs=None,
832
+ unet_condition_type=None,
833
+ use_noise=True,
834
+ noise_d=256,
835
+ timestep: int = 40,
836
+ width_input: int = 640,
837
+ height_input: int = 1024,
838
  ):
839
+ infer_canonicalize_width_input = width_input
840
+ infer_canonicalize_height_input = height_input
841
+ infer_canonicalize_timestep = timestep
842
+ infer_canonicalize_use_noise = use_noise
843
+ infer_canonicalize_noise_d = noise_d
844
+ infer_canonicalize_validation = validation
845
+ infer_canonicalize_unet_condition_type = unet_condition_type
846
+ infer_canonicalize_pretrained_model_path = pretrained_model_path
847
+ infer_canonicalize_local_crossattn = local_crossattn
848
+ infer_canonicalize_unet_from_pretrained_kwargs = unet_from_pretrained_kwargs
849
+ return infer_canonicalize_width_input, infer_canonicalize_height_input, infer_canonicalize_timestep, infer_canonicalize_use_noise, infer_canonicalize_noise_d, infer_canonicalize_validation, infer_canonicalize_unet_condition_type, infer_canonicalize_pretrained_model_path, infer_canonicalize_local_crossattn, infer_canonicalize_unet_from_pretrained_kwargs
850
+
851
+ infer_canonicalize_width_input, infer_canonicalize_height_input, infer_canonicalize_timestep, infer_canonicalize_use_noise, infer_canonicalize_noise_d, infer_canonicalize_validation, infer_canonicalize_unet_condition_type, infer_canonicalize_pretrained_model_path, infer_canonicalize_local_crossattn, infer_canonicalize_unet_from_pretrained_kwargs = infer_canonicalize_setup(**infer_canonicalize_loaded_config)
852
+
853
+ infer_canonicalize_tokenizer = CLIPTokenizer.from_pretrained(infer_canonicalize_pretrained_model_path, subfolder="tokenizer")
854
+ infer_canonicalize_text_encoder = CLIPTextModel.from_pretrained(infer_canonicalize_pretrained_model_path, subfolder="text_encoder")
855
+ infer_canonicalize_image_encoder = CLIPVisionModelWithProjection.from_pretrained(infer_canonicalize_pretrained_model_path, subfolder="image_encoder")
856
+ infer_canonicalize_feature_extractor = CLIPImageProcessor()
857
+ infer_canonicalize_vae = AutoencoderKL.from_pretrained(infer_canonicalize_pretrained_model_path, subfolder="vae")
858
+ infer_canonicalize_unet = UNetMV2DConditionModel.from_pretrained_2d(infer_canonicalize_pretrained_model_path, subfolder="unet", local_crossattn=infer_canonicalize_local_crossattn, **infer_canonicalize_unet_from_pretrained_kwargs)
859
+ infer_canonicalize_ref_unet = UNetMV2DRefModel.from_pretrained_2d(infer_canonicalize_pretrained_model_path, subfolder="ref_unet", local_crossattn=infer_canonicalize_local_crossattn, **infer_canonicalize_unet_from_pretrained_kwargs)
860
+
861
+ infer_canonicalize_text_encoder.to(device, dtype=weight_dtype)
862
+ infer_canonicalize_image_encoder.to(device, dtype=weight_dtype)
863
+ infer_canonicalize_vae.to(device, dtype=weight_dtype)
864
+ infer_canonicalize_ref_unet.to(device, dtype=weight_dtype)
865
+ infer_canonicalize_unet.to(device, dtype=weight_dtype)
866
+
867
+ infer_canonicalize_vae.requires_grad_(False)
868
+ infer_canonicalize_ref_unet.requires_grad_(False)
869
+ infer_canonicalize_unet.requires_grad_(False)
870
+
871
+ infer_canonicalize_noise_scheduler = DDIMScheduler.from_pretrained(infer_canonicalize_pretrained_model_path, subfolder="scheduler-zerosnr")
872
+ infer_canonicalize_validation_pipeline = CanonicalizationPipeline(
873
+ vae=infer_canonicalize_vae, text_encoder=infer_canonicalize_text_encoder, tokenizer=infer_canonicalize_tokenizer, unet=infer_canonicalize_unet, ref_unet=infer_canonicalize_ref_unet,feature_extractor=infer_canonicalize_feature_extractor,image_encoder=infer_canonicalize_image_encoder,
874
+ scheduler=infer_canonicalize_noise_scheduler
875
+ )
876
+ infer_canonicalize_validation_pipeline.set_progress_bar_config(disable=True)
877
 
 
 
 
 
 
 
878
 
879
+ def infer_canonicalize_gen(img_input, seed=0):
880
+ if np.array(img_input).shape[-1] == 4 and np.array(img_input)[..., 3].min() == 255:
881
+ # convert to RGB
882
+ img_input = img_input.convert("RGB")
883
+ img_output = inference(
884
+ infer_canonicalize_validation_pipeline, img_input, infer_canonicalize_vae, infer_canonicalize_feature_extractor, infer_canonicalize_image_encoder, infer_canonicalize_unet, infer_canonicalize_ref_unet, infer_canonicalize_tokenizer, infer_canonicalize_text_encoder,
885
+ infer_canonicalize_pretrained_model_path, infer_canonicalize_validation, infer_canonicalize_width_input, infer_canonicalize_height_input, infer_canonicalize_unet_condition_type,
886
+ use_noise=infer_canonicalize_use_noise, noise_d=infer_canonicalize_noise_d, crop=True, seed=seed, timestep=infer_canonicalize_timestep
887
+ )
888
+
889
+ max_dim = max(img_output.width, img_output.height)
890
+ new_image = Image.new("RGBA", (max_dim, max_dim))
891
+ left = (max_dim - img_output.width) // 2
892
+ top = (max_dim - img_output.height) // 2
893
+ new_image.paste(img_output, (left, top))
894
+
895
+ return new_image
infer_api_bk.py ADDED
@@ -0,0 +1,889 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from PIL import Image
3
+
4
+ import io
5
+ import argparse
6
+ import os
7
+ import random
8
+ import tempfile
9
+ from typing import Dict, Optional, Tuple
10
+ from omegaconf import OmegaConf
11
+ import numpy as np
12
+
13
+ import torch
14
+
15
+ from diffusers import AutoencoderKL, DDIMScheduler
16
+ from diffusers.utils import check_min_version
17
+ from tqdm.auto import tqdm
18
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection
19
+ from torchvision import transforms
20
+
21
+ from canonicalize.models.unet_mv2d_condition import UNetMV2DConditionModel
22
+ from canonicalize.models.unet_mv2d_ref import UNetMV2DRefModel
23
+ from canonicalize.pipeline_canonicalize import CanonicalizationPipeline
24
+ from einops import rearrange
25
+ from torchvision.utils import save_image
26
+ import json
27
+ import cv2
28
+
29
+ import onnxruntime as rt
30
+ from huggingface_hub.file_download import hf_hub_download
31
+ from huggingface_hub import list_repo_files
32
+ from rm_anime_bg.cli import get_mask, SCALE
33
+
34
+ import argparse
35
+ import os
36
+ import cv2
37
+ import glob
38
+ import numpy as np
39
+ import matplotlib.pyplot as plt
40
+ from typing import Dict, Optional, List
41
+ from omegaconf import OmegaConf, DictConfig
42
+ from PIL import Image
43
+ from pathlib import Path
44
+ from dataclasses import dataclass
45
+ from typing import Dict
46
+ import torch
47
+ import torch.nn.functional as F
48
+ import torch.utils.checkpoint
49
+ import torchvision.transforms.functional as TF
50
+ from torch.utils.data import Dataset, DataLoader
51
+ from torchvision import transforms
52
+ from torchvision.utils import make_grid, save_image
53
+ from accelerate.utils import set_seed
54
+ from tqdm.auto import tqdm
55
+ from einops import rearrange, repeat
56
+ from multiview.pipeline_multiclass import StableUnCLIPImg2ImgPipeline
57
+
58
+ import os
59
+ import imageio
60
+ import numpy as np
61
+ import torch
62
+ import cv2
63
+ import glob
64
+ import matplotlib.pyplot as plt
65
+ from PIL import Image
66
+ from torchvision.transforms import v2
67
+ from pytorch_lightning import seed_everything
68
+ from omegaconf import OmegaConf
69
+ from tqdm import tqdm
70
+
71
+ from slrm.utils.train_util import instantiate_from_config
72
+ from slrm.utils.camera_util import (
73
+ FOV_to_intrinsics,
74
+ get_circular_camera_poses,
75
+ )
76
+ from slrm.utils.mesh_util import save_obj, save_glb
77
+ from slrm.utils.infer_util import images_to_video
78
+
79
+ import cv2
80
+ import numpy as np
81
+ import os
82
+ import trimesh
83
+ import argparse
84
+ import torch
85
+ import scipy
86
+ from PIL import Image
87
+
88
+ from refine.mesh_refine import geo_refine
89
+ from refine.func import make_star_cameras_orthographic
90
+ from refine.render import NormalsRenderer, calc_vertex_normals
91
+
92
+ import pytorch3d
93
+ from pytorch3d.structures import Meshes
94
+ from sklearn.neighbors import KDTree
95
+
96
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
97
+
98
+ check_min_version("0.24.0")
99
+ weight_dtype = torch.float16
100
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
101
+ VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
102
+
103
+
104
+ @spaces.GPU
105
+ def set_seed(seed):
106
+ random.seed(seed)
107
+ np.random.seed(seed)
108
+ torch.manual_seed(seed)
109
+ torch.cuda.manual_seed_all(seed)
110
+
111
+
112
+ session_infer_path = hf_hub_download(
113
+ repo_id="skytnt/anime-seg", filename="isnetis.onnx",
114
+ )
115
+ providers: list[str] = ["CPUExecutionProvider"]
116
+ if "CUDAExecutionProvider" in rt.get_available_providers():
117
+ providers = ["CUDAExecutionProvider"]
118
+
119
+ bkg_remover_session_infer = rt.InferenceSession(
120
+ session_infer_path, providers=providers,
121
+ )
122
+
123
+ @spaces.GPU
124
+ def remove_background(
125
+ img: np.ndarray,
126
+ alpha_min: float,
127
+ alpha_max: float,
128
+ ) -> list:
129
+ img = np.array(img)
130
+ mask = get_mask(bkg_remover_session_infer, img)
131
+ mask[mask < alpha_min] = 0.0
132
+ mask[mask > alpha_max] = 1.0
133
+ img_after = (mask * img).astype(np.uint8)
134
+ mask = (mask * SCALE).astype(np.uint8)
135
+ img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8)
136
+ return Image.fromarray(img_after)
137
+
138
+
139
+ def process_image(image, totensor, width, height):
140
+ assert image.mode == "RGBA"
141
+
142
+ # Find non-transparent pixels
143
+ non_transparent = np.nonzero(np.array(image)[..., 3])
144
+ min_x, max_x = non_transparent[1].min(), non_transparent[1].max()
145
+ min_y, max_y = non_transparent[0].min(), non_transparent[0].max()
146
+ image = image.crop((min_x, min_y, max_x, max_y))
147
+
148
+ # paste to center
149
+ max_dim = max(image.width, image.height)
150
+ max_height = int(max_dim * 1.2)
151
+ max_width = int(max_dim / (height/width) * 1.2)
152
+ new_image = Image.new("RGBA", (max_width, max_height))
153
+ left = (max_width - image.width) // 2
154
+ top = (max_height - image.height) // 2
155
+ new_image.paste(image, (left, top))
156
+
157
+ image = new_image.resize((width, height), resample=Image.BICUBIC)
158
+ image = np.array(image)
159
+ image = image.astype(np.float32) / 255.
160
+ assert image.shape[-1] == 4 # RGBA
161
+ alpha = image[..., 3:4]
162
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
163
+ image = image[..., :3] * alpha + bg_color * (1 - alpha)
164
+ return totensor(image)
165
+
166
+
167
+ @spaces.GPU
168
+ @torch.no_grad()
169
+ def inference(validation_pipeline, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
170
+ text_encoder, pretrained_model_path, validation, val_width, val_height, unet_condition_type,
171
+ use_noise=True, noise_d=256, crop=False, seed=100, timestep=20):
172
+ set_seed(seed)
173
+ generator = torch.Generator(device=device).manual_seed(seed)
174
+
175
+ totensor = transforms.ToTensor()
176
+
177
+ prompts = "high quality, best quality"
178
+ prompt_ids = tokenizer(
179
+ prompts, max_length=tokenizer.model_max_length, padding="max_length", truncation=True,
180
+ return_tensors="pt"
181
+ ).input_ids[0]
182
+
183
+ # (B*Nv, 3, H, W)
184
+ B = 1
185
+ if input_image.mode != "RGBA":
186
+ # remove background
187
+ input_image = remove_background(input_image, 0.1, 0.9)
188
+ imgs_in = process_image(input_image, totensor, val_width, val_height)
189
+ imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W")
190
+
191
+ with torch.autocast('cuda' if torch.cuda.is_available() else 'cpu', dtype=weight_dtype):
192
+ imgs_in = imgs_in.to(device=device)
193
+ # B*Nv images
194
+ out = validation_pipeline(prompt=prompts, image=imgs_in.to(weight_dtype), generator=generator,
195
+ num_inference_steps=timestep, prompt_ids=prompt_ids,
196
+ height=val_height, width=val_width, unet_condition_type=unet_condition_type,
197
+ use_noise=use_noise, **validation,)
198
+ out = rearrange(out, "B C f H W -> (B f) C H W", f=1)
199
+
200
+ print("OUT!!!!!!")
201
+
202
+ img_buf = io.BytesIO()
203
+ save_image(out[0], img_buf, format='PNG')
204
+ img_buf.seek(0)
205
+ img = Image.open(img_buf)
206
+
207
+ print("OUT2!!!!!!")
208
+
209
+ torch.cuda.empty_cache()
210
+ return img
211
+
212
+
213
+ ######### Multi View Part #############
214
+ weight_dtype = torch.float16
215
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
216
+
217
+ def tensor_to_numpy(tensor):
218
+ return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
219
+
220
+
221
+ @dataclass
222
+ class TestConfig:
223
+ pretrained_model_name_or_path: str
224
+ pretrained_unet_path:Optional[str]
225
+ revision: Optional[str]
226
+ validation_dataset: Dict
227
+ save_dir: str
228
+ seed: Optional[int]
229
+ validation_batch_size: int
230
+ dataloader_num_workers: int
231
+ save_mode: str
232
+ local_rank: int
233
+
234
+ pipe_kwargs: Dict
235
+ pipe_validation_kwargs: Dict
236
+ unet_from_pretrained_kwargs: Dict
237
+ validation_grid_nrow: int
238
+ camera_embedding_lr_mult: float
239
+
240
+ num_views: int
241
+ camera_embedding_type: str
242
+
243
+ pred_type: str
244
+ regress_elevation: bool
245
+ enable_xformers_memory_efficient_attention: bool
246
+
247
+ cond_on_normals: bool
248
+ cond_on_colors: bool
249
+
250
+ regress_elevation: bool
251
+ regress_focal_length: bool
252
+
253
+
254
+
255
+ def convert_to_numpy(tensor):
256
+ return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
257
+
258
+ def save_image(tensor):
259
+ ndarr = convert_to_numpy(tensor)
260
+ return save_image_numpy(ndarr)
261
+
262
+ def save_image_numpy(ndarr):
263
+ im = Image.fromarray(ndarr)
264
+ # pad to square
265
+ if im.size[0] != im.size[1]:
266
+ size = max(im.size)
267
+ new_im = Image.new("RGB", (size, size))
268
+ # set to white
269
+ new_im.paste((255, 255, 255), (0, 0, size, size))
270
+ new_im.paste(im, ((size - im.size[0]) // 2, (size - im.size[1]) // 2))
271
+ im = new_im
272
+ # resize to 1024x1024
273
+ im = im.resize((1024, 1024), Image.LANCZOS)
274
+ return im
275
+
276
+ @spaces.GPU
277
+ def run_multiview_infer(data, pipeline, cfg: TestConfig, num_levels=3):
278
+ if cfg.seed is None:
279
+ generator = None
280
+ else:
281
+ generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed)
282
+
283
+ images_cond = []
284
+ results = {}
285
+
286
+ torch.cuda.empty_cache()
287
+ images_cond.append(data['image_cond_rgb'][:, 0].cuda())
288
+ imgs_in = torch.cat([data['image_cond_rgb']]*2, dim=0).cuda()
289
+ num_views = imgs_in.shape[1]
290
+ imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W)
291
+
292
+ target_h, target_w = imgs_in.shape[-2], imgs_in.shape[-1]
293
+
294
+ normal_prompt_embeddings, clr_prompt_embeddings = data['normal_prompt_embeddings'].cuda(), data['color_prompt_embeddings'].cuda()
295
+ prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
296
+ prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
297
+
298
+ # B*Nv images
299
+ unet_out = pipeline(
300
+ imgs_in, None, prompt_embeds=prompt_embeddings,
301
+ generator=generator, guidance_scale=3.0, output_type='pt', num_images_per_prompt=1,
302
+ height=cfg.height, width=cfg.width,
303
+ num_inference_steps=40, eta=1.0,
304
+ num_levels=num_levels,
305
+ )
306
+
307
+ for level in range(num_levels):
308
+ out = unet_out[level].images
309
+ bsz = out.shape[0] // 2
310
+
311
+ normals_pred = out[:bsz]
312
+ images_pred = out[bsz:]
313
+
314
+ if num_levels == 2:
315
+ results[level+1] = {'normals': [], 'images': []}
316
+ else:
317
+ results[level] = {'normals': [], 'images': []}
318
+
319
+ for i in range(bsz//num_views):
320
+ img_in_ = images_cond[-1][i].to(out.device)
321
+ for j in range(num_views):
322
+ view = VIEWS[j]
323
+ idx = i*num_views + j
324
+ normal = normals_pred[idx]
325
+ color = images_pred[idx]
326
+
327
+ ## save color and normal---------------------
328
+ new_normal = save_image(normal)
329
+ new_color = save_image(color)
330
+
331
+ if num_levels == 2:
332
+ results[level+1]['normals'].append(new_normal)
333
+ results[level+1]['images'].append(new_color)
334
+ else:
335
+ results[level]['normals'].append(new_normal)
336
+ results[level]['images'].append(new_color)
337
+
338
+ torch.cuda.empty_cache()
339
+ return results
340
+
341
+ @spaces.GPU
342
+ def load_multiview_pipeline(cfg):
343
+ pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
344
+ cfg.pretrained_path,
345
+ torch_dtype=torch.float16,)
346
+ pipeline.unet.enable_xformers_memory_efficient_attention()
347
+ if torch.cuda.is_available():
348
+ pipeline.to(device)
349
+ return pipeline
350
+
351
+
352
+ class InferAPI:
353
+ def __init__(self,
354
+ canonical_configs,
355
+ multiview_configs,
356
+ slrm_configs,
357
+ refine_configs):
358
+ self.canonical_configs = canonical_configs
359
+ self.multiview_configs = multiview_configs
360
+ self.slrm_configs = slrm_configs
361
+ self.refine_configs = refine_configs
362
+
363
+ repo_id = "hyz317/StdGEN"
364
+ all_files = list_repo_files(repo_id, revision="main")
365
+ for file in all_files:
366
+ if os.path.exists(file):
367
+ continue
368
+ hf_hub_download(repo_id, file, local_dir="./ckpt")
369
+
370
+ self.canonical_infer = InferCanonicalAPI(self.canonical_configs)
371
+ # self.multiview_infer = InferMultiviewAPI(self.multiview_configs)
372
+ # self.slrm_infer = InferSlrmAPI(self.slrm_configs)
373
+ # self.refine_infer = InferRefineAPI(self.refine_configs)
374
+
375
+ def genStage1(self, img, seed):
376
+ return self.canonical_infer.gen(img, seed)
377
+
378
+ def genStage2(self, img, seed, num_levels):
379
+ return self.multiview_infer.gen(img, seed, num_levels)
380
+
381
+ def genStage3(self, img):
382
+ return self.slrm_infer.gen(img)
383
+
384
+ def genStage4(self, meshes, imgs):
385
+ return self.refine_infer.refine(meshes, imgs)
386
+
387
+
388
+ ############## Refine ##############
389
+ def fix_vert_color_glb(mesh_path):
390
+ from pygltflib import GLTF2, Material, PbrMetallicRoughness
391
+ obj1 = GLTF2().load(mesh_path)
392
+ obj1.meshes[0].primitives[0].material = 0
393
+ obj1.materials.append(Material(
394
+ pbrMetallicRoughness = PbrMetallicRoughness(
395
+ baseColorFactor = [1.0, 1.0, 1.0, 1.0],
396
+ metallicFactor = 0.,
397
+ roughnessFactor = 1.0,
398
+ ),
399
+ emissiveFactor = [0.0, 0.0, 0.0],
400
+ doubleSided = True,
401
+ ))
402
+ obj1.save(mesh_path)
403
+
404
+
405
+ def srgb_to_linear(c_srgb):
406
+ c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
407
+ return c_linear.clip(0, 1.)
408
+
409
+
410
+ def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
411
+ # convert from pytorch3d meshes to trimesh mesh
412
+ vertices = meshes.verts_packed().cpu().float().numpy()
413
+ triangles = meshes.faces_packed().cpu().long().numpy()
414
+ np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
415
+ if save_glb_path.endswith(".glb"):
416
+ # rotate 180 along +Y
417
+ vertices[:, [0, 2]] = -vertices[:, [0, 2]]
418
+
419
+ if apply_sRGB_to_LinearRGB:
420
+ np_color = srgb_to_linear(np_color)
421
+ assert vertices.shape[0] == np_color.shape[0]
422
+ assert np_color.shape[1] == 3
423
+ assert 0 <= np_color.min() and np_color.max() <= 1.001, f"min={np_color.min()}, max={np_color.max()}"
424
+ np_color = np.clip(np_color, 0, 1)
425
+ mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
426
+ mesh.remove_unreferenced_vertices()
427
+ # save mesh
428
+ mesh.export(save_glb_path)
429
+ if save_glb_path.endswith(".glb"):
430
+ fix_vert_color_glb(save_glb_path)
431
+ print(f"saving to {save_glb_path}")
432
+
433
+
434
+ def calc_horizontal_offset(target_img, source_img):
435
+ target_mask = target_img.astype(np.float32).sum(axis=-1) > 750
436
+ source_mask = source_img.astype(np.float32).sum(axis=-1) > 750
437
+ best_offset = -114514
438
+ for offset in range(-200, 200):
439
+ offset_mask = np.roll(source_mask, offset, axis=1)
440
+ overlap = (target_mask & offset_mask).sum()
441
+ if overlap > best_offset:
442
+ best_offset = overlap
443
+ best_offset_value = offset
444
+ return best_offset_value
445
+
446
+
447
+ def calc_horizontal_offset2(target_mask, source_img):
448
+ source_mask = source_img.astype(np.float32).sum(axis=-1) > 750
449
+ best_offset = -114514
450
+ for offset in range(-200, 200):
451
+ offset_mask = np.roll(source_mask, offset, axis=1)
452
+ overlap = (target_mask & offset_mask).sum()
453
+ if overlap > best_offset:
454
+ best_offset = overlap
455
+ best_offset_value = offset
456
+ return best_offset_value
457
+
458
+
459
+ @spaces.GPU
460
+ def get_distract_mask(generator, color_0, color_1, normal_0=None, normal_1=None, thres=0.25, ratio=0.50, outside_thres=0.10, outside_ratio=0.20):
461
+ distract_area = np.abs(color_0 - color_1).sum(axis=-1) > thres
462
+ if normal_0 is not None and normal_1 is not None:
463
+ distract_area |= np.abs(normal_0 - normal_1).sum(axis=-1) > thres
464
+ labeled_array, num_features = scipy.ndimage.label(distract_area)
465
+ results = []
466
+
467
+ random_sampled_points = []
468
+
469
+ for i in range(num_features + 1):
470
+ if np.sum(labeled_array == i) > 1000 and np.sum(labeled_array == i) < 100000:
471
+ results.append((i, np.sum(labeled_array == i)))
472
+ # random sample a point in the area
473
+ points = np.argwhere(labeled_array == i)
474
+ random_sampled_points.append(points[np.random.randint(0, points.shape[0])])
475
+
476
+ results = sorted(results, key=lambda x: x[1], reverse=True) # [1:]
477
+ distract_mask = np.zeros_like(distract_area)
478
+ distract_bbox = np.zeros_like(distract_area)
479
+ for i, _ in results:
480
+ distract_mask |= labeled_array == i
481
+ bbox = np.argwhere(labeled_array == i)
482
+ min_x, min_y = bbox.min(axis=0)
483
+ max_x, max_y = bbox.max(axis=0)
484
+ distract_bbox[min_x:max_x, min_y:max_y] = 1
485
+
486
+ points = np.array(random_sampled_points)[:, ::-1]
487
+ labels = np.ones(len(points), dtype=np.int32)
488
+
489
+ masks = generator.generate((color_1 * 255).astype(np.uint8))
490
+
491
+ outside_area = np.abs(color_0 - color_1).sum(axis=-1) < outside_thres
492
+
493
+ final_mask = np.zeros_like(distract_mask)
494
+ for iii, mask in enumerate(masks):
495
+ mask['segmentation'] = cv2.resize(mask['segmentation'].astype(np.float32), (1024, 1024)) > 0.5
496
+ intersection = np.logical_and(mask['segmentation'], distract_mask).sum()
497
+ total = mask['segmentation'].sum()
498
+ iou = intersection / total
499
+ outside_intersection = np.logical_and(mask['segmentation'], outside_area).sum()
500
+ outside_total = mask['segmentation'].sum()
501
+ outside_iou = outside_intersection / outside_total
502
+ if iou > ratio and outside_iou < outside_ratio:
503
+ final_mask |= mask['segmentation']
504
+
505
+ # calculate coverage
506
+ intersection = np.logical_and(final_mask, distract_mask).sum()
507
+ total = distract_mask.sum()
508
+ coverage = intersection / total
509
+
510
+ if coverage < 0.8:
511
+ # use original distract mask
512
+ final_mask = (distract_mask.copy() * 255).astype(np.uint8)
513
+ final_mask = cv2.dilate(final_mask, np.ones((3, 3), np.uint8), iterations=3)
514
+ labeled_array_dilate, num_features_dilate = scipy.ndimage.label(final_mask)
515
+ for i in range(num_features_dilate + 1):
516
+ if np.sum(labeled_array_dilate == i) < 200:
517
+ final_mask[labeled_array_dilate == i] = 255
518
+
519
+ final_mask = cv2.erode(final_mask, np.ones((3, 3), np.uint8), iterations=3)
520
+ final_mask = final_mask > 127
521
+
522
+ return distract_mask, distract_bbox, random_sampled_points, final_mask
523
+
524
+
525
+ class InferRefineAPI:
526
+ @spaces.GPU
527
+ def __init__(self, config):
528
+ self.sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda()
529
+ self.generator = SamAutomaticMaskGenerator(
530
+ model=self.sam,
531
+ points_per_side=64,
532
+ pred_iou_thresh=0.80,
533
+ stability_score_thresh=0.92,
534
+ crop_n_layers=1,
535
+ crop_n_points_downscale_factor=2,
536
+ min_mask_region_area=100,
537
+ )
538
+ self.outside_ratio = 0.20
539
+
540
+ @spaces.GPU
541
+ def refine(self, meshes, imgs):
542
+ fixed_v, fixed_f, fixed_t = None, None, None
543
+ flow_vert, flow_vector = None, None
544
+ last_colors, last_normals = None, None
545
+ last_front_color, last_front_normal = None, None
546
+ distract_mask = None
547
+
548
+ mv, proj = make_star_cameras_orthographic(8, 1, r=1.2)
549
+ mv = mv[[4, 3, 2, 0, 6, 5]]
550
+ renderer = NormalsRenderer(mv,proj,(1024,1024))
551
+
552
+ results = []
553
+
554
+ for name_idx, level in zip([2, 0, 1], [2, 1, 0]):
555
+ mesh = trimesh.load(meshes[name_idx])
556
+ new_mesh = mesh.split(only_watertight=False)
557
+ new_mesh = [ j for j in new_mesh if len(j.vertices) >= 300 ]
558
+ mesh = trimesh.Scene(new_mesh).dump(concatenate=True)
559
+ mesh_v, mesh_f = mesh.vertices, mesh.faces
560
+
561
+ if last_colors is None:
562
+ images = renderer.render(
563
+ torch.tensor(mesh_v, device='cuda').float(),
564
+ torch.ones_like(torch.from_numpy(mesh_v), device='cuda').float(),
565
+ torch.tensor(mesh_f, device='cuda'),
566
+ )
567
+ mask = (images[..., 3] < 0.9).cpu().numpy()
568
+
569
+ colors, normals = [], []
570
+ for i in range(6):
571
+ color = np.array(imgs[level]['images'][i])
572
+ normal = np.array(imgs[level]['normals'][i])
573
+
574
+ if last_colors is not None:
575
+ offset = calc_horizontal_offset(np.array(last_colors[i]), color)
576
+ # print('offset', i, offset)
577
+ else:
578
+ offset = calc_horizontal_offset2(mask[i], color)
579
+ # print('init offset', i, offset)
580
+
581
+ if offset != 0:
582
+ color = np.roll(color, offset, axis=1)
583
+ normal = np.roll(normal, offset, axis=1)
584
+
585
+ color = Image.fromarray(color)
586
+ normal = Image.fromarray(normal)
587
+ colors.append(color)
588
+ normals.append(normal)
589
+
590
+ if last_front_color is not None and level == 0:
591
+ original_mask, distract_bbox, _, distract_mask = get_distract_mask(self.generator, last_front_color, np.array(colors[0]).astype(np.float32) / 255.0, outside_ratio=self.outside_ratio)
592
+ else:
593
+ distract_mask = None
594
+ distract_bbox = None
595
+
596
+ last_front_color = np.array(colors[0]).astype(np.float32) / 255.0
597
+ last_front_normal = np.array(normals[0]).astype(np.float32) / 255.0
598
+
599
+ if last_colors is None:
600
+ from copy import deepcopy
601
+ last_colors, last_normals = deepcopy(colors), deepcopy(normals)
602
+
603
+ # my mesh flow weight by nearest vertexs
604
+ if fixed_v is not None and fixed_f is not None and level == 1:
605
+ t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
606
+
607
+ fixed_v_cpu = fixed_v.cpu().numpy()
608
+ kdtree_anchor = KDTree(fixed_v_cpu)
609
+ kdtree_mesh_v = KDTree(mesh_v)
610
+ _, idx_anchor = kdtree_anchor.query(mesh_v, k=1)
611
+ _, idx_mesh_v = kdtree_mesh_v.query(mesh_v, k=25)
612
+ idx_anchor = idx_anchor.squeeze()
613
+ neighbors = torch.tensor(mesh_v).cuda()[idx_mesh_v] # V, 25, 3
614
+ # calculate the distances neighbors [V, 25, 3]; mesh_v [V, 3] -> [V, 25]
615
+ neighbor_dists = torch.norm(neighbors - torch.tensor(mesh_v).cuda()[:, None], dim=-1)
616
+ neighbor_dists[neighbor_dists > 0.06] = 114514.
617
+ neighbor_weights = torch.exp(-neighbor_dists * 1.)
618
+ neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
619
+ anchors = fixed_v[idx_anchor] # V, 3
620
+ anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
621
+ dis_anchor = torch.clamp(((anchors - torch.tensor(mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
622
+ vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
623
+ vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
624
+ weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
625
+ mesh_v += weighted_vec_anchor.cpu().numpy()
626
+
627
+ t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
628
+
629
+ mesh_v = torch.tensor(mesh_v, device='cuda', dtype=torch.float32)
630
+ mesh_f = torch.tensor(mesh_f, device='cuda')
631
+
632
+ new_mesh, simp_v, simp_f = geo_refine(mesh_v, mesh_f, colors, normals, fixed_v=fixed_v, fixed_f=fixed_f, distract_mask=distract_mask, distract_bbox=distract_bbox)
633
+
634
+ # my mesh flow weight by nearest vertexs
635
+ try:
636
+ if fixed_v is not None and fixed_f is not None and level != 0:
637
+ new_mesh_v = new_mesh.verts_packed().cpu().numpy()
638
+
639
+ fixed_v_cpu = fixed_v.cpu().numpy()
640
+ kdtree_anchor = KDTree(fixed_v_cpu)
641
+ kdtree_mesh_v = KDTree(new_mesh_v)
642
+ _, idx_anchor = kdtree_anchor.query(new_mesh_v, k=1)
643
+ _, idx_mesh_v = kdtree_mesh_v.query(new_mesh_v, k=25)
644
+ idx_anchor = idx_anchor.squeeze()
645
+ neighbors = torch.tensor(new_mesh_v).cuda()[idx_mesh_v] # V, 25, 3
646
+ # calculate the distances neighbors [V, 25, 3]; new_mesh_v [V, 3] -> [V, 25]
647
+ neighbor_dists = torch.norm(neighbors - torch.tensor(new_mesh_v).cuda()[:, None], dim=-1)
648
+ neighbor_dists[neighbor_dists > 0.06] = 114514.
649
+ neighbor_weights = torch.exp(-neighbor_dists * 1.)
650
+ neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
651
+ anchors = fixed_v[idx_anchor] # V, 3
652
+ anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
653
+ dis_anchor = torch.clamp(((anchors - torch.tensor(new_mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
654
+ vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
655
+ vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
656
+ weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
657
+ new_mesh_v += weighted_vec_anchor.cpu().numpy()
658
+
659
+ # replace new_mesh verts with new_mesh_v
660
+ new_mesh = Meshes(verts=[torch.tensor(new_mesh_v, device='cuda')], faces=new_mesh.faces_list(), textures=new_mesh.textures)
661
+
662
+ except Exception as e:
663
+ pass
664
+
665
+ notsimp_v, notsimp_f, notsimp_t = new_mesh.verts_packed(), new_mesh.faces_packed(), new_mesh.textures.verts_features_packed()
666
+
667
+ if fixed_v is None:
668
+ fixed_v, fixed_f = simp_v, simp_f
669
+ complete_v, complete_f, complete_t = notsimp_v, notsimp_f, notsimp_t
670
+ else:
671
+ fixed_f = torch.cat([fixed_f, simp_f + fixed_v.shape[0]], dim=0)
672
+ fixed_v = torch.cat([fixed_v, simp_v], dim=0)
673
+
674
+ complete_f = torch.cat([complete_f, notsimp_f + complete_v.shape[0]], dim=0)
675
+ complete_v = torch.cat([complete_v, notsimp_v], dim=0)
676
+ complete_t = torch.cat([complete_t, notsimp_t], dim=0)
677
+
678
+ if level == 2:
679
+ new_mesh = Meshes(verts=[new_mesh.verts_packed()], faces=[new_mesh.faces_packed()], textures=pytorch3d.renderer.mesh.textures.TexturesVertex(verts_features=[torch.ones_like(new_mesh.textures.verts_features_packed(), device=new_mesh.verts_packed().device)*0.5]))
680
+
681
+ save_py3dmesh_with_trimesh_fast(new_mesh, meshes[name_idx].replace('.obj', '_refined.obj'), apply_sRGB_to_LinearRGB=False)
682
+ results.append(meshes[name_idx].replace('.obj', '_refined.obj'))
683
+
684
+ # save whole mesh
685
+ save_py3dmesh_with_trimesh_fast(Meshes(verts=[complete_v], faces=[complete_f], textures=pytorch3d.renderer.mesh.textures.TexturesVertex(verts_features=[complete_t])), meshes[name_idx].replace('.obj', '_refined_whole.obj'), apply_sRGB_to_LinearRGB=False)
686
+ results.append(meshes[name_idx].replace('.obj', '_refined_whole.obj'))
687
+
688
+ return results
689
+
690
+
691
+ class InferSlrmAPI:
692
+ @spaces.GPU
693
+ def __init__(self, config):
694
+ self.config_path = config['config_path']
695
+ self.config = OmegaConf.load(self.config_path)
696
+ self.config_name = os.path.basename(self.config_path).replace('.yaml', '')
697
+ self.model_config = self.config.model_config
698
+ self.infer_config = self.config.infer_config
699
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
700
+ self.model = instantiate_from_config(self.model_config)
701
+ state_dict = torch.load(self.infer_config.model_path, map_location='cpu')
702
+ self.model.load_state_dict(state_dict, strict=False)
703
+ self.model = self.model.to(self.device)
704
+ self.model.init_flexicubes_geometry(self.device, fovy=30.0, is_ortho=self.model.is_ortho)
705
+ self.model = self.model.eval()
706
+
707
+ @spaces.GPU
708
+ def gen(self, imgs):
709
+ imgs = [ cv2.imread(img[0])[:, :, ::-1] for img in imgs ]
710
+ imgs = np.stack(imgs, axis=0).astype(np.float32) / 255.0
711
+ imgs = torch.from_numpy(np.array(imgs)).permute(0, 3, 1, 2).contiguous().float() # (6, 3, 1024, 1024)
712
+ mesh_glb_fpaths = self.make3d(imgs)
713
+ return mesh_glb_fpaths[1:4] + mesh_glb_fpaths[0:1]
714
+
715
+ @spaces.GPU
716
+ def make3d(self, images):
717
+ input_cameras = torch.tensor(np.load('slrm/cameras.npy')).to(device)
718
+
719
+ images = images.unsqueeze(0).to(device)
720
+ images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
721
+
722
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
723
+ print(mesh_fpath)
724
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
725
+ mesh_dirname = os.path.dirname(mesh_fpath)
726
+
727
+ with torch.no_grad():
728
+ # get triplane
729
+ planes = self.model.forward_planes(images, input_cameras.float())
730
+
731
+ # get mesh
732
+ mesh_glb_fpaths = []
733
+ for j in range(4):
734
+ mesh_glb_fpath = self.make_mesh(mesh_fpath.replace(mesh_fpath[-4:], f'_{j}{mesh_fpath[-4:]}'), planes, level=[0, 3, 4, 2][j])
735
+ mesh_glb_fpaths.append(mesh_glb_fpath)
736
+
737
+ return mesh_glb_fpaths
738
+
739
+ @spaces.GPU
740
+ def make_mesh(self, mesh_fpath, planes, level=None):
741
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
742
+ mesh_dirname = os.path.dirname(mesh_fpath)
743
+ mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
744
+
745
+ with torch.no_grad():
746
+ # get mesh
747
+ mesh_out = self.model.extract_mesh(
748
+ planes,
749
+ use_texture_map=False,
750
+ levels=torch.tensor([level]).to(device),
751
+ **self.infer_config,
752
+ )
753
+
754
+ vertices, faces, vertex_colors = mesh_out
755
+ vertices = vertices[:, [1, 2, 0]]
756
+
757
+ if level == 2:
758
+ # fill all vertex_colors with 127
759
+ vertex_colors = np.ones_like(vertex_colors) * 127
760
+
761
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
762
+
763
+ return mesh_fpath
764
+
765
+ class InferMultiviewAPI:
766
+ def __init__(self, config):
767
+ parser = argparse.ArgumentParser()
768
+ parser.add_argument("--seed", type=int, default=42)
769
+ parser.add_argument("--num_views", type=int, default=6)
770
+ parser.add_argument("--num_levels", type=int, default=3)
771
+ parser.add_argument("--pretrained_path", type=str, default='./ckpt/StdGEN-multiview-1024')
772
+ parser.add_argument("--height", type=int, default=1024)
773
+ parser.add_argument("--width", type=int, default=576)
774
+ self.cfg = parser.parse_args()
775
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
776
+ self.pipeline = load_multiview_pipeline(self.cfg)
777
+ self.results = {}
778
+ if torch.cuda.is_available():
779
+ self.pipeline.to(device)
780
+
781
+ self.image_transforms = [transforms.Resize(int(max(self.cfg.height, self.cfg.width))),
782
+ transforms.CenterCrop((self.cfg.height, self.cfg.width)),
783
+ transforms.ToTensor(),
784
+ transforms.Lambda(lambda x: x * 2. - 1),
785
+ ]
786
+ self.image_transforms = transforms.Compose(self.image_transforms)
787
+
788
+ prompt_embeds_path = './multiview/fixed_prompt_embeds_6view'
789
+ self.normal_text_embeds = torch.load(f'{prompt_embeds_path}/normal_embeds.pt')
790
+ self.color_text_embeds = torch.load(f'{prompt_embeds_path}/clr_embeds.pt')
791
+ self.total_views = self.cfg.num_views
792
+
793
+
794
+ def process_im(self, im):
795
+ im = self.image_transforms(im)
796
+ return im
797
+
798
+ def gen(self, img, seed, num_levels):
799
+ set_seed(seed)
800
+ data = {}
801
+
802
+ cond_im_rgb = self.process_im(img)
803
+ cond_im_rgb = torch.stack([cond_im_rgb] * self.total_views, dim=0)
804
+ data["image_cond_rgb"] = cond_im_rgb[None, ...]
805
+ data["normal_prompt_embeddings"] = self.normal_text_embeds[None, ...]
806
+ data["color_prompt_embeddings"] = self.color_text_embeds[None, ...]
807
+
808
+ results = run_multiview_infer(data, self.pipeline, self.cfg, num_levels=num_levels)
809
+ for k in results:
810
+ self.results[k] = results[k]
811
+ return results
812
+
813
+
814
+ class InferCanonicalAPI:
815
+ def __init__(self, config):
816
+ self.config = config
817
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
818
+
819
+ self.config_path = config['config_path']
820
+ self.loaded_config = OmegaConf.load(self.config_path)
821
+
822
+ self.setup(**self.loaded_config)
823
+
824
+ def setup(self,
825
+ validation: Dict,
826
+ pretrained_model_path: str,
827
+ local_crossattn: bool = True,
828
+ unet_from_pretrained_kwargs=None,
829
+ unet_condition_type=None,
830
+ use_noise=True,
831
+ noise_d=256,
832
+ timestep: int = 40,
833
+ width_input: int = 640,
834
+ height_input: int = 1024,
835
+ ):
836
+ self.width_input = width_input
837
+ self.height_input = height_input
838
+ self.timestep = timestep
839
+ self.use_noise = use_noise
840
+ self.noise_d = noise_d
841
+ self.validation = validation
842
+ self.unet_condition_type = unet_condition_type
843
+ self.pretrained_model_path = pretrained_model_path
844
+
845
+ self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
846
+ self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
847
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_path, subfolder="image_encoder")
848
+ self.feature_extractor = CLIPImageProcessor()
849
+ self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
850
+ self.unet = UNetMV2DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs)
851
+ self.ref_unet = UNetMV2DRefModel.from_pretrained_2d(pretrained_model_path, subfolder="ref_unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs)
852
+
853
+ self.text_encoder.to(device, dtype=weight_dtype)
854
+ self.image_encoder.to(device, dtype=weight_dtype)
855
+ self.vae.to(device, dtype=weight_dtype)
856
+ self.ref_unet.to(device, dtype=weight_dtype)
857
+ self.unet.to(device, dtype=weight_dtype)
858
+
859
+ self.vae.requires_grad_(False)
860
+ self.ref_unet.requires_grad_(False)
861
+ self.unet.requires_grad_(False)
862
+
863
+ self.noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler-zerosnr")
864
+ self.validation_pipeline = CanonicalizationPipeline(
865
+ vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, ref_unet=self.ref_unet,feature_extractor=self.feature_extractor,image_encoder=self.image_encoder,
866
+ scheduler=self.noise_scheduler
867
+ )
868
+ self.validation_pipeline.set_progress_bar_config(disable=True)
869
+
870
+ def canonicalize(self, image, seed):
871
+ return inference(
872
+ self.validation_pipeline, image, self.vae, self.feature_extractor, self.image_encoder, self.unet, self.ref_unet, self.tokenizer, self.text_encoder,
873
+ self.pretrained_model_path, self.validation, self.width_input, self.height_input, self.unet_condition_type,
874
+ use_noise=self.use_noise, noise_d=self.noise_d, crop=True, seed=seed, timestep=self.timestep
875
+ )
876
+
877
+ def gen(self, img_input, seed=0):
878
+ if np.array(img_input).shape[-1] == 4 and np.array(img_input)[..., 3].min() == 255:
879
+ # convert to RGB
880
+ img_input = img_input.convert("RGB")
881
+ img_output = self.canonicalize(img_input, seed)
882
+
883
+ max_dim = max(img_output.width, img_output.height)
884
+ new_image = Image.new("RGBA", (max_dim, max_dim))
885
+ left = (max_dim - img_output.width) // 2
886
+ top = (max_dim - img_output.height) // 2
887
+ new_image.paste(img_output, (left, top))
888
+
889
+ return new_image