DHEIVER commited on
Commit
5dc06ac
1 Parent(s): 6423efc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -114
app.py CHANGED
@@ -1,6 +1,4 @@
1
  import os
2
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3
-
4
  import gradio as gr
5
  import torch
6
  import cv2
@@ -8,109 +6,68 @@ import numpy as np
8
  from preprocess import unsharp_masking
9
  import time
10
  from sklearn.cluster import KMeans
11
- import os
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
- # Função para ordenar e pré-processar a imagem de entrada
16
- def ordenar_e_preprocessar_imagem(img, modelo):
17
- ori = img.copy()
18
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
19
- h, w = img.shape
20
- img_out = preprocessar_imagem(img, modelo)
21
- return img_out, h, w, img, ori
22
-
23
- # Função para pré-processar a imagem com base no modelo selecionado
24
- def preprocessar_imagem(img, modelo='SE-RegUNet 4GF'):
25
- # Redimensionar a imagem para 512x512
26
  img = cv2.resize(img, (512, 512))
27
-
28
- # Aplicar a máscara de nitidez à imagem
29
  img = unsharp_masking(img).astype(np.uint8)
30
-
31
- # Função auxiliar para normalizar a imagem
32
- def normalizar_imagem(img):
33
  return np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
34
-
35
- if modelo == 'AngioNet' or modelo == 'UNet3+':
36
- img = normalizar_imagem(img)
37
  img_out = np.expand_dims(img, axis=0)
38
- elif modelo == 'SE-RegUNet 4GF':
39
  clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
40
  clahe2 = cv2.createCLAHE(clipLimit=8.0, tileGridSize=(8, 8))
41
  image1 = clahe1.apply(img)
42
  image2 = clahe2.apply(img)
43
- img = normalizar_imagem(img)
44
- image1 = normalizar_imagem(image1)
45
- image2 = normalizar_imagem(image2)
46
  img_out = np.stack((img, image1, image2), axis=0)
47
  else:
48
  clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
49
  image1 = clahe1.apply(img)
50
- image1 = normalizar_imagem(image1)
51
  img_out = np.stack((image1,) * 3, axis=0)
52
-
53
- return img_out
54
 
 
55
 
56
- import os
57
-
58
- # Caminho absoluto para a pasta de salvamento
59
- caminho_salvar_resultado = "/Segmento_de_Angio_Coronariana_v5/Salvar Resultado"
60
-
61
- # Função para processar a imagem de entrada
62
- def processar_imagem_de_entrada(img, modelo, pipe, salvar_resultado=False):
63
  try:
64
- # Faça uma cópia da imagem original
65
  img = img.copy()
66
-
67
- # Coloque o modelo na GPU (se disponível) e configure-o para modo de avaliação
68
- pipe = pipe.to(device).eval()
69
-
70
- # Registre o tempo de início
71
  start = time.time()
72
-
73
- # Pré-processe a imagem e obtenha informações de dimensão
74
- img, h, w, ori_gray, ori = ordenar_e_preprocessar_imagem(img, modelo)
75
-
76
- # Converta a imagem para o formato esperado pelo modelo e coloque-a na GPU
77
  img = torch.FloatTensor(img).unsqueeze(0).to(device)
78
-
79
- # Realize a inferência do modelo sem gradientes
80
  with torch.no_grad():
81
- if modelo == 'AngioNet':
82
  img = torch.cat([img, img], dim=0)
83
  logit = np.round(torch.softmax(pipe.forward(img), dim=1).detach().cpu().numpy()[0, 0]).astype(np.uint8)
84
-
85
- # Calcule o tempo decorrido
86
- spent = time.time() - start
87
- spent = f"{spent:.3f} segundos"
88
 
89
- # Redimensione o resultado, se necessário
 
90
  if h != 512 or w != 512:
91
  logit = cv2.resize(logit, (h, w))
92
 
93
- # Converta o resultado para um formato booleano
94
  logit = logit.astype(bool)
95
-
96
- # Crie uma cópia da imagem original para saída e aplique a máscara
97
  img_out = ori.copy()
98
  img_out[logit, 0] = 255
99
 
100
- # Salve o resultado em um arquivo se a opção estiver ativada
101
- if salvar_resultado:
102
- nome_arquivo = os.path.join(caminho_salvar_resultado, f'resultado_{int(time.time())}.png')
103
- cv2.imwrite(nome_arquivo, img_out)
104
-
105
  return spent, img_out
106
-
107
  except Exception as e:
108
- # Em caso de erro, retorne uma mensagem de erro
109
  return str(e), None
110
 
111
-
112
-
113
- # Carregar modelos pré-treinados
114
  models = {
115
  'SE-RegUNet 4GF': torch.jit.load('./model/SERegUNet4GF.pt'),
116
  'SE-RegUNet 16GF': torch.jit.load('./model/SERegUNet16GF.pt'),
@@ -120,72 +77,45 @@ models = {
120
  'UNet3+': torch.jit.load('./model/UNet3plus.pt'),
121
  }
122
 
123
- from scipy.spatial import distance
124
- from scipy.ndimage import label
125
- import numpy as np
126
 
127
- # Adicionar a opção de salvar o resultado em um arquivo
128
- def processar_imagem_de_entrada_wrapper(img, modelo, salvar_resultado=False):
129
- model = models[modelo]
130
- resultado, img_out = processar_imagem_de_entrada(img, modelo, model, salvar_resultado)
131
-
132
- # Resto do código permanece inalterado
133
  kmeans = KMeans(n_clusters=2, random_state=0)
134
  flattened_img = img_out[:, :, 0].reshape((-1, 1))
135
  kmeans.fit(flattened_img)
136
  labels = kmeans.labels_
137
  cluster_centers = kmeans.cluster_centers_
138
-
139
- # Detecção de doenças usando K-Means
140
- kmeans = KMeans(n_clusters=2, random_state=0)
141
- flattened_img = img_out[:, :, 0].reshape((-1, 1)) # Use o canal de intensidade
142
- kmeans.fit(flattened_img)
143
- labels = kmeans.labels_
144
- cluster_centers = kmeans.cluster_centers_
145
-
146
- # Resto do código permanece inalterado
147
-
148
- # Extração de características dos clusters
149
  num_clusters = len(cluster_centers)
150
  cluster_features = []
 
151
  for i in range(num_clusters):
152
- cluster_mask = labels == i # Create a boolean mask for the cluster
153
-
154
- # Calcular área do cluster
155
  area = np.sum(cluster_mask)
156
-
157
- if area == 0: # Skip empty clusters
158
  continue
159
-
160
- # Calcular forma do cluster usando a relação entre área e perímetro
161
  contours, _ = cv2.findContours(np.uint8(cluster_mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
162
  if len(contours) > 0:
163
  perimeter = cv2.arcLength(contours[0], True)
164
  compactness = 4 * np.pi * area / (perimeter ** 2)
165
-
166
  cluster_features.append({'area': area, 'compactness': compactness})
167
-
168
- # Decidir se há doença com base nas características dos clusters
169
  has_disease_flag = any(feature['area'] >= 200 and feature['compactness'] < 0.3 for feature in cluster_features)
170
-
171
- # Formatar o indicador de doença como uma string
172
- if has_disease_flag:
173
- status_doenca = "Sim"
174
- explanation = "A máquina detectou uma possível doença nos vasos sanguíneos."
175
- else:
176
- status_doenca = "Não"
177
- explanation = "A máquina não detectou nenhuma doença nos vasos sanguíneos."
178
-
179
- # Resto do código permanece inalterado
180
-
181
  return resultado, img_out, status_doenca, explanation, f"{num_analises} análises realizadas"
182
 
183
- # Inicializar a contagem de análises
184
  num_analises = 0
185
 
186
- # Criar a interface Gradio
187
  my_app = gr.Interface(
188
- fn=processar_imagem_de_entrada_wrapper,
189
  inputs=[
190
  gr.inputs.Image(label="Angiograma:", shape=(512, 512)),
191
  gr.inputs.Dropdown(['SE-RegUNet 4GF','SE-RegUNet 16GF', 'AngioNet', 'EffUNet++ B5', 'Reg-SA-UNet++', 'UNet3+'], label='Modelo', default='SE-RegUNet 4GF'),
@@ -205,5 +135,4 @@ my_app = gr.Interface(
205
  allow_flagging=False,
206
  )
207
 
208
- # Iniciar a interface Gradio
209
- my_app.launch()
 
1
  import os
 
 
2
  import gradio as gr
3
  import torch
4
  import cv2
 
6
  from preprocess import unsharp_masking
7
  import time
8
  from sklearn.cluster import KMeans
 
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+ def preprocess_image(img, model='SE-RegUNet 4GF'):
 
 
 
 
 
 
 
 
 
 
13
  img = cv2.resize(img, (512, 512))
 
 
14
  img = unsharp_masking(img).astype(np.uint8)
15
+
16
+ def normalize_image(img):
 
17
  return np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
18
+
19
+ if model in ('AngioNet', 'UNet3+'):
20
+ img = normalize_image(img)
21
  img_out = np.expand_dims(img, axis=0)
22
+ elif model == 'SE-RegUNet 4GF':
23
  clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
24
  clahe2 = cv2.createCLAHE(clipLimit=8.0, tileGridSize=(8, 8))
25
  image1 = clahe1.apply(img)
26
  image2 = clahe2.apply(img)
27
+ img = normalize_image(img)
28
+ image1 = normalize_image(image1)
29
+ image2 = normalize_image(image2)
30
  img_out = np.stack((img, image1, image2), axis=0)
31
  else:
32
  clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
33
  image1 = clahe1.apply(img)
34
+ image1 = normalize_image(image1)
35
  img_out = np.stack((image1,) * 3, axis=0)
 
 
36
 
37
+ return img_out
38
 
39
+ def process_input_image(img, model, save_result=False):
 
 
 
 
 
 
40
  try:
 
41
  img = img.copy()
42
+ pipe = models[model].to(device).eval()
 
 
 
 
43
  start = time.time()
44
+ img, h, w, ori_gray, ori = preprocess_image(img, model)
 
 
 
 
45
  img = torch.FloatTensor(img).unsqueeze(0).to(device)
46
+
 
47
  with torch.no_grad():
48
+ if model == 'AngioNet':
49
  img = torch.cat([img, img], dim=0)
50
  logit = np.round(torch.softmax(pipe.forward(img), dim=1).detach().cpu().numpy()[0, 0]).astype(np.uint8)
 
 
 
 
51
 
52
+ spent = f"{time.time() - start:.3f} segundos"
53
+
54
  if h != 512 or w != 512:
55
  logit = cv2.resize(logit, (h, w))
56
 
 
57
  logit = logit.astype(bool)
58
+
 
59
  img_out = ori.copy()
60
  img_out[logit, 0] = 255
61
 
62
+ if save_result:
63
+ file_name = os.path.join(caminho_salvar_resultado, f'resultado_{int(time.time())}.png')
64
+ cv2.imwrite(file_name, img_out)
65
+
 
66
  return spent, img_out
67
+
68
  except Exception as e:
 
69
  return str(e), None
70
 
 
 
 
71
  models = {
72
  'SE-RegUNet 4GF': torch.jit.load('./model/SERegUNet4GF.pt'),
73
  'SE-RegUNet 16GF': torch.jit.load('./model/SERegUNet16GF.pt'),
 
77
  'UNet3+': torch.jit.load('./model/UNet3plus.pt'),
78
  }
79
 
80
+ def process_input_image_wrapper(img, model, save_result=False):
81
+ resultado, img_out = process_input_image(img, model, save_result)
 
82
 
 
 
 
 
 
 
83
  kmeans = KMeans(n_clusters=2, random_state=0)
84
  flattened_img = img_out[:, :, 0].reshape((-1, 1))
85
  kmeans.fit(flattened_img)
86
  labels = kmeans.labels_
87
  cluster_centers = kmeans.cluster_centers_
88
+
 
 
 
 
 
 
 
 
 
 
89
  num_clusters = len(cluster_centers)
90
  cluster_features = []
91
+
92
  for i in range(num_clusters):
93
+ cluster_mask = labels == i
 
 
94
  area = np.sum(cluster_mask)
95
+
96
+ if area == 0:
97
  continue
98
+
 
99
  contours, _ = cv2.findContours(np.uint8(cluster_mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
100
+
101
  if len(contours) > 0:
102
  perimeter = cv2.arcLength(contours[0], True)
103
  compactness = 4 * np.pi * area / (perimeter ** 2)
104
+
105
  cluster_features.append({'area': area, 'compactness': compactness})
106
+
 
107
  has_disease_flag = any(feature['area'] >= 200 and feature['compactness'] < 0.3 for feature in cluster_features)
108
+
109
+ status_doenca = "Sim" if has_disease_flag else "Não"
110
+ explanation = "A máquina detectou uma possível doença nos vasos sanguíneos." if has_disease_flag else "A máquina não detectou nenhuma doença nos vasos sanguíneos."
111
+
 
 
 
 
 
 
 
112
  return resultado, img_out, status_doenca, explanation, f"{num_analises} análises realizadas"
113
 
114
+ caminho_salvar_resultado = "/Segmento_de_Angio_Coronariana_v5/Salvar Resultado"
115
  num_analises = 0
116
 
 
117
  my_app = gr.Interface(
118
+ fn=process_input_image_wrapper,
119
  inputs=[
120
  gr.inputs.Image(label="Angiograma:", shape=(512, 512)),
121
  gr.inputs.Dropdown(['SE-RegUNet 4GF','SE-RegUNet 16GF', 'AngioNet', 'EffUNet++ B5', 'Reg-SA-UNet++', 'UNet3+'], label='Modelo', default='SE-RegUNet 4GF'),
 
135
  allow_flagging=False,
136
  )
137
 
138
+ my_app.launch()