from typing import * import numpy as np import torch import utils3d import nvdiffrast.torch as dr from tqdm import tqdm import trimesh import trimesh.visual import xatlas import pyvista as pv from pymeshfix import _meshfix import igraph import cv2 from PIL import Image from .random_utils import sphere_hammersley_sequence from .render_utils import render_multiview from ..renderers import GaussianRenderer from ..representations import Strivec, Gaussian, MeshExtractResult @torch.no_grad() def _fill_holes( verts, faces, max_hole_size=0.04, max_hole_nbe=32, resolution=128, num_views=500, debug=False, verbose=False, ): """ Rasterize a mesh from multiple views and remove invisible faces. Also includes postprocessing to: 1. Remove connected components that are have low visibility. 2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole. Args: verts (torch.Tensor): Vertices of the mesh. Shape (V, 3). faces (torch.Tensor): Faces of the mesh. Shape (F, 3). max_hole_size (float): Maximum area of a hole to fill. resolution (int): Resolution of the rasterization. num_views (int): Number of views to rasterize the mesh. verbose (bool): Whether to print progress. """ # Construct cameras yaws = [] pitchs = [] for i in range(num_views): y, p = sphere_hammersley_sequence(i, num_views) yaws.append(y) pitchs.append(p) yaws = torch.tensor(yaws).cuda() pitchs = torch.tensor(pitchs).cuda() radius = 2.0 fov = torch.deg2rad(torch.tensor(40)).cuda() projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3) views = [] for (yaw, pitch) in zip(yaws, pitchs): orig = ( torch.tensor( [ torch.sin(yaw) * torch.cos(pitch), torch.cos(yaw) * torch.cos(pitch), torch.sin(pitch), ] ) .cuda() .float() * radius ) view = utils3d.torch.view_look_at( orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda(), ) views.append(view) views = torch.stack(views, dim=0) # Rasterize visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device) rastctx = utils3d.torch.RastContext(backend="cuda") for i in tqdm( range(views.shape[0]), total=views.shape[0], disable=not verbose, desc="Rasterizing", ): view = views[i] buffers = utils3d.torch.rasterize_triangle_faces( rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection, ) face_id = buffers["face_id"][0][buffers["mask"][0] > 0.95] - 1 face_id = torch.unique(face_id).long() visblity[face_id] += 1 visblity = visblity.float() / num_views # Mincut ## construct outer faces edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces) boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1) connected_components = utils3d.torch.compute_connected_components( faces, edges, face2edge ) outer_face_indices = torch.zeros( faces.shape[0], dtype=torch.bool, device=faces.device ) for i in range(len(connected_components)): outer_face_indices[connected_components[i]] = visblity[ connected_components[i] ] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5) outer_face_indices = outer_face_indices.nonzero().reshape(-1) ## construct inner faces inner_face_indices = torch.nonzero(visblity == 0).reshape(-1) if verbose: tqdm.write(f"Found {inner_face_indices.shape[0]} invisible faces") if inner_face_indices.shape[0] == 0: return verts, faces ## Construct dual graph (faces as nodes, edges as edges) dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge) dual_edge2edge = edges[dual_edge2edge] dual_edges_weights = torch.norm( verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1 ) if verbose: tqdm.write(f"Dual graph: {dual_edges.shape[0]} edges") ## solve mincut problem ### construct main graph g = igraph.Graph() g.add_vertices(faces.shape[0]) g.add_edges(dual_edges.cpu().numpy()) g.es["weight"] = dual_edges_weights.cpu().numpy() ### source and target g.add_vertex("s") g.add_vertex("t") ### connect invisible faces to source g.add_edges( [(f, "s") for f in inner_face_indices], attributes={ "weight": torch.ones(inner_face_indices.shape[0], dtype=torch.float32) .cpu() .numpy() }, ) ### connect outer faces to target g.add_edges( [(f, "t") for f in outer_face_indices], attributes={ "weight": torch.ones(outer_face_indices.shape[0], dtype=torch.float32) .cpu() .numpy() }, ) ### solve mincut cut = g.mincut("s", "t", (np.array(g.es["weight"]) * 1000).tolist()) remove_face_indices = torch.tensor( [v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device, ) if verbose: tqdm.write(f"Mincut solved, start checking the cut") ### check if the cut is valid with each connected component to_remove_cc = utils3d.torch.compute_connected_components( faces[remove_face_indices] ) if debug: tqdm.write(f"Number of connected components of the cut: {len(to_remove_cc)}") valid_remove_cc = [] cutting_edges = [] for cc in to_remove_cc: #### check if the connected component has low visibility visblity_median = visblity[remove_face_indices[cc]].median() if debug: tqdm.write(f"visblity_median: {visblity_median}") if visblity_median > 0.25: continue #### check if the cuting loop is small enough cc_edge_indices, cc_edges_degree = torch.unique( face2edge[remove_face_indices[cc]], return_counts=True ) cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1] cc_new_boundary_edge_indices = cc_boundary_edge_indices[ ~torch.isin(cc_boundary_edge_indices, boundary_edge_indices) ] if len(cc_new_boundary_edge_indices) > 0: cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components( edges[cc_new_boundary_edge_indices] ) cc_new_boundary_edges_cc_center = [ verts[edges[cc_new_boundary_edge_indices[edge_cc]]] .mean(dim=1) .mean(dim=0) for edge_cc in cc_new_boundary_edge_cc ] cc_new_boundary_edges_cc_area = [] for i, edge_cc in enumerate(cc_new_boundary_edge_cc): _e1 = ( verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i] ) _e2 = ( verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i] ) cc_new_boundary_edges_cc_area.append( torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5 ) if debug: cutting_edges.append(cc_new_boundary_edge_indices) tqdm.write(f"Area of the cutting loop: {cc_new_boundary_edges_cc_area}") if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]): continue valid_remove_cc.append(cc) if debug: face_v = verts[faces].mean(dim=1).cpu().numpy() vis_dual_edges = dual_edges.cpu().numpy() vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8) vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255] vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0] vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255] if len(valid_remove_cc) > 0: vis_colors[ remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy() ] = [255, 0, 0] utils3d.io.write_ply( "dbg_dual.ply", face_v, edges=vis_dual_edges, vertex_colors=vis_colors ) vis_verts = verts.cpu().numpy() vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy() utils3d.io.write_ply("dbg_cut.ply", vis_verts, edges=vis_edges) if len(valid_remove_cc) > 0: remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)] mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device) mask[remove_face_indices] = 0 faces = faces[mask] faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts) if verbose: tqdm.write(f"Removed {(~mask).sum()} faces by mincut") else: if verbose: tqdm.write(f"Removed 0 faces by mincut") mesh = _meshfix.PyTMesh() mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy()) mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True) verts, faces = mesh.return_arrays() verts, faces = torch.tensor( verts, device="cuda", dtype=torch.float32 ), torch.tensor(faces, device="cuda", dtype=torch.int32) return verts, faces def postprocess_mesh( vertices: np.array, faces: np.array, simplify: bool = True, simplify_ratio: float = 0.9, fill_holes: bool = True, fill_holes_max_hole_size: float = 0.04, fill_holes_max_hole_nbe: int = 32, fill_holes_resolution: int = 1024, fill_holes_num_views: int = 1000, debug: bool = False, verbose: bool = False, ): """ Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces. Args: vertices (np.array): Vertices of the mesh. Shape (V, 3). faces (np.array): Faces of the mesh. Shape (F, 3). simplify (bool): Whether to simplify the mesh, using quadric edge collapse. simplify_ratio (float): Ratio of faces to keep after simplification. fill_holes (bool): Whether to fill holes in the mesh. fill_holes_max_hole_size (float): Maximum area of a hole to fill. fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill. fill_holes_resolution (int): Resolution of the rasterization. fill_holes_num_views (int): Number of views to rasterize the mesh. verbose (bool): Whether to print progress. """ if verbose: tqdm.write( f"Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces" ) # Simplify if simplify and simplify_ratio > 0: mesh = pv.PolyData( vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1) ) mesh = mesh.decimate(simplify_ratio, progress_bar=verbose) vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:] if verbose: tqdm.write( f"After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces" ) # Remove invisible faces if fill_holes: vertices, faces = ( torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda(), ) vertices, faces = _fill_holes( vertices, faces, max_hole_size=fill_holes_max_hole_size, max_hole_nbe=fill_holes_max_hole_nbe, resolution=fill_holes_resolution, num_views=fill_holes_num_views, debug=debug, verbose=verbose, ) vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy() if verbose: tqdm.write( f"After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces" ) return vertices, faces def parametrize_mesh(vertices: np.array, faces: np.array): """ Parametrize a mesh to a texture space, using xatlas. Args: vertices (np.array): Vertices of the mesh. Shape (V, 3). faces (np.array): Faces of the mesh. Shape (F, 3). """ vmapping, indices, uvs = xatlas.parametrize(vertices, faces) vertices = vertices[vmapping] faces = indices return vertices, faces, uvs def bake_texture( vertices: np.array, faces: np.array, uvs: np.array, observations: List[np.array], masks: List[np.array], extrinsics: List[np.array], intrinsics: List[np.array], texture_size: int = 2048, near: float = 0.1, far: float = 10.0, mode: Literal["fast", "opt"] = "opt", lambda_tv: float = 1e-2, verbose: bool = False, ): """ Bake texture to a mesh from multiple observations. Args: vertices (np.array): Vertices of the mesh. Shape (V, 3). faces (np.array): Faces of the mesh. Shape (F, 3). uvs (np.array): UV coordinates of the mesh. Shape (V, 2). observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3). masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W). extrinsics (List[np.array]): List of extrinsics. Shape (4, 4). intrinsics (List[np.array]): List of intrinsics. Shape (3, 3). texture_size (int): Size of the texture. near (float): Near plane of the camera. far (float): Far plane of the camera. mode (Literal['fast', 'opt']): Mode of texture baking. lambda_tv (float): Weight of total variation loss in optimization. verbose (bool): Whether to print progress. """ vertices = torch.tensor(vertices).cuda() faces = torch.tensor(faces.astype(np.int32)).cuda() uvs = torch.tensor(uvs).cuda() observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations] masks = [torch.tensor(m > 0).bool().cuda() for m in masks] views = [ utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics ] projections = [ utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics ] if mode == "fast": texture = torch.zeros( (texture_size * texture_size, 3), dtype=torch.float32 ).cuda() texture_weights = torch.zeros( (texture_size * texture_size), dtype=torch.float32 ).cuda() rastctx = utils3d.torch.RastContext(backend="cuda") for observation, view, projection in tqdm( zip(observations, views, projections), total=len(observations), disable=not verbose, desc="Texture baking (fast)", ): with torch.no_grad(): rast = utils3d.torch.rasterize_triangle_faces( rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection, ) uv_map = rast["uv"][0].detach().flip(0) mask = rast["mask"][0].detach().bool() & masks[0] # nearest neighbor interpolation uv_map = (uv_map * texture_size).floor().long() obs = observation[mask] uv_map = uv_map[mask] idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs) texture_weights = texture_weights.scatter_add( 0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device), ) mask = texture_weights > 0 texture[mask] /= texture_weights[mask][:, None] texture = np.clip( texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255 ).astype(np.uint8) # inpaint mask = ( (texture_weights == 0) .cpu() .numpy() .astype(np.uint8) .reshape(texture_size, texture_size) ) texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) elif mode == "opt": rastctx = utils3d.torch.RastContext(backend="cuda") observations = [observations.flip(0) for observations in observations] masks = [m.flip(0) for m in masks] _uv = [] _uv_dr = [] for observation, view, projection in tqdm( zip(observations, views, projections), total=len(views), disable=not verbose, desc="Texture baking (opt): UV", ): with torch.no_grad(): rast = utils3d.torch.rasterize_triangle_faces( rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection, ) _uv.append(rast["uv"].detach()) _uv_dr.append(rast["uv_dr"].detach()) texture = torch.nn.Parameter( torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda() ) optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2) def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): return start_lr * (end_lr / start_lr) ** (step / total_steps) def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): return end_lr + 0.5 * (start_lr - end_lr) * ( 1 + np.cos(np.pi * step / total_steps) ) def tv_loss(texture): return torch.nn.functional.l1_loss( texture[:, :-1, :, :], texture[:, 1:, :, :] ) + torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :]) total_steps = 2500 with tqdm( total=total_steps, disable=not verbose, desc="Texture baking (opt): optimizing", ) as pbar: for step in range(total_steps): optimizer.zero_grad() selected = np.random.randint(0, len(views)) uv, uv_dr, observation, mask = ( _uv[selected], _uv_dr[selected], observations[selected], masks[selected], ) render = dr.texture(texture, uv, uv_dr)[0] loss = torch.nn.functional.l1_loss(render[mask], observation[mask]) if lambda_tv > 0: loss += lambda_tv * tv_loss(texture) loss.backward() optimizer.step() # annealing optimizer.param_groups[0]["lr"] = cosine_anealing( optimizer, step, total_steps, 1e-2, 1e-5 ) pbar.set_postfix({"loss": loss.item()}) pbar.update() texture = np.clip( texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255 ).astype(np.uint8) mask = 1 - utils3d.torch.rasterize_triangle_faces( rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size )["mask"][0].detach().cpu().numpy().astype(np.uint8) texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) else: raise ValueError(f"Unknown mode: {mode}") return texture def to_glb( app_rep: Union[Strivec, Gaussian], mesh: MeshExtractResult, simplify: float = 0.95, fill_holes: bool = True, fill_holes_max_size: float = 0.04, texture_size: int = 1024, debug: bool = False, verbose: bool = True, ) -> trimesh.Trimesh: """ Convert a generated asset to a glb file. Args: app_rep (Union[Strivec, Gaussian]): Appearance representation. mesh (MeshExtractResult): Extracted mesh. simplify (float): Ratio of faces to remove in simplification. fill_holes (bool): Whether to fill holes in the mesh. fill_holes_max_size (float): Maximum area of a hole to fill. texture_size (int): Size of the texture. debug (bool): Whether to print debug information. verbose (bool): Whether to print progress. """ vertices = mesh.vertices.cpu().numpy() faces = mesh.faces.cpu().numpy() # mesh postprocess vertices, faces = postprocess_mesh( vertices, faces, simplify=simplify > 0, simplify_ratio=simplify, fill_holes=fill_holes, fill_holes_max_hole_size=fill_holes_max_size, fill_holes_max_hole_nbe=int(250 * np.sqrt(1 - simplify)), fill_holes_resolution=1024, fill_holes_num_views=1000, debug=debug, verbose=verbose, ) # parametrize mesh vertices, faces, uvs = parametrize_mesh(vertices, faces) # bake texture observations, extrinsics, intrinsics = render_multiview( app_rep, resolution=1024, nviews=100 ) masks = [np.any(observation > 0, axis=-1) for observation in observations] extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))] intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))] texture = bake_texture( vertices, faces, uvs, observations, masks, extrinsics, intrinsics, texture_size=texture_size, mode="opt", lambda_tv=0.01, verbose=verbose, ) texture = Image.fromarray(texture) # rotate mesh (from z-up to y-up) vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) material = trimesh.visual.material.PBRMaterial( roughnessFactor=1.0, baseColorTexture=texture, baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8), ) mesh = trimesh.Trimesh( vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material) ) return mesh def simplify_gs( gs: Gaussian, simplify: float = 0.95, verbose: bool = True, ): """ Simplify 3D Gaussians NOTE: this function is not used in the current implementation for the unsatisfactory performance. Args: gs (Gaussian): 3D Gaussian. simplify (float): Ratio of Gaussians to remove in simplification. """ if simplify <= 0: return gs # simplify observations, extrinsics, intrinsics = render_multiview( gs, resolution=1024, nviews=100 ) observations = [ torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) for obs in observations ] # Following https://arxiv.org/pdf/2411.06019 renderer = GaussianRenderer( { "resolution": 1024, "near": 0.8, "far": 1.6, "ssaa": 1, "bg_color": (0, 0, 0), } ) new_gs = Gaussian(**gs.init_params) new_gs._features_dc = gs._features_dc.clone() new_gs._features_rest = ( gs._features_rest.clone() if gs._features_rest is not None else None ) new_gs._opacity = torch.nn.Parameter(gs._opacity.clone()) new_gs._rotation = torch.nn.Parameter(gs._rotation.clone()) new_gs._scaling = torch.nn.Parameter(gs._scaling.clone()) new_gs._xyz = torch.nn.Parameter(gs._xyz.clone()) start_lr = [1e-4, 1e-3, 5e-3, 0.025] end_lr = [1e-6, 1e-5, 5e-5, 0.00025] optimizer = torch.optim.Adam( [ {"params": new_gs._xyz, "lr": start_lr[0]}, {"params": new_gs._rotation, "lr": start_lr[1]}, {"params": new_gs._scaling, "lr": start_lr[2]}, {"params": new_gs._opacity, "lr": start_lr[3]}, ], lr=start_lr[0], ) def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): return start_lr * (end_lr / start_lr) ** (step / total_steps) def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): return end_lr + 0.5 * (start_lr - end_lr) * ( 1 + np.cos(np.pi * step / total_steps) ) _zeta = new_gs.get_opacity.clone().detach().squeeze() _lambda = torch.zeros_like(_zeta) _delta = 1e-7 _interval = 10 num_target = int((1 - simplify) * _zeta.shape[0]) with tqdm(total=2500, disable=not verbose, desc="Simplifying Gaussian") as pbar: for i in range(2500): # prune if i % 100 == 0: mask = new_gs.get_opacity.squeeze() > 0.05 mask = torch.nonzero(mask).squeeze() new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask]) new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask]) new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask]) new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask]) new_gs._features_dc = new_gs._features_dc[mask] new_gs._features_rest = ( new_gs._features_rest[mask] if new_gs._features_rest is not None else None ) _zeta = _zeta[mask] _lambda = _lambda[mask] # update optimizer state for param_group, new_param in zip( optimizer.param_groups, [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity], ): stored_state = optimizer.state[param_group["params"][0]] if "exp_avg" in stored_state: stored_state["exp_avg"] = stored_state["exp_avg"][mask] stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] del optimizer.state[param_group["params"][0]] param_group["params"][0] = new_param optimizer.state[param_group["params"][0]] = stored_state opacity = new_gs.get_opacity.squeeze() # sparisfy if i % _interval == 0: _zeta = _lambda + opacity.detach() if opacity.shape[0] > num_target: index = _zeta.topk(num_target)[1] _m = torch.ones_like(_zeta, dtype=torch.bool) _m[index] = 0 _zeta[_m] = 0 _lambda = _lambda + opacity.detach() - _zeta # sample a random view view_idx = np.random.randint(len(observations)) observation = observations[view_idx] extrinsic = extrinsics[view_idx] intrinsic = intrinsics[view_idx] color = renderer.render(new_gs, extrinsic, intrinsic)["color"] rgb_loss = torch.nn.functional.l1_loss(color, observation) loss = rgb_loss + _delta * torch.sum( torch.pow(_lambda + opacity - _zeta, 2) ) optimizer.zero_grad() loss.backward() optimizer.step() # update lr for j in range(len(optimizer.param_groups)): optimizer.param_groups[j]["lr"] = cosine_anealing( optimizer, i, 2500, start_lr[j], end_lr[j] ) pbar.set_postfix( { "loss": rgb_loss.item(), "num": opacity.shape[0], "lambda": _lambda.mean().item(), } ) pbar.update() new_gs._xyz = new_gs._xyz.data new_gs._rotation = new_gs._rotation.data new_gs._scaling = new_gs._scaling.data new_gs._opacity = new_gs._opacity.data return new_gs