Spaces:
Sleeping
Sleeping
FrozenBurning
commited on
Commit
·
670f57e
1
Parent(s):
9d573a0
update inference
Browse files- inference.py +10 -7
inference.py
CHANGED
@@ -86,17 +86,18 @@ def extract_texmesh(args, model, output_path, device):
|
|
86 |
# Prepare directory
|
87 |
ins_dir = output_path
|
88 |
# Noise Filter
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
92 |
dist = torch.sqrt(torch.sum((prim_position[:, None, :] - prim_position[None, :, :]) ** 2, dim=-1))
|
93 |
-
dist += torch.eye(prim_position.shape[0]).to(
|
94 |
min_dist, min_indices = dist.min(1)
|
95 |
dst_prim_scale = prim_scale[min_indices, :]
|
96 |
-
min_scale_converage = prim_scale * 1.
|
97 |
prim_mask = min_dist < min_scale_converage[:, 0]
|
98 |
-
filtered_srt_param =
|
99 |
-
filtered_feat_param =
|
100 |
model.srt_param.data = filtered_srt_param
|
101 |
model.feat_param.data = filtered_feat_param
|
102 |
print(f'[INFO] Mesh Extraction on PrimX: srt={model.srt_param.shape} feat={model.feat_param.shape}')
|
@@ -210,6 +211,8 @@ def extract_texmesh(args, model, output_path, device):
|
|
210 |
|
211 |
target_mesh = Mesh(v=torch.from_numpy(v_np).contiguous(), f=torch.from_numpy(f_np).contiguous(), ft=ft.contiguous(), vt=torch.from_numpy(vt_np).contiguous(), albedo=torch.from_numpy(feats[..., :3]) / 255, metallicRoughness=torch.from_numpy(feats[..., 3:]) / 255)
|
212 |
target_mesh.write(os.path.join(ins_dir, f'pbr_mesh.glb'))
|
|
|
|
|
213 |
|
214 |
def main(config):
|
215 |
logging.basicConfig(level=logging.INFO)
|
|
|
86 |
# Prepare directory
|
87 |
ins_dir = output_path
|
88 |
# Noise Filter
|
89 |
+
raw_srt_param = model.srt_param.clone()
|
90 |
+
raw_feat_param = model.feat_param.clone()
|
91 |
+
prim_position = raw_srt_param[:, 1:4]
|
92 |
+
prim_scale = raw_srt_param[:, 0:1]
|
93 |
dist = torch.sqrt(torch.sum((prim_position[:, None, :] - prim_position[None, :, :]) ** 2, dim=-1))
|
94 |
+
dist += torch.eye(prim_position.shape[0]).to(raw_srt_param)
|
95 |
min_dist, min_indices = dist.min(1)
|
96 |
dst_prim_scale = prim_scale[min_indices, :]
|
97 |
+
min_scale_converage = prim_scale * 1. + dst_prim_scale * 1.
|
98 |
prim_mask = min_dist < min_scale_converage[:, 0]
|
99 |
+
filtered_srt_param = raw_srt_param[prim_mask, :]
|
100 |
+
filtered_feat_param = raw_feat_param[prim_mask, ...]
|
101 |
model.srt_param.data = filtered_srt_param
|
102 |
model.feat_param.data = filtered_feat_param
|
103 |
print(f'[INFO] Mesh Extraction on PrimX: srt={model.srt_param.shape} feat={model.feat_param.shape}')
|
|
|
211 |
|
212 |
target_mesh = Mesh(v=torch.from_numpy(v_np).contiguous(), f=torch.from_numpy(f_np).contiguous(), ft=ft.contiguous(), vt=torch.from_numpy(vt_np).contiguous(), albedo=torch.from_numpy(feats[..., :3]) / 255, metallicRoughness=torch.from_numpy(feats[..., 3:]) / 255)
|
213 |
target_mesh.write(os.path.join(ins_dir, f'pbr_mesh.glb'))
|
214 |
+
model.srt_param.data = raw_srt_param
|
215 |
+
model.feat_param.data = raw_feat_param
|
216 |
|
217 |
def main(config):
|
218 |
logging.basicConfig(level=logging.INFO)
|