Hunyuan3D-1 / third_party /gen_baking.py
Huiwenshi's picture
Upload folder using huggingface_hub
68cd723 verified
raw
history blame
10.7 kB
import os, sys, time
from typing import List, Optional
from iopath.common.file_io import PathManager
import cv2
import imageio
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torchvision import transforms
import trimesh
from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
from pytorch3d.ops import interpolate_face_attributes
from pytorch3d.common.datatypes import Device
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
look_at_view_transform,
FoVPerspectiveCameras,
PointLights,
DirectionalLights,
AmbientLights,
Materials,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
TexturesUV,
TexturesVertex,
camera_position_from_spherical_angles,
BlendParams,
)
def erode_mask(src_mask, p=1 / 20.0):
monoMaskImage = cv2.split(src_mask)[0]
br = cv2.boundingRect(monoMaskImage)
k = int(min(br[2], br[3]) * p)
kernel = np.ones((k, k), dtype=np.uint8)
dst_mask = cv2.erode(src_mask, kernel, 1)
return dst_mask
def load_objs_as_meshes_fast(
verts,
faces,
aux,
device: Optional[Device] = None,
load_textures: bool = True,
create_texture_atlas: bool = False,
texture_atlas_size: int = 4,
texture_wrap: Optional[str] = "repeat",
path_manager: Optional[PathManager] = None,
):
tex = None
if create_texture_atlas:
# TexturesAtlas type
tex = TexturesAtlas(atlas=[aux.texture_atlas.to(device)])
else:
# TexturesUV type
tex_maps = aux.texture_images
if tex_maps is not None and len(tex_maps) > 0:
verts_uvs = aux.verts_uvs.to(device) # (V, 2)
faces_uvs = faces.textures_idx.to(device) # (F, 3)
image = list(tex_maps.values())[0].to(device)[None]
tex = TexturesUV(verts_uvs=[verts_uvs], faces_uvs=[faces_uvs], maps=image)
mesh = Meshes( verts=[verts.to(device)], faces=[faces.verts_idx.to(device)], textures=tex)
return mesh
def get_triangle_to_triangle(tri_1, tri_2, img_refined):
'''
args:
tri_1:
tri_2:
'''
r1 = cv2.boundingRect(tri_1)
r2 = cv2.boundingRect(tri_2)
tri_1_cropped = []
tri_2_cropped = []
for i in range(0, 3):
tri_1_cropped.append(((tri_1[i][1] - r1[1]), (tri_1[i][0] - r1[0])))
tri_2_cropped.append(((tri_2[i][1] - r2[1]), (tri_2[i][0] - r2[0])))
trans = cv2.getAffineTransform(np.float32(tri_1_cropped), np.float32(tri_2_cropped))
img_1_cropped = np.float32(img_refined[r1[0]:r1[0] + r1[2], r1[1]:r1[1] + r1[3]])
mask = np.zeros((r2[2], r2[3], 3), dtype=np.float32)
cv2.fillConvexPoly(mask, np.int32(tri_2_cropped), (1.0, 1.0, 1.0), 16, 0)
img_2_cropped = cv2.warpAffine(
img_1_cropped, trans, (r2[3], r2[2]), None,
flags = cv2.INTER_LINEAR,
borderMode = cv2.BORDER_REFLECT_101
)
return mask, img_2_cropped, r2
def back_projection(
obj_file,
init_texture_file,
front_view_file,
dst_dir,
render_resolution=512,
uv_resolution=600,
normalThreshold=0.3, # 0.3
rgb_thresh=820, # 520
views=None,
camera_dist=1.5,
erode_scale=1/100.0,
device="cuda:0"
):
# obj_file: 带有uv的obj
# init_texture_file: 初始展开的uv贴图
# render_resolution 正面视角渲染分辨率
# uv_resolution 贴图分辨率
# thres:normal threshold
os.makedirs(dst_dir, exist_ok=True)
if isinstance(front_view_file, str):
src = np.array(Image.open(front_view_file).convert("RGB"))
elif isinstance(front_view_file, Image.Image):
src = np.array(front_view_file.convert("RGB"))
else:
raise "need file_path or pil"
image_size = (render_resolution, render_resolution)
init_texture = Image.open(init_texture_file)
init_texture = init_texture.convert("RGB")
# init_texture = init_texture.resize((uv_resolution, uv_resolution))
init_texture = np.array(init_texture).astype(np.float32)
print("load obj", obj_file)
verts, faces, aux = load_obj(obj_file, device=device)
mesh = load_objs_as_meshes_fast(verts, faces, aux, device=device)
t0 = time.time()
verts_uvs = aux.verts_uvs
triangle_uvs = verts_uvs[faces.textures_idx]
triangle_uvs = torch.cat([
((1 - triangle_uvs[..., 1]) * uv_resolution).unsqueeze(2),
(triangle_uvs[..., 0] * uv_resolution).unsqueeze(2),
], dim=-1)
triangle_uvs = np.clip(np.round(np.float32(triangle_uvs.cpu())).astype(np.int64), 0, uv_resolution-1)
# import ipdb;ipdb.set_trace()
R0, T0 = look_at_view_transform(camera_dist, views[0][0], views[0][1])
cameras = FoVPerspectiveCameras(device=device, R=R0, T=T0, fov=49.1)
camera_normal = camera_position_from_spherical_angles(1, views[0][0], views[0][1]).to(device)
screen_coords = cameras.transform_points_screen(verts, image_size=image_size)[:, :2]
screen_coords = torch.cat([screen_coords[..., 1, None], screen_coords[..., 0, None]], dim=-1)
triangle_screen_coords = np.round(np.float32(screen_coords[faces.verts_idx].cpu())) # numpy.ndarray (90000, 3, 2)
triangle_screen_coords = np.clip(triangle_screen_coords.astype(np.int64), 0, render_resolution-1)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings= RasterizationSettings(
image_size=image_size,
blur_radius=0.0,
faces_per_pixel=1,
),
),
shader=SoftPhongShader(
device=device,
cameras=cameras,
lights= AmbientLights(device=device),
blend_params=BlendParams(background_color=(1.0, 1.0, 1.0)),
)
)
dst = renderer(mesh)
dst = (dst[..., :3] * 255).squeeze(0).cpu().numpy().astype(np.uint8)
src_mask = np.ones((src.shape[0], src.shape[1]), dst.dtype)
ids = np.where(dst.sum(-1) > 253 * 3)
ids2 = np.where(src.sum(-1) > 250 * 3)
src_mask[ids[0], ids[1]] = 0
src_mask[ids2[0], ids2[1]] = 0
src_mask = (src_mask > 0).astype(np.uint8) * 255
monoMaskImage = cv2.split(src_mask)[0] # reducing the mask to a monochrome
br = cv2.boundingRect(monoMaskImage) # bounding rect (x,y,width,height)
center = (br[0] + br[2] // 2, br[1] + br[3] // 2)
# seamlessClone
try:
images = cv2.seamlessClone(src, dst, src_mask, center, cv2.NORMAL_CLONE) # more qingxi
# images = cv2.seamlessClone(src, dst, src_mask, center, cv2.MIXED_CLONE)
except Exception as err:
print(f"\n\n Warning seamlessClone error: {err} \n\n")
images = src
Image.fromarray(src_mask).save(os.path.join(dst_dir, 'mask.jpeg'))
Image.fromarray(src).save(os.path.join(dst_dir, 'src.jpeg'))
Image.fromarray(dst).save(os.path.join(dst_dir, 'dst.jpeg'))
Image.fromarray(images).save(os.path.join(dst_dir, 'blend.jpeg'))
fragments_scaled = renderer.rasterizer(mesh) # pytorch3d.renderer.mesh.rasterizer.Fragments
faces_covered = fragments_scaled.pix_to_face.unique()[1:] # torch.Tensor torch.Size([30025])
face_normals = mesh.faces_normals_packed().to(device) # torch.Tensor torch.Size([90000, 3]) cuda:0
# faces: pytorch3d.io.obj_io.Faces
# faces.textures_idx: torch.Tensor torch.Size([90000, 3])
# verts_uvs: torch.Tensor torch.Size([49554, 2])
triangle_uvs = verts_uvs[faces.textures_idx]
triangle_uvs = [
((1 - triangle_uvs[..., 1]) * uv_resolution).unsqueeze(2),
(triangle_uvs[..., 0] * uv_resolution).unsqueeze(2),
]
triangle_uvs = torch.cat(triangle_uvs, dim=-1) # numpy.ndarray (90000, 3, 2)
triangle_uvs = np.clip(np.round(np.float32(triangle_uvs.cpu())).astype(np.int64), 0, uv_resolution-1)
t0 = time.time()
SOFT_NORM = True # process big angle-diff face, true:flase? coeff:skip
for k in faces_covered:
# todo: accelerate this for-loop
# if cosine between face-camera is too low, skip current face baking
face_normal = face_normals[k]
cosine = torch.sum((face_normal * camera_normal) ** 2)
if not SOFT_NORM and cosine < normalThreshold: continue
# if coord in screen out of subject, skip current face baking
out_of_subject = src_mask[triangle_screen_coords[k][0][0], triangle_screen_coords[k][0][1]]==0
if out_of_subject: continue
coeff, img_2_cropped, r2 = get_triangle_to_triangle(triangle_screen_coords[k], triangle_uvs[k], images)
# if color difference between new-old, skip current face baking
err = np.abs(init_texture[r2[0]:r2[0]+r2[2], r2[1]:r2[1]+r2[3]]- img_2_cropped)
err = (err * coeff).sum(-1)
# print(err.shape, np.max(err))
if (np.max(err) > rgb_thresh): continue
color_for_debug = None
# if (np.max(err) > 400): color_for_debug = [255, 0, 0]
# if (np.max(err) > 450): color_for_debug = [0, 255, 0]
# if (np.max(err) > 500): color_for_debug = [0, 0, 255]
coeff = coeff.clip(0, 1)
if SOFT_NORM:
coeff *= ((cosine.detach().cpu().numpy() - normalThreshold) / normalThreshold).clip(0,1)
coeff *= (((rgb_thresh - err[...,None]) / rgb_thresh)**0.4).clip(0,1)
if color_for_debug is None:
init_texture[r2[0]:r2[0]+r2[2], r2[1]:r2[1]+r2[3]] = \
init_texture[r2[0]:r2[0]+r2[2], r2[1]:r2[1]+r2[3]] * ((1.0,1.0,1.0)-coeff) + img_2_cropped * coeff
else:
init_texture[r2[0]:r2[0]+r2[2], r2[1]:r2[1]+r2[3]] = color_for_debug
print(f'View baking time: {time.time() - t0}')
bake_dir = os.path.join(dst_dir, 'bake')
os.makedirs(bake_dir, exist_ok=True)
os.system(f'cp {obj_file} {bake_dir}')
textute_img = Image.fromarray(init_texture.astype(np.uint8))
textute_img.save(os.path.join(bake_dir, init_texture_file.split("/")[-1]))
mtl_dir = obj_file.replace('.obj', '.mtl')
if not os.path.exists(mtl_dir): mtl_dir = obj_file.replace("mesh.obj" ,"material.mtl")
if not os.path.exists(mtl_dir): mtl_dir = obj_file.replace("mesh.obj" ,"texture.mtl")
if not os.path.exists(mtl_dir): import ipdb;ipdb.set_trace()
os.system(f'cp {mtl_dir} {bake_dir}')
# convert .obj to .glb file
new_obj_pth = os.path.join(bake_dir, obj_file.split('/')[-1])
new_glb_path = new_obj_pth.replace('.obj', '.glb')
mesh = trimesh.load_mesh(new_obj_pth)
mesh.export(new_glb_path, file_type='glb')