huathedev commited on
Commit
cbb3c7c
·
1 Parent(s): 759487d

Delete apps

Browse files
apps/.DS_Store DELETED
Binary file (6.15 kB)
 
apps/demo/.DS_Store DELETED
Binary file (6.15 kB)
 
apps/demo/Mesh_Segementation_MeshSegNet_17_classes_60samples_best.tar DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3d2e44db8865ff3968803e86dadcf73cf9c4b738ddc35bfb3bc42c02347d7a0c
3
- size 28825987
 
 
 
 
apps/demo/ZOUIF2W4_upper.obj DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:581b9a026e2ce734f6335f34aa900e8114dc33e2a83541ebd6bb26536382545e
3
- size 18769177
 
 
 
 
apps/demo/file.obj DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:581b9a026e2ce734f6335f34aa900e8114dc33e2a83541ebd6bb26536382545e
3
- size 18769177
 
 
 
 
apps/demo/illu.png DELETED

Git LFS Details

  • SHA256: bd252730ed2050dc7ab0cce85546a3db5fc8d3162d5d7512970f93aa63fb2c74
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
apps/demo/pages/01_🦷 Segment.py DELETED
@@ -1,898 +0,0 @@
1
- from streamlit import session_state as session
2
- import shutil
3
-
4
- import os
5
- import numpy as np
6
- from sklearn import neighbors
7
- from scipy.spatial import distance_matrix
8
- from pygco import cut_from_graph
9
- import open3d as o3d
10
- import matplotlib.pyplot as plt
11
- import matplotlib.colors as mcolors
12
- import json
13
- from stpyvista import stpyvista
14
- import torch
15
- import torch.nn as nn
16
- from torch.autograd import Variable
17
- import torch.nn.functional as F
18
- import streamlit as st
19
- import pyvista as pv
20
-
21
- from PIL import Image
22
-
23
- class TeethApp:
24
- def __init__(self):
25
- # Font
26
- with open("utils/style.css") as css:
27
- st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)
28
-
29
- # Logo
30
- self.image_path = "utils/teeth-295404_1280.png"
31
- self.image = Image.open(self.image_path)
32
- width, height = self.image.size
33
- scale = 12
34
- new_width, new_height = width / scale, height / scale
35
- self.image = self.image.resize((int(new_width), int(new_height)))
36
-
37
- # Streamlit side navigation bar
38
- st.sidebar.markdown("# AI ToothSeg")
39
- st.sidebar.markdown("Automatic teeth segmentation with Deep Learning")
40
- st.sidebar.markdown(" ")
41
- st.sidebar.image(self.image, use_column_width=False)
42
- st.markdown(
43
- """
44
- <style>
45
- .css-1bxukto {
46
- background-color: rgb(255, 255, 255) ;""",
47
- unsafe_allow_html=True,
48
- )
49
-
50
-
51
- class STN3d(nn.Module):
52
- def __init__(self, channel):
53
- super(STN3d, self).__init__()
54
- self.conv1 = torch.nn.Conv1d(channel, 64, 1)
55
- self.conv2 = torch.nn.Conv1d(64, 128, 1)
56
- self.conv3 = torch.nn.Conv1d(128, 1024, 1)
57
- self.fc1 = nn.Linear(1024, 512)
58
- self.fc2 = nn.Linear(512, 256)
59
- self.fc3 = nn.Linear(256, 9)
60
- self.relu = nn.ReLU()
61
-
62
- self.bn1 = nn.BatchNorm1d(64)
63
- self.bn2 = nn.BatchNorm1d(128)
64
- self.bn3 = nn.BatchNorm1d(1024)
65
- self.bn4 = nn.BatchNorm1d(512)
66
- self.bn5 = nn.BatchNorm1d(256)
67
-
68
- def forward(self, x):
69
- batchsize = x.size()[0]
70
- x = F.relu(self.bn1(self.conv1(x)))
71
- x = F.relu(self.bn2(self.conv2(x)))
72
- x = F.relu(self.bn3(self.conv3(x)))
73
- x = torch.max(x, 2, keepdim=True)[0]
74
- x = x.view(-1, 1024)
75
-
76
- x = F.relu(self.bn4(self.fc1(x)))
77
- x = F.relu(self.bn5(self.fc2(x)))
78
- x = self.fc3(x)
79
-
80
- iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(
81
- batchsize, 1)
82
- if x.is_cuda:
83
- iden = iden.to(x.get_device())
84
- x = x + iden
85
- x = x.view(-1, 3, 3)
86
- return x
87
-
88
- class STNkd(nn.Module):
89
- def __init__(self, k=64):
90
- super(STNkd, self).__init__()
91
- self.conv1 = torch.nn.Conv1d(k, 64, 1)
92
- self.conv2 = torch.nn.Conv1d(64, 128, 1)
93
- self.conv3 = torch.nn.Conv1d(128, 512, 1)
94
- self.fc1 = nn.Linear(512, 256)
95
- self.fc2 = nn.Linear(256, 128)
96
- self.fc3 = nn.Linear(128, k * k)
97
- self.relu = nn.ReLU()
98
-
99
- self.bn1 = nn.BatchNorm1d(64)
100
- self.bn2 = nn.BatchNorm1d(128)
101
- self.bn3 = nn.BatchNorm1d(512)
102
- self.bn4 = nn.BatchNorm1d(256)
103
- self.bn5 = nn.BatchNorm1d(128)
104
-
105
- self.k = k
106
-
107
- def forward(self, x):
108
- batchsize = x.size()[0]
109
- x = F.relu(self.bn1(self.conv1(x)))
110
- x = F.relu(self.bn2(self.conv2(x)))
111
- x = F.relu(self.bn3(self.conv3(x)))
112
- x = torch.max(x, 2, keepdim=True)[0]
113
- x = x.view(-1, 512)
114
-
115
- x = F.relu(self.bn4(self.fc1(x)))
116
- x = F.relu(self.bn5(self.fc2(x)))
117
- x = self.fc3(x)
118
-
119
- iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat(
120
- batchsize, 1)
121
- if x.is_cuda:
122
- iden = iden.to(x.get_device())
123
- x = x + iden
124
- x = x.view(-1, self.k, self.k)
125
- return x
126
-
127
- class MeshSegNet(nn.Module):
128
- def __init__(self, num_classes=17, num_channels=15, with_dropout=True, dropout_p=0.5):
129
- super(MeshSegNet, self).__init__()
130
- self.num_classes = num_classes
131
- self.num_channels = num_channels
132
- self.with_dropout = with_dropout
133
- self.dropout_p = dropout_p
134
-
135
- # MLP-1 [64, 64]
136
- self.mlp1_conv1 = torch.nn.Conv1d(self.num_channels, 64, 1)
137
- self.mlp1_conv2 = torch.nn.Conv1d(64, 64, 1)
138
- self.mlp1_bn1 = nn.BatchNorm1d(64)
139
- self.mlp1_bn2 = nn.BatchNorm1d(64)
140
- # FTM (feature-transformer module)
141
- self.fstn = STNkd(k=64)
142
- # GLM-1 (graph-contrained learning modulus)
143
- self.glm1_conv1_1 = torch.nn.Conv1d(64, 32, 1)
144
- self.glm1_conv1_2 = torch.nn.Conv1d(64, 32, 1)
145
- self.glm1_bn1_1 = nn.BatchNorm1d(32)
146
- self.glm1_bn1_2 = nn.BatchNorm1d(32)
147
- self.glm1_conv2 = torch.nn.Conv1d(32+32, 64, 1)
148
- self.glm1_bn2 = nn.BatchNorm1d(64)
149
- # MLP-2
150
- self.mlp2_conv1 = torch.nn.Conv1d(64, 64, 1)
151
- self.mlp2_bn1 = nn.BatchNorm1d(64)
152
- self.mlp2_conv2 = torch.nn.Conv1d(64, 128, 1)
153
- self.mlp2_bn2 = nn.BatchNorm1d(128)
154
- self.mlp2_conv3 = torch.nn.Conv1d(128, 512, 1)
155
- self.mlp2_bn3 = nn.BatchNorm1d(512)
156
- # GLM-2 (graph-contrained learning modulus)
157
- self.glm2_conv1_1 = torch.nn.Conv1d(512, 128, 1)
158
- self.glm2_conv1_2 = torch.nn.Conv1d(512, 128, 1)
159
- self.glm2_conv1_3 = torch.nn.Conv1d(512, 128, 1)
160
- self.glm2_bn1_1 = nn.BatchNorm1d(128)
161
- self.glm2_bn1_2 = nn.BatchNorm1d(128)
162
- self.glm2_bn1_3 = nn.BatchNorm1d(128)
163
- self.glm2_conv2 = torch.nn.Conv1d(128*3, 512, 1)
164
- self.glm2_bn2 = nn.BatchNorm1d(512)
165
- # MLP-3
166
- self.mlp3_conv1 = torch.nn.Conv1d(64+512+512+512, 256, 1)
167
- self.mlp3_conv2 = torch.nn.Conv1d(256, 256, 1)
168
- self.mlp3_bn1_1 = nn.BatchNorm1d(256)
169
- self.mlp3_bn1_2 = nn.BatchNorm1d(256)
170
- self.mlp3_conv3 = torch.nn.Conv1d(256, 128, 1)
171
- self.mlp3_conv4 = torch.nn.Conv1d(128, 128, 1)
172
- self.mlp3_bn2_1 = nn.BatchNorm1d(128)
173
- self.mlp3_bn2_2 = nn.BatchNorm1d(128)
174
- # output
175
- self.output_conv = torch.nn.Conv1d(128, self.num_classes, 1)
176
- if self.with_dropout:
177
- self.dropout = nn.Dropout(p=self.dropout_p)
178
-
179
- def forward(self, x, a_s, a_l):
180
- batchsize = x.size()[0]
181
- n_pts = x.size()[2]
182
- # MLP-1
183
- x = F.relu(self.mlp1_bn1(self.mlp1_conv1(x)))
184
- x = F.relu(self.mlp1_bn2(self.mlp1_conv2(x)))
185
- # FTM
186
- trans_feat = self.fstn(x)
187
- x = x.transpose(2, 1)
188
- x_ftm = torch.bmm(x, trans_feat)
189
- # GLM-1
190
- sap = torch.bmm(a_s, x_ftm)
191
- sap = sap.transpose(2, 1)
192
- x_ftm = x_ftm.transpose(2, 1)
193
- x = F.relu(self.glm1_bn1_1(self.glm1_conv1_1(x_ftm)))
194
- glm_1_sap = F.relu(self.glm1_bn1_2(self.glm1_conv1_2(sap)))
195
- x = torch.cat([x, glm_1_sap], dim=1)
196
- x = F.relu(self.glm1_bn2(self.glm1_conv2(x)))
197
- # MLP-2
198
- x = F.relu(self.mlp2_bn1(self.mlp2_conv1(x)))
199
- x = F.relu(self.mlp2_bn2(self.mlp2_conv2(x)))
200
- x_mlp2 = F.relu(self.mlp2_bn3(self.mlp2_conv3(x)))
201
- if self.with_dropout:
202
- x_mlp2 = self.dropout(x_mlp2)
203
- # GLM-2
204
- x_mlp2 = x_mlp2.transpose(2, 1)
205
- sap_1 = torch.bmm(a_s, x_mlp2)
206
- sap_2 = torch.bmm(a_l, x_mlp2)
207
- x_mlp2 = x_mlp2.transpose(2, 1)
208
- sap_1 = sap_1.transpose(2, 1)
209
- sap_2 = sap_2.transpose(2, 1)
210
- x = F.relu(self.glm2_bn1_1(self.glm2_conv1_1(x_mlp2)))
211
- glm_2_sap_1 = F.relu(self.glm2_bn1_2(self.glm2_conv1_2(sap_1)))
212
- glm_2_sap_2 = F.relu(self.glm2_bn1_3(self.glm2_conv1_3(sap_2)))
213
- x = torch.cat([x, glm_2_sap_1, glm_2_sap_2], dim=1)
214
- x_glm2 = F.relu(self.glm2_bn2(self.glm2_conv2(x)))
215
- # GMP
216
- x = torch.max(x_glm2, 2, keepdim=True)[0]
217
- # Upsample
218
- x = torch.nn.Upsample(n_pts)(x)
219
- # Dense fusion
220
- x = torch.cat([x, x_ftm, x_mlp2, x_glm2], dim=1)
221
- # MLP-3
222
- x = F.relu(self.mlp3_bn1_1(self.mlp3_conv1(x)))
223
- x = F.relu(self.mlp3_bn1_2(self.mlp3_conv2(x)))
224
- x = F.relu(self.mlp3_bn2_1(self.mlp3_conv3(x)))
225
- if self.with_dropout:
226
- x = self.dropout(x)
227
- x = F.relu(self.mlp3_bn2_2(self.mlp3_conv4(x)))
228
- # output
229
- x = self.output_conv(x)
230
- x = x.transpose(2,1).contiguous()
231
- x = torch.nn.Softmax(dim=-1)(x.view(-1, self.num_classes))
232
- x = x.view(batchsize, n_pts, self.num_classes)
233
-
234
- return x
235
-
236
- def clone_runoob(li1):
237
- li_copy = li1[:]
238
- return li_copy
239
-
240
- # 对离群点重新进行分类
241
- def class_inlier_outlier(label_list, mean_points,cloud, ind, label_index, points, labels):
242
- label_change = clone_runoob(labels)
243
- outlier_index = clone_runoob(label_index)
244
- ind_reverse = clone_runoob(ind)
245
- # 得到离群点的label下标
246
- ind_reverse.reverse()
247
- for i in ind_reverse:
248
- outlier_index.pop(i)
249
-
250
- # 获取离群点
251
- inlier_cloud = cloud.select_by_index(ind)
252
- outlier_cloud = cloud.select_by_index(ind, invert=True)
253
- outlier_points = np.array(outlier_cloud.points)
254
-
255
- for i in range(len(outlier_points)):
256
- distance = []
257
- for j in range(len(mean_points)):
258
- dis = np.linalg.norm(outlier_points[i] - mean_points[j], ord=2) # 计算tooth和GT质心之间的距离
259
- distance.append(dis)
260
- min_index = distance.index(min(distance)) # 获取和离群点质心最近label的index
261
- outlier_label = label_list[min_index] # 获取离群点应该的label
262
- index = outlier_index[i]
263
- label_change[index] = outlier_label
264
-
265
- return label_change
266
-
267
- # 利用knn算法消除离群点
268
- def remove_outlier(points, labels):
269
- # points = np.array(point_cloud_o3d_orign.points)
270
- # global label_list
271
- same_label_points = {}
272
-
273
- same_label_index = {}
274
-
275
- mean_points = [] # 所有label种类对应点云的质心坐标
276
-
277
- label_list = []
278
- for i in range(len(labels)):
279
- label_list.append(labels[i])
280
- label_list = list(set(label_list)) # 去重获从小到大排序取GT_label=[0, 11, 12, 13, 14, 15, 16, 17, 21, 22, 23, 24, 25, 26, 27]
281
- label_list.sort()
282
- label_list = label_list[1:]
283
-
284
- for i in label_list:
285
- key = i
286
- points_list = []
287
- all_label_index = []
288
- for j in range(len(labels)):
289
- if labels[j] == i:
290
- points_list.append(points[j].tolist())
291
- all_label_index.append(j) # 得到label为 i 的点对应的label的下标
292
- same_label_points[key] = points_list
293
- same_label_index[key] = all_label_index
294
-
295
- tooth_mean = np.mean(points_list, axis=0)
296
- mean_points.append(tooth_mean)
297
- # print(mean_points)
298
-
299
- for i in label_list:
300
- points_array = same_label_points[i]
301
- # 建立一个o3d的点云对象
302
- pcd = o3d.geometry.PointCloud()
303
- # 使用Vector3dVector方法转换
304
- pcd.points = o3d.utility.Vector3dVector(points_array)
305
-
306
- # 对label i 对应的点云进行统计离群值去除,找出离群点并显示
307
- # 统计式离群点移除
308
- cl, ind = pcd.remove_statistical_outlier(nb_neighbors=200, std_ratio=2.0) # cl是选中的点,ind是选中点index
309
- # 可视化
310
- # display_inlier_outlier(pcd, ind)
311
-
312
- # 对分出来的离群点重新分类
313
- label_index = same_label_index[i]
314
- labels = class_inlier_outlier(label_list, mean_points, pcd, ind, label_index, points, labels)
315
- # print(f"label_change{labels[4400]}")
316
-
317
- return labels
318
-
319
-
320
- # 消除离群点,保存最后的输出
321
- def remove_outlier_main(jaw, pcd_points, labels, instances_labels):
322
- # point_cloud_o3d_orign = o3d.io.read_point_cloud('E:/tooth/data/MeshSegNet-master/test_upsample_15/upsample_01K17AN8_upper_refined.pcd')
323
- # 原始点
324
- points = pcd_points.copy()
325
- label = remove_outlier(points, labels)
326
-
327
- # 保存json文件
328
- label_dict = {}
329
- label_dict["id_patient"] = ""
330
- label_dict["jaw"] = jaw
331
- label_dict["labels"] = label.tolist()
332
- label_dict["instances"] = instances_labels.tolist()
333
- b = json.dumps(label_dict)
334
- with open('dental-labels4' + '.json', 'w') as f_obj:
335
- f_obj.write(b)
336
- f_obj.close()
337
-
338
-
339
- same_points_list = {}
340
-
341
-
342
- # 体素下采样
343
- def voxel_filter(point_cloud, leaf_size):
344
- same_points_list = {}
345
- filtered_points = []
346
- # step1 计算边界点
347
- x_max, y_max, z_max = np.amax(point_cloud, axis=0) # 计算 x,y,z三个维度的最值
348
- x_min, y_min, z_min = np.amin(point_cloud, axis=0)
349
-
350
- # step2 确定体素的尺寸
351
- size_r = leaf_size
352
-
353
- # step3 计算每个 volex的维度 voxel grid
354
- Dx = (x_max - x_min) // size_r + 1
355
- Dy = (y_max - y_min) // size_r + 1
356
- Dz = (z_max - z_min) // size_r + 1
357
-
358
- # print("Dx x Dy x Dz is {} x {} x {}".format(Dx, Dy, Dz))
359
-
360
- # step4 计算每个点在volex grid内每一个维度的值
361
- h = list() # h 为保存索引的列表
362
- for i in range(len(point_cloud)):
363
- hx = np.floor((point_cloud[i][0] - x_min) // size_r)
364
- hy = np.floor((point_cloud[i][1] - y_min) // size_r)
365
- hz = np.floor((point_cloud[i][2] - z_min) // size_r)
366
- h.append(hx + hy * Dx + hz * Dx * Dy)
367
- # print(h[60581])
368
-
369
- # step5 对h值进行排序
370
- h = np.array(h)
371
- h_indice = np.argsort(h) # 提取索引,返回h里面的元素按从小到大排序的 索引
372
- h_sorted = h[h_indice] # 升序
373
- count = 0 # 用于维度的累计
374
- step = 20
375
- # 将h值相同的点放入到同一个grid中,并进行筛选
376
- for i in range(1, len(h_sorted)): # 0-19999个数据点
377
- # if i == len(h_sorted)-1:
378
- # print("aaa")
379
- if h_sorted[i] == h_sorted[i - 1] and (i != len(h_sorted) - 1):
380
- continue
381
- elif h_sorted[i] == h_sorted[i - 1] and (i == len(h_sorted) - 1):
382
- point_idx = h_indice[count:]
383
- key = h_sorted[i - 1]
384
- same_points_list[key] = point_idx
385
- _G = np.mean(point_cloud[point_idx], axis=0) # 所有点的重心
386
- _d = np.linalg.norm(point_cloud[point_idx] - _G, axis=1, ord=2) # 计算到重心的距离
387
- _d.sort()
388
- inx = [j for j in range(0, len(_d), step)] # 获取指定间隔元素下标
389
- for j in inx:
390
- index = point_idx[j]
391
- filtered_points.append(point_cloud[index])
392
- count = i
393
- elif h_sorted[i] != h_sorted[i - 1] and (i == len(h_sorted) - 1):
394
- point_idx1 = h_indice[count:i]
395
- key1 = h_sorted[i - 1]
396
- same_points_list[key1] = point_idx1
397
- _G = np.mean(point_cloud[point_idx1], axis=0) # 所有点的重心
398
- _d = np.linalg.norm(point_cloud[point_idx1] - _G, axis=1, ord=2) # 计算到重心的距离
399
- _d.sort()
400
- inx = [j for j in range(0, len(_d), step)] # 获取��定间隔元素下标
401
- for j in inx:
402
- index = point_idx1[j]
403
- filtered_points.append(point_cloud[index])
404
-
405
- point_idx2 = h_indice[i:]
406
- key2 = h_sorted[i]
407
- same_points_list[key2] = point_idx2
408
- _G = np.mean(point_cloud[point_idx2], axis=0) # 所有点的重心
409
- _d = np.linalg.norm(point_cloud[point_idx2] - _G, axis=1, ord=2) # 计算到重心的距离
410
- _d.sort()
411
- inx = [j for j in range(0, len(_d), step)] # 获取指定间隔元素下标
412
- for j in inx:
413
- index = point_idx2[j]
414
- filtered_points.append(point_cloud[index])
415
- count = i
416
-
417
- else:
418
- point_idx = h_indice[count: i]
419
- key = h_sorted[i - 1]
420
- same_points_list[key] = point_idx
421
- _G = np.mean(point_cloud[point_idx], axis=0) # 所有点的重心
422
- _d = np.linalg.norm(point_cloud[point_idx] - _G, axis=1, ord=2) # 计算到重心的距离
423
- _d.sort()
424
- inx = [j for j in range(0, len(_d), step)] # 获取指定间隔元素下标
425
- for j in inx:
426
- index = point_idx[j]
427
- filtered_points.append(point_cloud[index])
428
- count = i
429
-
430
- # 把点云格式改成array,并对外返回
431
- # print(f'filtered_points[0]为{filtered_points[0]}')
432
- filtered_points = np.array(filtered_points, dtype=np.float64)
433
- return filtered_points,same_points_list
434
-
435
-
436
- # 体素上采样
437
- def voxel_upsample(same_points_list, point_cloud, filtered_points, filter_labels, leaf_size):
438
- upsample_label = []
439
- upsample_point = []
440
- upsample_index = []
441
- # step1 计算边界点
442
- x_max, y_max, z_max = np.amax(point_cloud, axis=0) # 计算 x,y,z三个维度的最值
443
- x_min, y_min, z_min = np.amin(point_cloud, axis=0)
444
- # step2 确定体素的尺寸
445
- size_r = leaf_size
446
- # step3 计算每个 volex的维度 voxel grid
447
- Dx = (x_max - x_min) // size_r + 1
448
- Dy = (y_max - y_min) // size_r + 1
449
- Dz = (z_max - z_min) // size_r + 1
450
- print("Dx x Dy x Dz is {} x {} x {}".format(Dx, Dy, Dz))
451
-
452
- # step4 计算每个点(采样后的点)在volex grid内每一个维度的值
453
- h = list()
454
- for i in range(len(filtered_points)):
455
- hx = np.floor((filtered_points[i][0] - x_min) // size_r)
456
- hy = np.floor((filtered_points[i][1] - y_min) // size_r)
457
- hz = np.floor((filtered_points[i][2] - z_min) // size_r)
458
- h.append(hx + hy * Dx + hz * Dx * Dy)
459
-
460
- # step5 根据h值查询字典same_points_list
461
- h = np.array(h)
462
- count = 0
463
- for i in range(1, len(h)):
464
- if h[i] == h[i - 1] and i != (len(h) - 1):
465
- continue
466
- elif h[i] == h[i - 1] and i == (len(h) - 1):
467
- label = filter_labels[count:]
468
- key = h[i - 1]
469
- count = i
470
- # 累计label次数,classcount:{‘A’:2,'B':1}
471
- classcount = {}
472
- for i in range(len(label)):
473
- vote = label[i]
474
- classcount[vote] = classcount.get(vote, 0) + 1
475
- # 对map的value排序
476
- sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
477
- # key = h[i-1]
478
- point_index = same_points_list[key] # h对应的point index列表
479
- for j in range(len(point_index)):
480
- upsample_label.append(sortedclass[0][0])
481
- index = point_index[j]
482
- upsample_point.append(point_cloud[index])
483
- upsample_index.append(index)
484
- elif h[i] != h[i - 1] and (i == len(h) - 1):
485
- label1 = filter_labels[count:i]
486
- key1 = h[i - 1]
487
- label2 = filter_labels[i:]
488
- key2 = h[i]
489
- count = i
490
-
491
- classcount = {}
492
- for i in range(len(label1)):
493
- vote = label1[i]
494
- classcount[vote] = classcount.get(vote, 0) + 1
495
- sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
496
- # key1 = h[i-1]
497
- point_index = same_points_list[key1]
498
- for j in range(len(point_index)):
499
- upsample_label.append(sortedclass[0][0])
500
- index = point_index[j]
501
- upsample_point.append(point_cloud[index])
502
- upsample_index.append(index)
503
-
504
- # label2 = filter_labels[i:]
505
- classcount = {}
506
- for i in range(len(label2)):
507
- vote = label2[i]
508
- classcount[vote] = classcount.get(vote, 0) + 1
509
- sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
510
- # key2 = h[i]
511
- point_index = same_points_list[key2]
512
- for j in range(len(point_index)):
513
- upsample_label.append(sortedclass[0][0])
514
- index = point_index[j]
515
- upsample_point.append(point_cloud[index])
516
- upsample_index.append(index)
517
- else:
518
- label = filter_labels[count:i]
519
- key = h[i - 1]
520
- count = i
521
- classcount = {}
522
- for i in range(len(label)):
523
- vote = label[i]
524
- classcount[vote] = classcount.get(vote, 0) + 1
525
- sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
526
- # key = h[i-1]
527
- point_index = same_points_list[key] # h对应的point index列表
528
- for j in range(len(point_index)):
529
- upsample_label.append(sortedclass[0][0])
530
- index = point_index[j]
531
- upsample_point.append(point_cloud[index])
532
- upsample_index.append(index)
533
- # count = i
534
-
535
- # 恢复原始顺序
536
- # print(f'upsample_index[0]的值为{upsample_index[0]}')
537
- # print(f'upsample_index的总长度为{len(upsample_index)}')
538
-
539
- # 恢复index原始顺序
540
- upsample_index = np.array(upsample_index)
541
- upsample_index_indice = np.argsort(upsample_index) # 提取索引,返回h里面的元素按从小到大排序的 索引
542
- upsample_index_sorted = upsample_index[upsample_index_indice]
543
-
544
- upsample_point = np.array(upsample_point)
545
- upsample_label = np.array(upsample_label)
546
- # 恢复point和label的原始顺序
547
- upsample_point_sorted = upsample_point[upsample_index_indice]
548
- upsample_label_sorted = upsample_label[upsample_index_indice]
549
-
550
- return upsample_point_sorted, upsample_label_sorted
551
-
552
-
553
- # 利用knn算法上采样
554
- def KNN_sklearn_Load_data(voxel_points, center_points, labels):
555
- # 载入数据
556
- # x_train, x_test, y_train, y_test = train_test_split(center_points, labels, test_size=0.1)
557
- # 构建模型
558
- model = neighbors.KNeighborsClassifier(n_neighbors=3)
559
- model.fit(center_points, labels)
560
- prediction = model.predict(voxel_points.reshape(1, -1))
561
- # meshtopoints_labels = classification_report(voxel_points, prediction)
562
- return prediction[0]
563
-
564
-
565
- # 加载点进行knn上采样
566
- def Load_data(voxel_points, center_points, labels):
567
- meshtopoints_labels = []
568
- # meshtopoints_labels.append(SVC_sklearn_Load_data(voxel_points[i], center_points, labels))
569
- for i in range(0, voxel_points.shape[0]):
570
- meshtopoints_labels.append(KNN_sklearn_Load_data(voxel_points[i], center_points, labels))
571
- return np.array(meshtopoints_labels)
572
-
573
- # 将三角网格数据上采样回原始点云数据
574
- def mesh_to_points_main(jaw, pcd_points, center_points, labels):
575
- points = pcd_points.copy()
576
- # 下采样
577
- voxel_points, same_points_list = voxel_filter(points, 0.6)
578
-
579
- after_labels = Load_data(voxel_points, center_points, labels)
580
-
581
- upsample_point, upsample_label = voxel_upsample(same_points_list, points, voxel_points, after_labels, 0.6)
582
-
583
- new_pcd = o3d.geometry.PointCloud()
584
- new_pcd.points = o3d.utility.Vector3dVector(upsample_point)
585
- instances_labels = upsample_label.copy()
586
- # '''
587
- # o3d.io.write_point_cloud(os.path.join(save_path, 'upsample_' + name + '.pcd'), new_pcd, write_ascii=True)
588
- for i in range(0, upsample_label.shape[0]):
589
- if jaw == 'upper':
590
- if (upsample_label[i] >= 1) and (upsample_label[i] <= 8):
591
- upsample_label[i] = upsample_label[i] + 10
592
- elif (upsample_label[i] >= 9) and (upsample_label[i] <= 16):
593
- upsample_label[i] = upsample_label[i] + 12
594
- else:
595
- if (upsample_label[i] >= 1) and (upsample_label[i] <= 8):
596
- upsample_label[i] = upsample_label[i] + 30
597
- elif (upsample_label[i] >= 9) and (upsample_label[i] <= 16):
598
- upsample_label[i] = upsample_label[i] + 32
599
- remove_outlier_main(jaw, pcd_points, upsample_label, instances_labels)
600
-
601
-
602
- # 将原始点云数据转换为三角网格
603
- def mesh_grid(pcd_points):
604
- new_pcd,_ = voxel_filter(pcd_points, 0.6)
605
- # pcd需要有法向量
606
-
607
- # estimate radius for rolling ball
608
- pcd_new = o3d.geometry.PointCloud()
609
- pcd_new.points = o3d.utility.Vector3dVector(new_pcd)
610
- pcd_new.estimate_normals()
611
- distances = pcd_new.compute_nearest_neighbor_distance()
612
- avg_dist = np.mean(distances)
613
- radius = 6 * avg_dist
614
- mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
615
- pcd_new,
616
- o3d.utility.DoubleVector([radius, radius * 2]))
617
- # o3d.io.write_triangle_mesh("./tooth date/test.ply", mesh)
618
-
619
- return mesh
620
-
621
-
622
- # 读取obj文件内容
623
- def read_obj(obj_path):
624
- jaw = None
625
- with open(obj_path) as file:
626
- points = []
627
- faces = []
628
- while 1:
629
- line = file.readline()
630
- if not line:
631
- break
632
- strs = line.split(" ")
633
- if strs[0] == "v":
634
- points.append((float(strs[1]), float(strs[2]), float(strs[3])))
635
- elif strs[0] == "f":
636
- faces.append((int(strs[1]), int(strs[2]), int(strs[3])))
637
- elif strs[1][0:5] == 'lower':
638
- jaw = 'lower'
639
- elif strs[1][0:5] == 'upper':
640
- jaw = 'upper'
641
-
642
- points = np.array(points)
643
- faces = np.array(faces)
644
-
645
- if jaw is None:
646
- raise ValueError("Jaw type not found in OBJ file")
647
-
648
- return points, faces, jaw
649
-
650
-
651
- # obj文件转为pcd文件
652
- def obj2pcd(obj_path):
653
- if os.path.exists(obj_path):
654
- print('yes')
655
- points, _, jaw = read_obj(obj_path)
656
- pcd_list = []
657
- num_points = np.shape(points)[0]
658
- for i in range(num_points):
659
- new_line = str(points[i, 0]) + ' ' + str(points[i, 1]) + ' ' + str(points[i, 2])
660
- pcd_list.append(new_line.split())
661
-
662
- pcd_points = np.array(pcd_list).astype(np.float64)
663
- return pcd_points, jaw
664
-
665
- # Configure Streamlit page
666
- st.set_page_config(page_title="Teeth Segmentation", page_icon="🦷")
667
-
668
- class Segment(TeethApp):
669
- def __init__(self):
670
- TeethApp.__init__(self)
671
- self.build_app()
672
-
673
- def build_app(self):
674
-
675
- st.title("Segment Intra-oral Scans")
676
- st.markdown("Select scan for segmentation")
677
-
678
- inputs = st.radio(
679
- "Select scan for segmentation:",
680
- ("Upload Scan", "Example Scan"),
681
- )
682
- import pyvista as pv
683
- if inputs == "Example Scan":
684
- mesh = pv.read("ZOUIF2W4_upper.obj")
685
- plotter = pv.Plotter()
686
-
687
- # Add the mesh to the plotter
688
- plotter.add_mesh(mesh, color='black', show_edges=True)
689
- visualize = st.button("Segment")
690
- if visualize:
691
- stpyvista(plotter)
692
-
693
- elif inputs == "Upload Scan":
694
- file = st.file_uploader("Please upload an OBJ Object file", type=["OBJ"])
695
-
696
- if file is not None:
697
- # save the uploaded file to disk
698
- with open("file.obj", "wb") as buffer:
699
- shutil.copyfileobj(file, buffer)
700
- # 复制数据
701
-
702
-
703
- obj_path = "file.obj"
704
- upsampling_method = 'KNN'
705
-
706
- model_path = 'Mesh_Segementation_MeshSegNet_17_classes_60samples_best.tar'
707
- num_classes = 17
708
- num_channels = 15
709
-
710
- # set model
711
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
712
- model = MeshSegNet(num_classes=num_classes, num_channels=num_channels).to(device, dtype=torch.float)
713
-
714
- # load trained model
715
- # checkpoint = torch.load(os.path.join(model_path, model_name), map_location='cpu')
716
- checkpoint = torch.load(model_path, map_location='cpu')
717
- model.load_state_dict(checkpoint['model_state_dict'])
718
- del checkpoint
719
- model = model.to(device, dtype=torch.float)
720
-
721
- # cudnn
722
- torch.backends.cudnn.benchmark = True
723
- torch.backends.cudnn.enabled = True
724
-
725
- # Predicting
726
- model.eval()
727
- with torch.no_grad():
728
- pcd_points, jaw = obj2pcd(obj_path)
729
- mesh = mesh_grid(pcd_points)
730
-
731
- # move mesh to origin
732
- with st.spinner("Patience please, AI at work. Grab a coffee while you wait☕!"):
733
- vertices_points = np.asarray(mesh.vertices)
734
- triangles_points = np.asarray(mesh.triangles)
735
- N = triangles_points.shape[0]
736
- cells = np.zeros((triangles_points.shape[0], 9))
737
- cells = vertices_points[triangles_points].reshape(triangles_points.shape[0], 9)
738
-
739
- mean_cell_centers = mesh.get_center()
740
- cells[:, 0:3] -= mean_cell_centers[0:3]
741
- cells[:, 3:6] -= mean_cell_centers[0:3]
742
- cells[:, 6:9] -= mean_cell_centers[0:3]
743
-
744
- v1 = np.zeros([triangles_points.shape[0], 3], dtype='float32')
745
- v2 = np.zeros([triangles_points.shape[0], 3], dtype='float32')
746
- v1[:, 0] = cells[:, 0] - cells[:, 3]
747
- v1[:, 1] = cells[:, 1] - cells[:, 4]
748
- v1[:, 2] = cells[:, 2] - cells[:, 5]
749
- v2[:, 0] = cells[:, 3] - cells[:, 6]
750
- v2[:, 1] = cells[:, 4] - cells[:, 7]
751
- v2[:, 2] = cells[:, 5] - cells[:, 8]
752
- mesh_normals = np.cross(v1, v2)
753
- mesh_normal_length = np.linalg.norm(mesh_normals, axis=1)
754
- mesh_normals[:, 0] /= mesh_normal_length[:]
755
- mesh_normals[:, 1] /= mesh_normal_length[:]
756
- mesh_normals[:, 2] /= mesh_normal_length[:]
757
-
758
- # prepare input
759
- points = vertices_points.copy()
760
- points[:, 0:3] -= mean_cell_centers[0:3]
761
- normals = np.nan_to_num(mesh_normals).copy()
762
- barycenters = np.zeros((triangles_points.shape[0], 3))
763
- s = np.sum(vertices_points[triangles_points], 1)
764
- barycenters = 1 / 3 * s
765
- center_points = barycenters.copy()
766
- barycenters -= mean_cell_centers[0:3]
767
-
768
- # normalized data
769
- maxs = points.max(axis=0)
770
- mins = points.min(axis=0)
771
- means = points.mean(axis=0)
772
- stds = points.std(axis=0)
773
- nmeans = normals.mean(axis=0)
774
- nstds = normals.std(axis=0)
775
-
776
- for i in range(3):
777
- cells[:, i] = (cells[:, i] - means[i]) / stds[i] # point 1
778
- cells[:, i + 3] = (cells[:, i + 3] - means[i]) / stds[i] # point 2
779
- cells[:, i + 6] = (cells[:, i + 6] - means[i]) / stds[i] # point 3
780
- barycenters[:, i] = (barycenters[:, i] - mins[i]) / (maxs[i] - mins[i])
781
- normals[:, i] = (normals[:, i] - nmeans[i]) / nstds[i]
782
-
783
- X = np.column_stack((cells, barycenters, normals))
784
-
785
- # computing A_S and A_L
786
- A_S = np.zeros([X.shape[0], X.shape[0]], dtype='float32')
787
- A_L = np.zeros([X.shape[0], X.shape[0]], dtype='float32')
788
- D = distance_matrix(X[:, 9:12], X[:, 9:12])
789
- A_S[D < 0.1] = 1.0
790
- A_S = A_S / np.dot(np.sum(A_S, axis=1, keepdims=True), np.ones((1, X.shape[0])))
791
-
792
- A_L[D < 0.2] = 1.0
793
- A_L = A_L / np.dot(np.sum(A_L, axis=1, keepdims=True), np.ones((1, X.shape[0])))
794
-
795
- # numpy -> torch.tensor
796
- X = X.transpose(1, 0)
797
- X = X.reshape([1, X.shape[0], X.shape[1]])
798
- X = torch.from_numpy(X).to(device, dtype=torch.float)
799
- A_S = A_S.reshape([1, A_S.shape[0], A_S.shape[1]])
800
- A_L = A_L.reshape([1, A_L.shape[0], A_L.shape[1]])
801
- A_S = torch.from_numpy(A_S).to(device, dtype=torch.float)
802
- A_L = torch.from_numpy(A_L).to(device, dtype=torch.float)
803
-
804
- tensor_prob_output = model(X, A_S, A_L).to(device, dtype=torch.float)
805
- patch_prob_output = tensor_prob_output.cpu().numpy()
806
-
807
- # refinement
808
- with st.spinner("Refining..."):
809
- round_factor = 100
810
- patch_prob_output[patch_prob_output < 1.0e-6] = 1.0e-6
811
-
812
- # unaries
813
- unaries = -round_factor * np.log10(patch_prob_output)
814
- unaries = unaries.astype(np.int32)
815
- unaries = unaries.reshape(-1, num_classes)
816
-
817
- # parawisex
818
- pairwise = (1 - np.eye(num_classes, dtype=np.int32))
819
-
820
- cells = cells.copy()
821
-
822
- cell_ids = np.asarray(triangles_points)
823
- lambda_c = 20
824
- edges = np.empty([1, 3], order='C')
825
- for i_node in range(cells.shape[0]):
826
- # Find neighbors
827
- nei = np.sum(np.isin(cell_ids, cell_ids[i_node, :]), axis=1)
828
- nei_id = np.where(nei == 2)
829
- for i_nei in nei_id[0][:]:
830
- if i_node < i_nei:
831
- cos_theta = np.dot(normals[i_node, 0:3], normals[i_nei, 0:3]) / np.linalg.norm(
832
- normals[i_node, 0:3]) / np.linalg.norm(normals[i_nei, 0:3])
833
- if cos_theta >= 1.0:
834
- cos_theta = 0.9999
835
- theta = np.arccos(cos_theta)
836
- phi = np.linalg.norm(barycenters[i_node, :] - barycenters[i_nei, :])
837
- if theta > np.pi / 2.0:
838
- edges = np.concatenate(
839
- (edges, np.array([i_node, i_nei, -np.log10(theta / np.pi) * phi]).reshape(1, 3)), axis=0)
840
- else:
841
- beta = 1 + np.linalg.norm(np.dot(normals[i_node, 0:3], normals[i_nei, 0:3]))
842
- edges = np.concatenate(
843
- (edges, np.array([i_node, i_nei, -beta * np.log10(theta / np.pi) * phi]).reshape(1, 3)),
844
- axis=0)
845
- edges = np.delete(edges, 0, 0)
846
- edges[:, 2] *= lambda_c * round_factor
847
- edges = edges.astype(np.int32)
848
-
849
- refine_labels = cut_from_graph(edges, unaries, pairwise)
850
- refine_labels = refine_labels.reshape([-1, 1])
851
-
852
- predicted_labels_3 = refine_labels.reshape(refine_labels.shape[0])
853
- mesh_to_points_main(jaw, pcd_points, center_points, predicted_labels_3)
854
-
855
- import pyvista as pv
856
-
857
- with st.spinner("Rendering..."):
858
- # Load the .obj file
859
- mesh = pv.read('file.obj')
860
-
861
- # Load the JSON file
862
- with open('dental-labels4.json', 'r') as file:
863
- labels_data = json.load(file)
864
-
865
- # Assuming labels_data['labels'] is a list of labels
866
- labels = labels_data['labels']
867
-
868
- # Make sure the number of labels matches the number of vertices or faces
869
- assert len(labels) == mesh.n_points or len(labels) == mesh.n_cells
870
-
871
- # If labels correspond to vertices
872
- if len(labels) == mesh.n_points:
873
- mesh.point_data['Labels'] = labels
874
- # If labels correspond to faces
875
- elif len(labels) == mesh.n_cells:
876
- mesh.cell_data['Labels'] = labels
877
-
878
- # Create a pyvista plotter
879
- plotter = pv.Plotter()
880
-
881
- cmap = plt.cm.get_cmap('jet', 27) # Using a colormap with sufficient distinct colors
882
-
883
- colors = cmap(np.linspace(0, 1, 27)) # Generate colors
884
-
885
- # Convert colors to a format acceptable by PyVista
886
- colormap = mcolors.ListedColormap(colors)
887
-
888
- # Add the mesh to the plotter with labels as a scalar field
889
- #plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap='jet')
890
- plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap=colormap, clim=[0, 27])
891
-
892
- # Show the plot
893
- #plotter.show()
894
- ## Send to streamlit
895
- stpyvista(plotter)
896
-
897
- if __name__ == "__main__":
898
- app = Segment()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
apps/demo/pages/02_📙How_it_Works.py DELETED
@@ -1,50 +0,0 @@
1
- import streamlit as st
2
- from streamlit import session_state as session
3
-
4
- from PIL import Image
5
-
6
- class TeethApp:
7
- def __init__(self):
8
- # Font
9
- with open("utils/style.css") as css:
10
- st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)
11
-
12
- # Logo
13
- self.image_path = "utils/teeth-295404_1280.png"
14
- self.image = Image.open(self.image_path)
15
- width, height = self.image.size
16
- scale = 12
17
- new_width, new_height = width / scale, height / scale
18
- self.image = self.image.resize((int(new_width), int(new_height)))
19
-
20
- # Streamlit side navigation bar
21
- st.sidebar.markdown("# AI ToothSeg")
22
- st.sidebar.markdown("Automatic teeth segmentation with Deep Learning")
23
- st.sidebar.markdown(" ")
24
- st.sidebar.image(self.image, use_column_width=False)
25
- st.markdown(
26
- """
27
- <style>
28
- .css-1bxukto {
29
- background-color: rgb(255, 255, 255) ;""",
30
- unsafe_allow_html=True,
31
- )
32
-
33
- # Configure Streamlit page
34
- st.set_page_config(page_title="Teeth Segmentation", page_icon="ⓘ")
35
-
36
-
37
- class Guide(TeethApp):
38
- def __init__(self):
39
- TeethApp.__init__(self)
40
- self.build_app()
41
-
42
- def build_app(self):
43
- st.title("AI-assited Tooth Segmentation")
44
- st.markdown("This app automatically segments intra-oral scans of teeth using machine learning.")
45
- st.markdown("Head to the 'Segment' tab to try it out!")
46
- st.markdown("**Example:**")
47
- st.image("illu.png")
48
-
49
- if __name__ == "__main__":
50
- app = Guide()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
apps/demo/requirements.txt DELETED
@@ -1,10 +0,0 @@
1
- streamlit==1.28.2
2
- pyvista==0.36.1
3
- pythreejs==2.4.2
4
- stpyvista==0.0.5
5
- open3d==0.15.1
6
- torch==1.11.0
7
- scikit-learn==0.23.2
8
- scipy==1.5.2
9
- matplotlib==3.3.2
10
- pillow==10.1.0
 
 
 
 
 
 
 
 
 
 
 
apps/demo/utils/style.css DELETED
@@ -1,10 +0,0 @@
1
- @import url('https://fonts.googleapis.com/css2?family=Nunito:wght@400&display=swap');
2
-
3
- html,
4
- body,
5
- [class*="css"] {
6
- font-family: 'Nunito';
7
- /* font-size: 16px; */
8
- font-weight: 400;
9
- color: #091747;
10
- }
 
 
 
 
 
 
 
 
 
 
 
apps/demo/utils/teeth-295404_1280.png DELETED
Binary file (149 kB)
 
apps/demo/ⓘ_Introduction.py DELETED
@@ -1,40 +0,0 @@
1
- import streamlit as st
2
- from streamlit import session_state as session
3
-
4
- from PIL import Image
5
-
6
- class TeethApp:
7
- def __init__(self):
8
- # Font
9
- with open("utils/style.css") as css:
10
- st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)
11
-
12
- # Logo
13
- self.image_path = "utils/teeth-295404_1280.png"
14
- self.image = Image.open(self.image_path)
15
- width, height = self.image.size
16
- scale = 12
17
- new_width, new_height = width / scale, height / scale
18
- self.image = self.image.resize((int(new_width), int(new_height)))
19
-
20
- # Streamlit side navigation bar
21
- st.sidebar.markdown("# AI ToothSeg")
22
- st.sidebar.markdown("Automatic teeth segmentation with Deep Learning")
23
- st.sidebar.markdown(" ")
24
- st.sidebar.image(self.image, use_column_width=False)
25
- st.markdown(
26
- """
27
- <style>
28
- .css-1bxukto {
29
- background-color: rgb(255, 255, 255) ;""",
30
- unsafe_allow_html=True,
31
- )
32
-
33
- # Configure Streamlit page
34
- st.set_page_config(page_title="Teeth Segmentation", page_icon="ⓘ")
35
-
36
- st.title("AI-assited Tooth Segmentation")
37
- st.markdown("This app automatically segments intra-oral scans of teeth using machine learning.")
38
- st.markdown("Head to the 'Segment' tab to try it out!")
39
- st.markdown("**Example:**")
40
- st.image("illu.png")