huathedev commited on
Commit
078395a
·
1 Parent(s): 07dae79

Update pages/01_🦷 Segment.py

Browse files
Files changed (1) hide show
  1. pages/01_🦷 Segment.py +217 -194
pages/01_🦷 Segment.py CHANGED
@@ -9,6 +9,7 @@ from pygco import cut_from_graph
9
  import open3d as o3d
10
  import matplotlib.pyplot as plt
11
  import matplotlib.colors as mcolors
 
12
  import json
13
  from stpyvista import stpyvista
14
  import torch
@@ -662,6 +663,201 @@ def obj2pcd(obj_path):
662
  pcd_points = np.array(pcd_list).astype(np.float64)
663
  return pcd_points, jaw
664
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
  # Configure Streamlit page
666
  st.set_page_config(page_title="Teeth Segmentation", page_icon="🦷")
667
 
@@ -681,218 +877,45 @@ class Segment(TeethApp):
681
  )
682
  import pyvista as pv
683
  if inputs == "Example Scan":
 
684
  mesh = pv.read("ZOUIF2W4_upper.obj")
685
  plotter = pv.Plotter()
686
 
687
  # Add the mesh to the plotter
688
- plotter.add_mesh(mesh, color='black', show_edges=True)
689
- visualize = st.button("Segment")
690
- if visualize:
691
  stpyvista(plotter)
692
 
 
 
 
 
 
693
  elif inputs == "Upload Scan":
694
  file = st.file_uploader("Please upload an OBJ Object file", type=["OBJ"])
695
-
696
  if file is not None:
697
  # save the uploaded file to disk
698
  with open("file.obj", "wb") as buffer:
699
  shutil.copyfileobj(file, buffer)
700
  # 复制数据
701
-
702
-
703
  obj_path = "file.obj"
704
- upsampling_method = 'KNN'
705
-
706
- model_path = 'Mesh_Segementation_MeshSegNet_17_classes_60samples_best.tar'
707
- num_classes = 17
708
- num_channels = 15
709
-
710
- # set model
711
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
712
- model = MeshSegNet(num_classes=num_classes, num_channels=num_channels).to(device, dtype=torch.float)
713
-
714
- # load trained model
715
- # checkpoint = torch.load(os.path.join(model_path, model_name), map_location='cpu')
716
- checkpoint = torch.load(model_path, map_location='cpu')
717
- model.load_state_dict(checkpoint['model_state_dict'])
718
- del checkpoint
719
- model = model.to(device, dtype=torch.float)
720
-
721
- # cudnn
722
- torch.backends.cudnn.benchmark = True
723
- torch.backends.cudnn.enabled = True
724
-
725
- # Predicting
726
- model.eval()
727
- with torch.no_grad():
728
- pcd_points, jaw = obj2pcd(obj_path)
729
- mesh = mesh_grid(pcd_points)
730
-
731
- # move mesh to origin
732
- with st.spinner("Patience please, AI at work. Grab a coffee while you wait☕!"):
733
- vertices_points = np.asarray(mesh.vertices)
734
- triangles_points = np.asarray(mesh.triangles)
735
- N = triangles_points.shape[0]
736
- cells = np.zeros((triangles_points.shape[0], 9))
737
- cells = vertices_points[triangles_points].reshape(triangles_points.shape[0], 9)
738
-
739
- mean_cell_centers = mesh.get_center()
740
- cells[:, 0:3] -= mean_cell_centers[0:3]
741
- cells[:, 3:6] -= mean_cell_centers[0:3]
742
- cells[:, 6:9] -= mean_cell_centers[0:3]
743
-
744
- v1 = np.zeros([triangles_points.shape[0], 3], dtype='float32')
745
- v2 = np.zeros([triangles_points.shape[0], 3], dtype='float32')
746
- v1[:, 0] = cells[:, 0] - cells[:, 3]
747
- v1[:, 1] = cells[:, 1] - cells[:, 4]
748
- v1[:, 2] = cells[:, 2] - cells[:, 5]
749
- v2[:, 0] = cells[:, 3] - cells[:, 6]
750
- v2[:, 1] = cells[:, 4] - cells[:, 7]
751
- v2[:, 2] = cells[:, 5] - cells[:, 8]
752
- mesh_normals = np.cross(v1, v2)
753
- mesh_normal_length = np.linalg.norm(mesh_normals, axis=1)
754
- mesh_normals[:, 0] /= mesh_normal_length[:]
755
- mesh_normals[:, 1] /= mesh_normal_length[:]
756
- mesh_normals[:, 2] /= mesh_normal_length[:]
757
-
758
- # prepare input
759
- points = vertices_points.copy()
760
- points[:, 0:3] -= mean_cell_centers[0:3]
761
- normals = np.nan_to_num(mesh_normals).copy()
762
- barycenters = np.zeros((triangles_points.shape[0], 3))
763
- s = np.sum(vertices_points[triangles_points], 1)
764
- barycenters = 1 / 3 * s
765
- center_points = barycenters.copy()
766
- barycenters -= mean_cell_centers[0:3]
767
-
768
- # normalized data
769
- maxs = points.max(axis=0)
770
- mins = points.min(axis=0)
771
- means = points.mean(axis=0)
772
- stds = points.std(axis=0)
773
- nmeans = normals.mean(axis=0)
774
- nstds = normals.std(axis=0)
775
-
776
- for i in range(3):
777
- cells[:, i] = (cells[:, i] - means[i]) / stds[i] # point 1
778
- cells[:, i + 3] = (cells[:, i + 3] - means[i]) / stds[i] # point 2
779
- cells[:, i + 6] = (cells[:, i + 6] - means[i]) / stds[i] # point 3
780
- barycenters[:, i] = (barycenters[:, i] - mins[i]) / (maxs[i] - mins[i])
781
- normals[:, i] = (normals[:, i] - nmeans[i]) / nstds[i]
782
-
783
- X = np.column_stack((cells, barycenters, normals))
784
-
785
- # computing A_S and A_L
786
- A_S = np.zeros([X.shape[0], X.shape[0]], dtype='float32')
787
- A_L = np.zeros([X.shape[0], X.shape[0]], dtype='float32')
788
- D = distance_matrix(X[:, 9:12], X[:, 9:12])
789
- A_S[D < 0.1] = 1.0
790
- A_S = A_S / np.dot(np.sum(A_S, axis=1, keepdims=True), np.ones((1, X.shape[0])))
791
-
792
- A_L[D < 0.2] = 1.0
793
- A_L = A_L / np.dot(np.sum(A_L, axis=1, keepdims=True), np.ones((1, X.shape[0])))
794
-
795
- # numpy -> torch.tensor
796
- X = X.transpose(1, 0)
797
- X = X.reshape([1, X.shape[0], X.shape[1]])
798
- X = torch.from_numpy(X).to(device, dtype=torch.float)
799
- A_S = A_S.reshape([1, A_S.shape[0], A_S.shape[1]])
800
- A_L = A_L.reshape([1, A_L.shape[0], A_L.shape[1]])
801
- A_S = torch.from_numpy(A_S).to(device, dtype=torch.float)
802
- A_L = torch.from_numpy(A_L).to(device, dtype=torch.float)
803
-
804
- tensor_prob_output = model(X, A_S, A_L).to(device, dtype=torch.float)
805
- patch_prob_output = tensor_prob_output.cpu().numpy()
806
-
807
- # refinement
808
- with st.spinner("Refining..."):
809
- round_factor = 100
810
- patch_prob_output[patch_prob_output < 1.0e-6] = 1.0e-6
811
-
812
- # unaries
813
- unaries = -round_factor * np.log10(patch_prob_output)
814
- unaries = unaries.astype(np.int32)
815
- unaries = unaries.reshape(-1, num_classes)
816
-
817
- # parawisex
818
- pairwise = (1 - np.eye(num_classes, dtype=np.int32))
819
-
820
- cells = cells.copy()
821
-
822
- cell_ids = np.asarray(triangles_points)
823
- lambda_c = 20
824
- edges = np.empty([1, 3], order='C')
825
- for i_node in range(cells.shape[0]):
826
- # Find neighbors
827
- nei = np.sum(np.isin(cell_ids, cell_ids[i_node, :]), axis=1)
828
- nei_id = np.where(nei == 2)
829
- for i_nei in nei_id[0][:]:
830
- if i_node < i_nei:
831
- cos_theta = np.dot(normals[i_node, 0:3], normals[i_nei, 0:3]) / np.linalg.norm(
832
- normals[i_node, 0:3]) / np.linalg.norm(normals[i_nei, 0:3])
833
- if cos_theta >= 1.0:
834
- cos_theta = 0.9999
835
- theta = np.arccos(cos_theta)
836
- phi = np.linalg.norm(barycenters[i_node, :] - barycenters[i_nei, :])
837
- if theta > np.pi / 2.0:
838
- edges = np.concatenate(
839
- (edges, np.array([i_node, i_nei, -np.log10(theta / np.pi) * phi]).reshape(1, 3)), axis=0)
840
- else:
841
- beta = 1 + np.linalg.norm(np.dot(normals[i_node, 0:3], normals[i_nei, 0:3]))
842
- edges = np.concatenate(
843
- (edges, np.array([i_node, i_nei, -beta * np.log10(theta / np.pi) * phi]).reshape(1, 3)),
844
- axis=0)
845
- edges = np.delete(edges, 0, 0)
846
- edges[:, 2] *= lambda_c * round_factor
847
- edges = edges.astype(np.int32)
848
-
849
- refine_labels = cut_from_graph(edges, unaries, pairwise)
850
- refine_labels = refine_labels.reshape([-1, 1])
851
-
852
- predicted_labels_3 = refine_labels.reshape(refine_labels.shape[0])
853
- mesh_to_points_main(jaw, pcd_points, center_points, predicted_labels_3)
854
-
855
- import pyvista as pv
856
-
857
- with st.spinner("Rendering..."):
858
- # Load the .obj file
859
- mesh = pv.read('file.obj')
860
-
861
- # Load the JSON file
862
- with open('dental-labels4.json', 'r') as file:
863
- labels_data = json.load(file)
864
-
865
- # Assuming labels_data['labels'] is a list of labels
866
- labels = labels_data['labels']
867
-
868
- # Make sure the number of labels matches the number of vertices or faces
869
- assert len(labels) == mesh.n_points or len(labels) == mesh.n_cells
870
-
871
- # If labels correspond to vertices
872
- if len(labels) == mesh.n_points:
873
- mesh.point_data['Labels'] = labels
874
- # If labels correspond to faces
875
- elif len(labels) == mesh.n_cells:
876
- mesh.cell_data['Labels'] = labels
877
-
878
- # Create a pyvista plotter
879
- plotter = pv.Plotter()
880
-
881
- cmap = plt.cm.get_cmap('jet', 27) # Using a colormap with sufficient distinct colors
882
 
883
- colors = cmap(np.linspace(0, 1, 27)) # Generate colors
 
884
 
885
- # Convert colors to a format acceptable by PyVista
886
- colormap = mcolors.ListedColormap(colors)
 
 
 
887
 
888
- # Add the mesh to the plotter with labels as a scalar field
889
- #plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap='jet')
890
- plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap=colormap, clim=[0, 27])
891
 
892
- # Show the plot
893
- #plotter.show()
894
- ## Send to streamlit
895
- stpyvista(plotter)
896
 
897
  if __name__ == "__main__":
898
  app = Segment()
 
9
  import open3d as o3d
10
  import matplotlib.pyplot as plt
11
  import matplotlib.colors as mcolors
12
+ from stqdm import stqdm
13
  import json
14
  from stpyvista import stpyvista
15
  import torch
 
663
  pcd_points = np.array(pcd_list).astype(np.float64)
664
  return pcd_points, jaw
665
 
666
+
667
+ def segmentation_main(obj_path):
668
+ upsampling_method = 'KNN'
669
+
670
+ model_path = 'Mesh_Segementation_MeshSegNet_17_classes_60samples_best.tar'
671
+ num_classes = 17
672
+ num_channels = 15
673
+
674
+ # set model
675
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
676
+ model = MeshSegNet(num_classes=num_classes, num_channels=num_channels).to(device, dtype=torch.float)
677
+
678
+ # load trained model
679
+ # checkpoint = torch.load(os.path.join(model_path, model_name), map_location='cpu')
680
+ checkpoint = torch.load(model_path, map_location='cpu')
681
+ model.load_state_dict(checkpoint['model_state_dict'])
682
+ del checkpoint
683
+ model = model.to(device, dtype=torch.float)
684
+
685
+ # cudnn
686
+ torch.backends.cudnn.benchmark = True
687
+ torch.backends.cudnn.enabled = True
688
+
689
+ # Predicting
690
+ model.eval()
691
+ with torch.no_grad():
692
+ pcd_points, jaw = obj2pcd(obj_path)
693
+ mesh = mesh_grid(pcd_points)
694
+
695
+ # move mesh to origin
696
+ with st.spinner("Patience please, AI at work. Grab a coffee while you wait ☕."):
697
+ vertices_points = np.asarray(mesh.vertices)
698
+ triangles_points = np.asarray(mesh.triangles)
699
+ N = triangles_points.shape[0]
700
+ cells = np.zeros((triangles_points.shape[0], 9))
701
+ cells = vertices_points[triangles_points].reshape(triangles_points.shape[0], 9)
702
+
703
+ mean_cell_centers = mesh.get_center()
704
+ cells[:, 0:3] -= mean_cell_centers[0:3]
705
+ cells[:, 3:6] -= mean_cell_centers[0:3]
706
+ cells[:, 6:9] -= mean_cell_centers[0:3]
707
+
708
+ v1 = np.zeros([triangles_points.shape[0], 3], dtype='float32')
709
+ v2 = np.zeros([triangles_points.shape[0], 3], dtype='float32')
710
+ v1[:, 0] = cells[:, 0] - cells[:, 3]
711
+ v1[:, 1] = cells[:, 1] - cells[:, 4]
712
+ v1[:, 2] = cells[:, 2] - cells[:, 5]
713
+ v2[:, 0] = cells[:, 3] - cells[:, 6]
714
+ v2[:, 1] = cells[:, 4] - cells[:, 7]
715
+ v2[:, 2] = cells[:, 5] - cells[:, 8]
716
+ mesh_normals = np.cross(v1, v2)
717
+ mesh_normal_length = np.linalg.norm(mesh_normals, axis=1)
718
+ mesh_normals[:, 0] /= mesh_normal_length[:]
719
+ mesh_normals[:, 1] /= mesh_normal_length[:]
720
+ mesh_normals[:, 2] /= mesh_normal_length[:]
721
+
722
+ # prepare input
723
+ points = vertices_points.copy()
724
+ points[:, 0:3] -= mean_cell_centers[0:3]
725
+ normals = np.nan_to_num(mesh_normals).copy()
726
+ barycenters = np.zeros((triangles_points.shape[0], 3))
727
+ s = np.sum(vertices_points[triangles_points], 1)
728
+ barycenters = 1 / 3 * s
729
+ center_points = barycenters.copy()
730
+ barycenters -= mean_cell_centers[0:3]
731
+
732
+ # normalized data
733
+ maxs = points.max(axis=0)
734
+ mins = points.min(axis=0)
735
+ means = points.mean(axis=0)
736
+ stds = points.std(axis=0)
737
+ nmeans = normals.mean(axis=0)
738
+ nstds = normals.std(axis=0)
739
+
740
+ for i in range(3):
741
+ cells[:, i] = (cells[:, i] - means[i]) / stds[i] # point 1
742
+ cells[:, i + 3] = (cells[:, i + 3] - means[i]) / stds[i] # point 2
743
+ cells[:, i + 6] = (cells[:, i + 6] - means[i]) / stds[i] # point 3
744
+ barycenters[:, i] = (barycenters[:, i] - mins[i]) / (maxs[i] - mins[i])
745
+ normals[:, i] = (normals[:, i] - nmeans[i]) / nstds[i]
746
+
747
+ X = np.column_stack((cells, barycenters, normals))
748
+
749
+ # computing A_S and A_L
750
+ A_S = np.zeros([X.shape[0], X.shape[0]], dtype='float32')
751
+ A_L = np.zeros([X.shape[0], X.shape[0]], dtype='float32')
752
+ D = distance_matrix(X[:, 9:12], X[:, 9:12])
753
+ A_S[D < 0.1] = 1.0
754
+ A_S = A_S / np.dot(np.sum(A_S, axis=1, keepdims=True), np.ones((1, X.shape[0])))
755
+
756
+ A_L[D < 0.2] = 1.0
757
+ A_L = A_L / np.dot(np.sum(A_L, axis=1, keepdims=True), np.ones((1, X.shape[0])))
758
+
759
+ # numpy -> torch.tensor
760
+ X = X.transpose(1, 0)
761
+ X = X.reshape([1, X.shape[0], X.shape[1]])
762
+ X = torch.from_numpy(X).to(device, dtype=torch.float)
763
+ A_S = A_S.reshape([1, A_S.shape[0], A_S.shape[1]])
764
+ A_L = A_L.reshape([1, A_L.shape[0], A_L.shape[1]])
765
+ A_S = torch.from_numpy(A_S).to(device, dtype=torch.float)
766
+ A_L = torch.from_numpy(A_L).to(device, dtype=torch.float)
767
+
768
+ tensor_prob_output = model(X, A_S, A_L).to(device, dtype=torch.float)
769
+ patch_prob_output = tensor_prob_output.cpu().numpy()
770
+
771
+ # refinement
772
+ with st.spinner("Refining..."):
773
+ round_factor = 100
774
+ patch_prob_output[patch_prob_output < 1.0e-6] = 1.0e-6
775
+
776
+ # unaries
777
+ unaries = -round_factor * np.log10(patch_prob_output)
778
+ unaries = unaries.astype(np.int32)
779
+ unaries = unaries.reshape(-1, num_classes)
780
+
781
+ # parawisex
782
+ pairwise = (1 - np.eye(num_classes, dtype=np.int32))
783
+
784
+ cells = cells.copy()
785
+
786
+ cell_ids = np.asarray(triangles_points)
787
+ lambda_c = 20
788
+ edges = np.empty([1, 3], order='C')
789
+ for i_node in stqdm(range(cells.shape[0])):
790
+ # Find neighbors
791
+ nei = np.sum(np.isin(cell_ids, cell_ids[i_node, :]), axis=1)
792
+ nei_id = np.where(nei == 2)
793
+ for i_nei in nei_id[0][:]:
794
+ if i_node < i_nei:
795
+ cos_theta = np.dot(normals[i_node, 0:3], normals[i_nei, 0:3]) / np.linalg.norm(
796
+ normals[i_node, 0:3]) / np.linalg.norm(normals[i_nei, 0:3])
797
+ if cos_theta >= 1.0:
798
+ cos_theta = 0.9999
799
+ theta = np.arccos(cos_theta)
800
+ phi = np.linalg.norm(barycenters[i_node, :] - barycenters[i_nei, :])
801
+ if theta > np.pi / 2.0:
802
+ edges = np.concatenate(
803
+ (edges, np.array([i_node, i_nei, -np.log10(theta / np.pi) * phi]).reshape(1, 3)), axis=0)
804
+ else:
805
+ beta = 1 + np.linalg.norm(np.dot(normals[i_node, 0:3], normals[i_nei, 0:3]))
806
+ edges = np.concatenate(
807
+ (edges, np.array([i_node, i_nei, -beta * np.log10(theta / np.pi) * phi]).reshape(1, 3)),
808
+ axis=0)
809
+ edges = np.delete(edges, 0, 0)
810
+ edges[:, 2] *= lambda_c * round_factor
811
+ edges = edges.astype(np.int32)
812
+
813
+ refine_labels = cut_from_graph(edges, unaries, pairwise)
814
+ refine_labels = refine_labels.reshape([-1, 1])
815
+
816
+ predicted_labels_3 = refine_labels.reshape(refine_labels.shape[0])
817
+ mesh_to_points_main(jaw, pcd_points, center_points, predicted_labels_3)
818
+
819
+ import pyvista as pv
820
+ with st.spinner("Rendering..."):
821
+ # Load the .obj file
822
+ mesh = pv.read('file.obj')
823
+
824
+ # Load the JSON file
825
+ with open('dental-labels4.json', 'r') as file:
826
+ labels_data = json.load(file)
827
+
828
+ # Assuming labels_data['labels'] is a list of labels
829
+ labels = labels_data['labels']
830
+
831
+ # Make sure the number of labels matches the number of vertices or faces
832
+ assert len(labels) == mesh.n_points or len(labels) == mesh.n_cells
833
+
834
+ # If labels correspond to vertices
835
+ if len(labels) == mesh.n_points:
836
+ mesh.point_data['Labels'] = labels
837
+ # If labels correspond to faces
838
+ elif len(labels) == mesh.n_cells:
839
+ mesh.cell_data['Labels'] = labels
840
+
841
+ # Create a pyvista plotter
842
+ plotter = pv.Plotter()
843
+
844
+ cmap = plt.cm.get_cmap('jet', 27) # Using a colormap with sufficient distinct colors
845
+
846
+ colors = cmap(np.linspace(0, 1, 27)) # Generate colors
847
+
848
+ # Convert colors to a format acceptable by PyVista
849
+ colormap = mcolors.ListedColormap(colors)
850
+
851
+ # Add the mesh to the plotter with labels as a scalar field
852
+ #plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap='jet')
853
+ plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap=colormap, clim=[0, 27])
854
+
855
+ # Show the plot
856
+ #plotter.show()
857
+ ## Send to streamlit
858
+ with st.expander("View Segmentation Result", expanded=False):
859
+ stpyvista(plotter)
860
+
861
  # Configure Streamlit page
862
  st.set_page_config(page_title="Teeth Segmentation", page_icon="🦷")
863
 
 
877
  )
878
  import pyvista as pv
879
  if inputs == "Example Scan":
880
+ st.markdown("Expected time per prediction: 7-10 min.")
881
  mesh = pv.read("ZOUIF2W4_upper.obj")
882
  plotter = pv.Plotter()
883
 
884
  # Add the mesh to the plotter
885
+ plotter.add_mesh(mesh, color='white', show_edges=True)
886
+ segment = st.button("Apply Segmentation")
887
+ with st.expander("View Scan", expanded=False):
888
  stpyvista(plotter)
889
 
890
+ if segment:
891
+ segmentation_main("ZOUIF2W4_upper.obj")
892
+
893
+
894
+
895
  elif inputs == "Upload Scan":
896
  file = st.file_uploader("Please upload an OBJ Object file", type=["OBJ"])
897
+ st.markdown("Expected time per prediction: 7-10 min.")
898
  if file is not None:
899
  # save the uploaded file to disk
900
  with open("file.obj", "wb") as buffer:
901
  shutil.copyfileobj(file, buffer)
902
  # 复制数据
 
 
903
  obj_path = "file.obj"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904
 
905
+ mesh = pv.read(obj_path)
906
+ plotter = pv.Plotter()
907
 
908
+ # Add the mesh to the plotter
909
+ plotter.add_mesh(mesh, color='white', show_edges=True)
910
+ segment = st.button("Apply Segmentation")
911
+ with st.expander("View Scan", expanded=False):
912
+ stpyvista(plotter)
913
 
914
+ if segment:
915
+ segmentation_main(obj_path)
 
916
 
917
+
918
+
 
 
919
 
920
  if __name__ == "__main__":
921
  app = Segment()