Update pages/01_🦷 Segment.py
Browse files- pages/01_🦷 Segment.py +217 -194
pages/01_🦷 Segment.py
CHANGED
@@ -9,6 +9,7 @@ from pygco import cut_from_graph
|
|
9 |
import open3d as o3d
|
10 |
import matplotlib.pyplot as plt
|
11 |
import matplotlib.colors as mcolors
|
|
|
12 |
import json
|
13 |
from stpyvista import stpyvista
|
14 |
import torch
|
@@ -662,6 +663,201 @@ def obj2pcd(obj_path):
|
|
662 |
pcd_points = np.array(pcd_list).astype(np.float64)
|
663 |
return pcd_points, jaw
|
664 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
665 |
# Configure Streamlit page
|
666 |
st.set_page_config(page_title="Teeth Segmentation", page_icon="🦷")
|
667 |
|
@@ -681,218 +877,45 @@ class Segment(TeethApp):
|
|
681 |
)
|
682 |
import pyvista as pv
|
683 |
if inputs == "Example Scan":
|
|
|
684 |
mesh = pv.read("ZOUIF2W4_upper.obj")
|
685 |
plotter = pv.Plotter()
|
686 |
|
687 |
# Add the mesh to the plotter
|
688 |
-
plotter.add_mesh(mesh, color='
|
689 |
-
|
690 |
-
|
691 |
stpyvista(plotter)
|
692 |
|
|
|
|
|
|
|
|
|
|
|
693 |
elif inputs == "Upload Scan":
|
694 |
file = st.file_uploader("Please upload an OBJ Object file", type=["OBJ"])
|
695 |
-
|
696 |
if file is not None:
|
697 |
# save the uploaded file to disk
|
698 |
with open("file.obj", "wb") as buffer:
|
699 |
shutil.copyfileobj(file, buffer)
|
700 |
# 复制数据
|
701 |
-
|
702 |
-
|
703 |
obj_path = "file.obj"
|
704 |
-
upsampling_method = 'KNN'
|
705 |
-
|
706 |
-
model_path = 'Mesh_Segementation_MeshSegNet_17_classes_60samples_best.tar'
|
707 |
-
num_classes = 17
|
708 |
-
num_channels = 15
|
709 |
-
|
710 |
-
# set model
|
711 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
712 |
-
model = MeshSegNet(num_classes=num_classes, num_channels=num_channels).to(device, dtype=torch.float)
|
713 |
-
|
714 |
-
# load trained model
|
715 |
-
# checkpoint = torch.load(os.path.join(model_path, model_name), map_location='cpu')
|
716 |
-
checkpoint = torch.load(model_path, map_location='cpu')
|
717 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
718 |
-
del checkpoint
|
719 |
-
model = model.to(device, dtype=torch.float)
|
720 |
-
|
721 |
-
# cudnn
|
722 |
-
torch.backends.cudnn.benchmark = True
|
723 |
-
torch.backends.cudnn.enabled = True
|
724 |
-
|
725 |
-
# Predicting
|
726 |
-
model.eval()
|
727 |
-
with torch.no_grad():
|
728 |
-
pcd_points, jaw = obj2pcd(obj_path)
|
729 |
-
mesh = mesh_grid(pcd_points)
|
730 |
-
|
731 |
-
# move mesh to origin
|
732 |
-
with st.spinner("Patience please, AI at work. Grab a coffee while you wait☕!"):
|
733 |
-
vertices_points = np.asarray(mesh.vertices)
|
734 |
-
triangles_points = np.asarray(mesh.triangles)
|
735 |
-
N = triangles_points.shape[0]
|
736 |
-
cells = np.zeros((triangles_points.shape[0], 9))
|
737 |
-
cells = vertices_points[triangles_points].reshape(triangles_points.shape[0], 9)
|
738 |
-
|
739 |
-
mean_cell_centers = mesh.get_center()
|
740 |
-
cells[:, 0:3] -= mean_cell_centers[0:3]
|
741 |
-
cells[:, 3:6] -= mean_cell_centers[0:3]
|
742 |
-
cells[:, 6:9] -= mean_cell_centers[0:3]
|
743 |
-
|
744 |
-
v1 = np.zeros([triangles_points.shape[0], 3], dtype='float32')
|
745 |
-
v2 = np.zeros([triangles_points.shape[0], 3], dtype='float32')
|
746 |
-
v1[:, 0] = cells[:, 0] - cells[:, 3]
|
747 |
-
v1[:, 1] = cells[:, 1] - cells[:, 4]
|
748 |
-
v1[:, 2] = cells[:, 2] - cells[:, 5]
|
749 |
-
v2[:, 0] = cells[:, 3] - cells[:, 6]
|
750 |
-
v2[:, 1] = cells[:, 4] - cells[:, 7]
|
751 |
-
v2[:, 2] = cells[:, 5] - cells[:, 8]
|
752 |
-
mesh_normals = np.cross(v1, v2)
|
753 |
-
mesh_normal_length = np.linalg.norm(mesh_normals, axis=1)
|
754 |
-
mesh_normals[:, 0] /= mesh_normal_length[:]
|
755 |
-
mesh_normals[:, 1] /= mesh_normal_length[:]
|
756 |
-
mesh_normals[:, 2] /= mesh_normal_length[:]
|
757 |
-
|
758 |
-
# prepare input
|
759 |
-
points = vertices_points.copy()
|
760 |
-
points[:, 0:3] -= mean_cell_centers[0:3]
|
761 |
-
normals = np.nan_to_num(mesh_normals).copy()
|
762 |
-
barycenters = np.zeros((triangles_points.shape[0], 3))
|
763 |
-
s = np.sum(vertices_points[triangles_points], 1)
|
764 |
-
barycenters = 1 / 3 * s
|
765 |
-
center_points = barycenters.copy()
|
766 |
-
barycenters -= mean_cell_centers[0:3]
|
767 |
-
|
768 |
-
# normalized data
|
769 |
-
maxs = points.max(axis=0)
|
770 |
-
mins = points.min(axis=0)
|
771 |
-
means = points.mean(axis=0)
|
772 |
-
stds = points.std(axis=0)
|
773 |
-
nmeans = normals.mean(axis=0)
|
774 |
-
nstds = normals.std(axis=0)
|
775 |
-
|
776 |
-
for i in range(3):
|
777 |
-
cells[:, i] = (cells[:, i] - means[i]) / stds[i] # point 1
|
778 |
-
cells[:, i + 3] = (cells[:, i + 3] - means[i]) / stds[i] # point 2
|
779 |
-
cells[:, i + 6] = (cells[:, i + 6] - means[i]) / stds[i] # point 3
|
780 |
-
barycenters[:, i] = (barycenters[:, i] - mins[i]) / (maxs[i] - mins[i])
|
781 |
-
normals[:, i] = (normals[:, i] - nmeans[i]) / nstds[i]
|
782 |
-
|
783 |
-
X = np.column_stack((cells, barycenters, normals))
|
784 |
-
|
785 |
-
# computing A_S and A_L
|
786 |
-
A_S = np.zeros([X.shape[0], X.shape[0]], dtype='float32')
|
787 |
-
A_L = np.zeros([X.shape[0], X.shape[0]], dtype='float32')
|
788 |
-
D = distance_matrix(X[:, 9:12], X[:, 9:12])
|
789 |
-
A_S[D < 0.1] = 1.0
|
790 |
-
A_S = A_S / np.dot(np.sum(A_S, axis=1, keepdims=True), np.ones((1, X.shape[0])))
|
791 |
-
|
792 |
-
A_L[D < 0.2] = 1.0
|
793 |
-
A_L = A_L / np.dot(np.sum(A_L, axis=1, keepdims=True), np.ones((1, X.shape[0])))
|
794 |
-
|
795 |
-
# numpy -> torch.tensor
|
796 |
-
X = X.transpose(1, 0)
|
797 |
-
X = X.reshape([1, X.shape[0], X.shape[1]])
|
798 |
-
X = torch.from_numpy(X).to(device, dtype=torch.float)
|
799 |
-
A_S = A_S.reshape([1, A_S.shape[0], A_S.shape[1]])
|
800 |
-
A_L = A_L.reshape([1, A_L.shape[0], A_L.shape[1]])
|
801 |
-
A_S = torch.from_numpy(A_S).to(device, dtype=torch.float)
|
802 |
-
A_L = torch.from_numpy(A_L).to(device, dtype=torch.float)
|
803 |
-
|
804 |
-
tensor_prob_output = model(X, A_S, A_L).to(device, dtype=torch.float)
|
805 |
-
patch_prob_output = tensor_prob_output.cpu().numpy()
|
806 |
-
|
807 |
-
# refinement
|
808 |
-
with st.spinner("Refining..."):
|
809 |
-
round_factor = 100
|
810 |
-
patch_prob_output[patch_prob_output < 1.0e-6] = 1.0e-6
|
811 |
-
|
812 |
-
# unaries
|
813 |
-
unaries = -round_factor * np.log10(patch_prob_output)
|
814 |
-
unaries = unaries.astype(np.int32)
|
815 |
-
unaries = unaries.reshape(-1, num_classes)
|
816 |
-
|
817 |
-
# parawisex
|
818 |
-
pairwise = (1 - np.eye(num_classes, dtype=np.int32))
|
819 |
-
|
820 |
-
cells = cells.copy()
|
821 |
-
|
822 |
-
cell_ids = np.asarray(triangles_points)
|
823 |
-
lambda_c = 20
|
824 |
-
edges = np.empty([1, 3], order='C')
|
825 |
-
for i_node in range(cells.shape[0]):
|
826 |
-
# Find neighbors
|
827 |
-
nei = np.sum(np.isin(cell_ids, cell_ids[i_node, :]), axis=1)
|
828 |
-
nei_id = np.where(nei == 2)
|
829 |
-
for i_nei in nei_id[0][:]:
|
830 |
-
if i_node < i_nei:
|
831 |
-
cos_theta = np.dot(normals[i_node, 0:3], normals[i_nei, 0:3]) / np.linalg.norm(
|
832 |
-
normals[i_node, 0:3]) / np.linalg.norm(normals[i_nei, 0:3])
|
833 |
-
if cos_theta >= 1.0:
|
834 |
-
cos_theta = 0.9999
|
835 |
-
theta = np.arccos(cos_theta)
|
836 |
-
phi = np.linalg.norm(barycenters[i_node, :] - barycenters[i_nei, :])
|
837 |
-
if theta > np.pi / 2.0:
|
838 |
-
edges = np.concatenate(
|
839 |
-
(edges, np.array([i_node, i_nei, -np.log10(theta / np.pi) * phi]).reshape(1, 3)), axis=0)
|
840 |
-
else:
|
841 |
-
beta = 1 + np.linalg.norm(np.dot(normals[i_node, 0:3], normals[i_nei, 0:3]))
|
842 |
-
edges = np.concatenate(
|
843 |
-
(edges, np.array([i_node, i_nei, -beta * np.log10(theta / np.pi) * phi]).reshape(1, 3)),
|
844 |
-
axis=0)
|
845 |
-
edges = np.delete(edges, 0, 0)
|
846 |
-
edges[:, 2] *= lambda_c * round_factor
|
847 |
-
edges = edges.astype(np.int32)
|
848 |
-
|
849 |
-
refine_labels = cut_from_graph(edges, unaries, pairwise)
|
850 |
-
refine_labels = refine_labels.reshape([-1, 1])
|
851 |
-
|
852 |
-
predicted_labels_3 = refine_labels.reshape(refine_labels.shape[0])
|
853 |
-
mesh_to_points_main(jaw, pcd_points, center_points, predicted_labels_3)
|
854 |
-
|
855 |
-
import pyvista as pv
|
856 |
-
|
857 |
-
with st.spinner("Rendering..."):
|
858 |
-
# Load the .obj file
|
859 |
-
mesh = pv.read('file.obj')
|
860 |
-
|
861 |
-
# Load the JSON file
|
862 |
-
with open('dental-labels4.json', 'r') as file:
|
863 |
-
labels_data = json.load(file)
|
864 |
-
|
865 |
-
# Assuming labels_data['labels'] is a list of labels
|
866 |
-
labels = labels_data['labels']
|
867 |
-
|
868 |
-
# Make sure the number of labels matches the number of vertices or faces
|
869 |
-
assert len(labels) == mesh.n_points or len(labels) == mesh.n_cells
|
870 |
-
|
871 |
-
# If labels correspond to vertices
|
872 |
-
if len(labels) == mesh.n_points:
|
873 |
-
mesh.point_data['Labels'] = labels
|
874 |
-
# If labels correspond to faces
|
875 |
-
elif len(labels) == mesh.n_cells:
|
876 |
-
mesh.cell_data['Labels'] = labels
|
877 |
-
|
878 |
-
# Create a pyvista plotter
|
879 |
-
plotter = pv.Plotter()
|
880 |
-
|
881 |
-
cmap = plt.cm.get_cmap('jet', 27) # Using a colormap with sufficient distinct colors
|
882 |
|
883 |
-
|
|
|
884 |
|
885 |
-
|
886 |
-
|
|
|
|
|
|
|
887 |
|
888 |
-
|
889 |
-
|
890 |
-
plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap=colormap, clim=[0, 27])
|
891 |
|
892 |
-
|
893 |
-
|
894 |
-
## Send to streamlit
|
895 |
-
stpyvista(plotter)
|
896 |
|
897 |
if __name__ == "__main__":
|
898 |
app = Segment()
|
|
|
9 |
import open3d as o3d
|
10 |
import matplotlib.pyplot as plt
|
11 |
import matplotlib.colors as mcolors
|
12 |
+
from stqdm import stqdm
|
13 |
import json
|
14 |
from stpyvista import stpyvista
|
15 |
import torch
|
|
|
663 |
pcd_points = np.array(pcd_list).astype(np.float64)
|
664 |
return pcd_points, jaw
|
665 |
|
666 |
+
|
667 |
+
def segmentation_main(obj_path):
|
668 |
+
upsampling_method = 'KNN'
|
669 |
+
|
670 |
+
model_path = 'Mesh_Segementation_MeshSegNet_17_classes_60samples_best.tar'
|
671 |
+
num_classes = 17
|
672 |
+
num_channels = 15
|
673 |
+
|
674 |
+
# set model
|
675 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
676 |
+
model = MeshSegNet(num_classes=num_classes, num_channels=num_channels).to(device, dtype=torch.float)
|
677 |
+
|
678 |
+
# load trained model
|
679 |
+
# checkpoint = torch.load(os.path.join(model_path, model_name), map_location='cpu')
|
680 |
+
checkpoint = torch.load(model_path, map_location='cpu')
|
681 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
682 |
+
del checkpoint
|
683 |
+
model = model.to(device, dtype=torch.float)
|
684 |
+
|
685 |
+
# cudnn
|
686 |
+
torch.backends.cudnn.benchmark = True
|
687 |
+
torch.backends.cudnn.enabled = True
|
688 |
+
|
689 |
+
# Predicting
|
690 |
+
model.eval()
|
691 |
+
with torch.no_grad():
|
692 |
+
pcd_points, jaw = obj2pcd(obj_path)
|
693 |
+
mesh = mesh_grid(pcd_points)
|
694 |
+
|
695 |
+
# move mesh to origin
|
696 |
+
with st.spinner("Patience please, AI at work. Grab a coffee while you wait ☕."):
|
697 |
+
vertices_points = np.asarray(mesh.vertices)
|
698 |
+
triangles_points = np.asarray(mesh.triangles)
|
699 |
+
N = triangles_points.shape[0]
|
700 |
+
cells = np.zeros((triangles_points.shape[0], 9))
|
701 |
+
cells = vertices_points[triangles_points].reshape(triangles_points.shape[0], 9)
|
702 |
+
|
703 |
+
mean_cell_centers = mesh.get_center()
|
704 |
+
cells[:, 0:3] -= mean_cell_centers[0:3]
|
705 |
+
cells[:, 3:6] -= mean_cell_centers[0:3]
|
706 |
+
cells[:, 6:9] -= mean_cell_centers[0:3]
|
707 |
+
|
708 |
+
v1 = np.zeros([triangles_points.shape[0], 3], dtype='float32')
|
709 |
+
v2 = np.zeros([triangles_points.shape[0], 3], dtype='float32')
|
710 |
+
v1[:, 0] = cells[:, 0] - cells[:, 3]
|
711 |
+
v1[:, 1] = cells[:, 1] - cells[:, 4]
|
712 |
+
v1[:, 2] = cells[:, 2] - cells[:, 5]
|
713 |
+
v2[:, 0] = cells[:, 3] - cells[:, 6]
|
714 |
+
v2[:, 1] = cells[:, 4] - cells[:, 7]
|
715 |
+
v2[:, 2] = cells[:, 5] - cells[:, 8]
|
716 |
+
mesh_normals = np.cross(v1, v2)
|
717 |
+
mesh_normal_length = np.linalg.norm(mesh_normals, axis=1)
|
718 |
+
mesh_normals[:, 0] /= mesh_normal_length[:]
|
719 |
+
mesh_normals[:, 1] /= mesh_normal_length[:]
|
720 |
+
mesh_normals[:, 2] /= mesh_normal_length[:]
|
721 |
+
|
722 |
+
# prepare input
|
723 |
+
points = vertices_points.copy()
|
724 |
+
points[:, 0:3] -= mean_cell_centers[0:3]
|
725 |
+
normals = np.nan_to_num(mesh_normals).copy()
|
726 |
+
barycenters = np.zeros((triangles_points.shape[0], 3))
|
727 |
+
s = np.sum(vertices_points[triangles_points], 1)
|
728 |
+
barycenters = 1 / 3 * s
|
729 |
+
center_points = barycenters.copy()
|
730 |
+
barycenters -= mean_cell_centers[0:3]
|
731 |
+
|
732 |
+
# normalized data
|
733 |
+
maxs = points.max(axis=0)
|
734 |
+
mins = points.min(axis=0)
|
735 |
+
means = points.mean(axis=0)
|
736 |
+
stds = points.std(axis=0)
|
737 |
+
nmeans = normals.mean(axis=0)
|
738 |
+
nstds = normals.std(axis=0)
|
739 |
+
|
740 |
+
for i in range(3):
|
741 |
+
cells[:, i] = (cells[:, i] - means[i]) / stds[i] # point 1
|
742 |
+
cells[:, i + 3] = (cells[:, i + 3] - means[i]) / stds[i] # point 2
|
743 |
+
cells[:, i + 6] = (cells[:, i + 6] - means[i]) / stds[i] # point 3
|
744 |
+
barycenters[:, i] = (barycenters[:, i] - mins[i]) / (maxs[i] - mins[i])
|
745 |
+
normals[:, i] = (normals[:, i] - nmeans[i]) / nstds[i]
|
746 |
+
|
747 |
+
X = np.column_stack((cells, barycenters, normals))
|
748 |
+
|
749 |
+
# computing A_S and A_L
|
750 |
+
A_S = np.zeros([X.shape[0], X.shape[0]], dtype='float32')
|
751 |
+
A_L = np.zeros([X.shape[0], X.shape[0]], dtype='float32')
|
752 |
+
D = distance_matrix(X[:, 9:12], X[:, 9:12])
|
753 |
+
A_S[D < 0.1] = 1.0
|
754 |
+
A_S = A_S / np.dot(np.sum(A_S, axis=1, keepdims=True), np.ones((1, X.shape[0])))
|
755 |
+
|
756 |
+
A_L[D < 0.2] = 1.0
|
757 |
+
A_L = A_L / np.dot(np.sum(A_L, axis=1, keepdims=True), np.ones((1, X.shape[0])))
|
758 |
+
|
759 |
+
# numpy -> torch.tensor
|
760 |
+
X = X.transpose(1, 0)
|
761 |
+
X = X.reshape([1, X.shape[0], X.shape[1]])
|
762 |
+
X = torch.from_numpy(X).to(device, dtype=torch.float)
|
763 |
+
A_S = A_S.reshape([1, A_S.shape[0], A_S.shape[1]])
|
764 |
+
A_L = A_L.reshape([1, A_L.shape[0], A_L.shape[1]])
|
765 |
+
A_S = torch.from_numpy(A_S).to(device, dtype=torch.float)
|
766 |
+
A_L = torch.from_numpy(A_L).to(device, dtype=torch.float)
|
767 |
+
|
768 |
+
tensor_prob_output = model(X, A_S, A_L).to(device, dtype=torch.float)
|
769 |
+
patch_prob_output = tensor_prob_output.cpu().numpy()
|
770 |
+
|
771 |
+
# refinement
|
772 |
+
with st.spinner("Refining..."):
|
773 |
+
round_factor = 100
|
774 |
+
patch_prob_output[patch_prob_output < 1.0e-6] = 1.0e-6
|
775 |
+
|
776 |
+
# unaries
|
777 |
+
unaries = -round_factor * np.log10(patch_prob_output)
|
778 |
+
unaries = unaries.astype(np.int32)
|
779 |
+
unaries = unaries.reshape(-1, num_classes)
|
780 |
+
|
781 |
+
# parawisex
|
782 |
+
pairwise = (1 - np.eye(num_classes, dtype=np.int32))
|
783 |
+
|
784 |
+
cells = cells.copy()
|
785 |
+
|
786 |
+
cell_ids = np.asarray(triangles_points)
|
787 |
+
lambda_c = 20
|
788 |
+
edges = np.empty([1, 3], order='C')
|
789 |
+
for i_node in stqdm(range(cells.shape[0])):
|
790 |
+
# Find neighbors
|
791 |
+
nei = np.sum(np.isin(cell_ids, cell_ids[i_node, :]), axis=1)
|
792 |
+
nei_id = np.where(nei == 2)
|
793 |
+
for i_nei in nei_id[0][:]:
|
794 |
+
if i_node < i_nei:
|
795 |
+
cos_theta = np.dot(normals[i_node, 0:3], normals[i_nei, 0:3]) / np.linalg.norm(
|
796 |
+
normals[i_node, 0:3]) / np.linalg.norm(normals[i_nei, 0:3])
|
797 |
+
if cos_theta >= 1.0:
|
798 |
+
cos_theta = 0.9999
|
799 |
+
theta = np.arccos(cos_theta)
|
800 |
+
phi = np.linalg.norm(barycenters[i_node, :] - barycenters[i_nei, :])
|
801 |
+
if theta > np.pi / 2.0:
|
802 |
+
edges = np.concatenate(
|
803 |
+
(edges, np.array([i_node, i_nei, -np.log10(theta / np.pi) * phi]).reshape(1, 3)), axis=0)
|
804 |
+
else:
|
805 |
+
beta = 1 + np.linalg.norm(np.dot(normals[i_node, 0:3], normals[i_nei, 0:3]))
|
806 |
+
edges = np.concatenate(
|
807 |
+
(edges, np.array([i_node, i_nei, -beta * np.log10(theta / np.pi) * phi]).reshape(1, 3)),
|
808 |
+
axis=0)
|
809 |
+
edges = np.delete(edges, 0, 0)
|
810 |
+
edges[:, 2] *= lambda_c * round_factor
|
811 |
+
edges = edges.astype(np.int32)
|
812 |
+
|
813 |
+
refine_labels = cut_from_graph(edges, unaries, pairwise)
|
814 |
+
refine_labels = refine_labels.reshape([-1, 1])
|
815 |
+
|
816 |
+
predicted_labels_3 = refine_labels.reshape(refine_labels.shape[0])
|
817 |
+
mesh_to_points_main(jaw, pcd_points, center_points, predicted_labels_3)
|
818 |
+
|
819 |
+
import pyvista as pv
|
820 |
+
with st.spinner("Rendering..."):
|
821 |
+
# Load the .obj file
|
822 |
+
mesh = pv.read('file.obj')
|
823 |
+
|
824 |
+
# Load the JSON file
|
825 |
+
with open('dental-labels4.json', 'r') as file:
|
826 |
+
labels_data = json.load(file)
|
827 |
+
|
828 |
+
# Assuming labels_data['labels'] is a list of labels
|
829 |
+
labels = labels_data['labels']
|
830 |
+
|
831 |
+
# Make sure the number of labels matches the number of vertices or faces
|
832 |
+
assert len(labels) == mesh.n_points or len(labels) == mesh.n_cells
|
833 |
+
|
834 |
+
# If labels correspond to vertices
|
835 |
+
if len(labels) == mesh.n_points:
|
836 |
+
mesh.point_data['Labels'] = labels
|
837 |
+
# If labels correspond to faces
|
838 |
+
elif len(labels) == mesh.n_cells:
|
839 |
+
mesh.cell_data['Labels'] = labels
|
840 |
+
|
841 |
+
# Create a pyvista plotter
|
842 |
+
plotter = pv.Plotter()
|
843 |
+
|
844 |
+
cmap = plt.cm.get_cmap('jet', 27) # Using a colormap with sufficient distinct colors
|
845 |
+
|
846 |
+
colors = cmap(np.linspace(0, 1, 27)) # Generate colors
|
847 |
+
|
848 |
+
# Convert colors to a format acceptable by PyVista
|
849 |
+
colormap = mcolors.ListedColormap(colors)
|
850 |
+
|
851 |
+
# Add the mesh to the plotter with labels as a scalar field
|
852 |
+
#plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap='jet')
|
853 |
+
plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap=colormap, clim=[0, 27])
|
854 |
+
|
855 |
+
# Show the plot
|
856 |
+
#plotter.show()
|
857 |
+
## Send to streamlit
|
858 |
+
with st.expander("View Segmentation Result", expanded=False):
|
859 |
+
stpyvista(plotter)
|
860 |
+
|
861 |
# Configure Streamlit page
|
862 |
st.set_page_config(page_title="Teeth Segmentation", page_icon="🦷")
|
863 |
|
|
|
877 |
)
|
878 |
import pyvista as pv
|
879 |
if inputs == "Example Scan":
|
880 |
+
st.markdown("Expected time per prediction: 7-10 min.")
|
881 |
mesh = pv.read("ZOUIF2W4_upper.obj")
|
882 |
plotter = pv.Plotter()
|
883 |
|
884 |
# Add the mesh to the plotter
|
885 |
+
plotter.add_mesh(mesh, color='white', show_edges=True)
|
886 |
+
segment = st.button("Apply Segmentation")
|
887 |
+
with st.expander("View Scan", expanded=False):
|
888 |
stpyvista(plotter)
|
889 |
|
890 |
+
if segment:
|
891 |
+
segmentation_main("ZOUIF2W4_upper.obj")
|
892 |
+
|
893 |
+
|
894 |
+
|
895 |
elif inputs == "Upload Scan":
|
896 |
file = st.file_uploader("Please upload an OBJ Object file", type=["OBJ"])
|
897 |
+
st.markdown("Expected time per prediction: 7-10 min.")
|
898 |
if file is not None:
|
899 |
# save the uploaded file to disk
|
900 |
with open("file.obj", "wb") as buffer:
|
901 |
shutil.copyfileobj(file, buffer)
|
902 |
# 复制数据
|
|
|
|
|
903 |
obj_path = "file.obj"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
904 |
|
905 |
+
mesh = pv.read(obj_path)
|
906 |
+
plotter = pv.Plotter()
|
907 |
|
908 |
+
# Add the mesh to the plotter
|
909 |
+
plotter.add_mesh(mesh, color='white', show_edges=True)
|
910 |
+
segment = st.button("Apply Segmentation")
|
911 |
+
with st.expander("View Scan", expanded=False):
|
912 |
+
stpyvista(plotter)
|
913 |
|
914 |
+
if segment:
|
915 |
+
segmentation_main(obj_path)
|
|
|
916 |
|
917 |
+
|
918 |
+
|
|
|
|
|
919 |
|
920 |
if __name__ == "__main__":
|
921 |
app = Segment()
|