DHEIVER commited on
Commit
5517146
1 Parent(s): 5dc06ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -132
app.py CHANGED
@@ -1,138 +1,23 @@
1
- import os
2
  import gradio as gr
3
- import torch
4
- import cv2
5
- import numpy as np
6
- from preprocess import unsharp_masking
7
- import time
8
- from sklearn.cluster import KMeans
9
 
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
11
 
12
- def preprocess_image(img, model='SE-RegUNet 4GF'):
13
- img = cv2.resize(img, (512, 512))
14
- img = unsharp_masking(img).astype(np.uint8)
15
 
16
- def normalize_image(img):
17
- return np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
 
 
 
 
18
 
19
- if model in ('AngioNet', 'UNet3+'):
20
- img = normalize_image(img)
21
- img_out = np.expand_dims(img, axis=0)
22
- elif model == 'SE-RegUNet 4GF':
23
- clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
24
- clahe2 = cv2.createCLAHE(clipLimit=8.0, tileGridSize=(8, 8))
25
- image1 = clahe1.apply(img)
26
- image2 = clahe2.apply(img)
27
- img = normalize_image(img)
28
- image1 = normalize_image(image1)
29
- image2 = normalize_image(image2)
30
- img_out = np.stack((img, image1, image2), axis=0)
31
- else:
32
- clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
33
- image1 = clahe1.apply(img)
34
- image1 = normalize_image(image1)
35
- img_out = np.stack((image1,) * 3, axis=0)
36
 
37
- return img_out
38
-
39
- def process_input_image(img, model, save_result=False):
40
- try:
41
- img = img.copy()
42
- pipe = models[model].to(device).eval()
43
- start = time.time()
44
- img, h, w, ori_gray, ori = preprocess_image(img, model)
45
- img = torch.FloatTensor(img).unsqueeze(0).to(device)
46
-
47
- with torch.no_grad():
48
- if model == 'AngioNet':
49
- img = torch.cat([img, img], dim=0)
50
- logit = np.round(torch.softmax(pipe.forward(img), dim=1).detach().cpu().numpy()[0, 0]).astype(np.uint8)
51
-
52
- spent = f"{time.time() - start:.3f} segundos"
53
-
54
- if h != 512 or w != 512:
55
- logit = cv2.resize(logit, (h, w))
56
-
57
- logit = logit.astype(bool)
58
-
59
- img_out = ori.copy()
60
- img_out[logit, 0] = 255
61
-
62
- if save_result:
63
- file_name = os.path.join(caminho_salvar_resultado, f'resultado_{int(time.time())}.png')
64
- cv2.imwrite(file_name, img_out)
65
-
66
- return spent, img_out
67
-
68
- except Exception as e:
69
- return str(e), None
70
-
71
- models = {
72
- 'SE-RegUNet 4GF': torch.jit.load('./model/SERegUNet4GF.pt'),
73
- 'SE-RegUNet 16GF': torch.jit.load('./model/SERegUNet16GF.pt'),
74
- 'AngioNet': torch.jit.load('./model/AngioNet.pt'),
75
- 'EffUNet++ B5': torch.jit.load('./model/EffUNetppb5.pt'),
76
- 'Reg-SA-UNet++': torch.jit.load('./model/RegSAUnetpp.pt'),
77
- 'UNet3+': torch.jit.load('./model/UNet3plus.pt'),
78
- }
79
-
80
- def process_input_image_wrapper(img, model, save_result=False):
81
- resultado, img_out = process_input_image(img, model, save_result)
82
-
83
- kmeans = KMeans(n_clusters=2, random_state=0)
84
- flattened_img = img_out[:, :, 0].reshape((-1, 1))
85
- kmeans.fit(flattened_img)
86
- labels = kmeans.labels_
87
- cluster_centers = kmeans.cluster_centers_
88
-
89
- num_clusters = len(cluster_centers)
90
- cluster_features = []
91
-
92
- for i in range(num_clusters):
93
- cluster_mask = labels == i
94
- area = np.sum(cluster_mask)
95
-
96
- if area == 0:
97
- continue
98
-
99
- contours, _ = cv2.findContours(np.uint8(cluster_mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
100
-
101
- if len(contours) > 0:
102
- perimeter = cv2.arcLength(contours[0], True)
103
- compactness = 4 * np.pi * area / (perimeter ** 2)
104
-
105
- cluster_features.append({'area': area, 'compactness': compactness})
106
-
107
- has_disease_flag = any(feature['area'] >= 200 and feature['compactness'] < 0.3 for feature in cluster_features)
108
-
109
- status_doenca = "Sim" if has_disease_flag else "Não"
110
- explanation = "A máquina detectou uma possível doença nos vasos sanguíneos." if has_disease_flag else "A máquina não detectou nenhuma doença nos vasos sanguíneos."
111
-
112
- return resultado, img_out, status_doenca, explanation, f"{num_analises} análises realizadas"
113
-
114
- caminho_salvar_resultado = "/Segmento_de_Angio_Coronariana_v5/Salvar Resultado"
115
- num_analises = 0
116
-
117
- my_app = gr.Interface(
118
- fn=process_input_image_wrapper,
119
- inputs=[
120
- gr.inputs.Image(label="Angiograma:", shape=(512, 512)),
121
- gr.inputs.Dropdown(['SE-RegUNet 4GF','SE-RegUNet 16GF', 'AngioNet', 'EffUNet++ B5', 'Reg-SA-UNet++', 'UNet3+'], label='Modelo', default='SE-RegUNet 4GF'),
122
- gr.inputs.Checkbox(label="Salvar Resultado"),
123
- ],
124
- outputs=[
125
- gr.outputs.Label(label="Tempo decorrido"),
126
- gr.outputs.Image(type="numpy", label="Imagem de Saída"),
127
- gr.outputs.Label(label="Possui Doença?"),
128
- gr.outputs.Label(label="Explicação"),
129
- gr.outputs.Label(label="Análises Realizadas"),
130
- ],
131
- title="Segmentação de Angiograma Coronariano",
132
- description="Esta aplicação segmenta angiogramas coronarianos usando modelos de segmentação pré-treinados.",
133
- theme="default",
134
- layout="vertical",
135
- allow_flagging=False,
136
- )
137
-
138
- my_app.launch()
 
 
1
  import gradio as gr
2
+ from PIL import Image
 
 
 
 
 
3
 
4
+ # Import the ObstructionDetector class from your module
5
+ from obstruction_detector import ObstructionDetector
6
 
7
+ # Create an instance of ObstructionDetector
8
+ detector = ObstructionDetector()
 
9
 
10
+ # Define a Gradio function to process the image and return the report
11
+ def process_image(image):
12
+ # Call the detect_obstruction method of the ObstructionDetector with the PIL image
13
+ report = detector.detect_obstruction(image)
14
+
15
+ return report
16
 
17
+ # Define the Gradio interface
18
+ iface = gr.Interface(fn=process_image,
19
+ inputs=gr.inputs.Image(shape=(224, 224)), # Adjust shape as needed
20
+ outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Launch the Gradio interface
23
+ iface.launch()