File size: 5,893 Bytes
283999e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff8277d
 
 
 
283999e
 
 
 
 
 
 
 
 
ff8277d
 
 
 
 
 
af135e8
ff8277d
 
af135e8
 
 
 
ff8277d
 
dab2261
 
 
 
 
 
ff8277d
 
af135e8
283999e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff8277d
dab2261
283999e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

import gradio as gr
import torch
import cv2
import numpy as np
from preprocess import unsharp_masking
import time
from sklearn.cluster import KMeans

device = "cuda" if torch.cuda.is_available() else "cpu"

# Função para ordenar e pré-processar a imagem de entrada
def ordenar_arquivos(img, modelo):
    ori = img.copy()
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    h, w = img.shape
    img_out = preprocessamento(img, modelo)
    return img_out, h, w, img, ori

# Função para pré-processar a imagem com base no modelo selecionado
def preprocessamento(img, modelo='SE-RegUNet 4GF'):
    img = cv2.resize(img, (512, 512))
    img = unsharp_masking(img).astype(np.uint8)
    if modelo == 'AngioNet' or modelo == 'UNet3+':
        img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
        img_out = np.expand_dims(img, axis=0)
    elif modelo == 'SE-RegUNet 4GF':
        clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        clahe2 = cv2.createCLAHE(clipLimit=8.0, tileGridSize=(8, 8))
        image1 = clahe1.apply(img)
        image2 = clahe2.apply(img)
        img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
        image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6))
        image2 = np.float32((image2 - image2.min()) / (image2.max() - image2.min() + 1e-6))
        img_out = np.stack((img, image1, image2), axis=0)
    else:
        clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        image1 = clahe1.apply(img)
        image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6))
        img_out = np.stack((image1,) * 3, axis=0)
    return img_out

# Função para processar a imagem de entrada
def processar_imagem_de_entrada(img, modelo, pipe):
    img = img.copy()
    pipe = pipe.to(device).eval()
    start = time.time()
    img, h, w, ori_gray, ori = ordenar_arquivos(img, modelo)
    img = torch.FloatTensor(img).unsqueeze(0).to(device)
    with torch.no_grad():
        if modelo == 'AngioNet':
            img = torch.cat([img, img], dim=0)
        logit = np.round(torch.softmax(pipe.forward(img), dim=1).detach().cpu().numpy()[0, 0]).astype(np.uint8)
    spent = time.time() - start
    spent = f"{spent:.3f} segundos"

    if h != 512 or w != 512:
        logit = cv2.resize(logit, (h, w))

    logit = logit.astype(bool)
    img_out = ori.copy()
    img_out[logit, 0] = 255
    return spent, img_out

# Carregar modelos pré-treinados
models = {
    'SE-RegUNet 4GF': torch.jit.load('./model/SERegUNet4GF.pt'),
    'SE-RegUNet 16GF': torch.jit.load('./model/SERegUNet16GF.pt'),
    'AngioNet': torch.jit.load('./model/AngioNet.pt'),
    'EffUNet++ B5': torch.jit.load('./model/EffUNetppb5.pt'),
    'Reg-SA-UNet++': torch.jit.load('./model/RegSAUnetpp.pt'),
    'UNet3+': torch.jit.load('./model/UNet3plus.pt'),
}

from scipy.spatial import distance
from scipy.ndimage import label
import numpy as np

def processar_imagem_de_entrada_wrapper(img, modelo):
    model = models[modelo]
    spent, img_out = processar_imagem_de_entrada(img, modelo, model)
    
    # Verificar se há doença usando K-Means
    kmeans = KMeans(n_clusters=2, random_state=0)
    flattened_img = img_out[:, :, 0].reshape((-1, 1))  # Use the intensity channel
    kmeans.fit(flattened_img)
    labels = kmeans.labels_
    cluster_centers = kmeans.cluster_centers_
    
    # Extração de características dos clusters
    num_clusters = len(cluster_centers)
    cluster_features = []
    for i in range(num_clusters):
        cluster_mask = labels == i  # Create a boolean mask for the cluster
        
        # Calcular área do cluster
        area = np.sum(cluster_mask)
        
        if area == 0:  # Skip empty clusters
            continue
        
        # Calcular forma do cluster usando a relação entre área e perímetro
        contours, _ = cv2.findContours(np.uint8(cluster_mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if len(contours) > 0:
            perimeter = cv2.arcLength(contours[0], True)
            compactness = 4 * np.pi * area / (perimeter ** 2)
            
            cluster_features.append({'area': area, 'compactness': compactness})
    
    # Decidir se há doença com base nas características dos clusters
    has_disease_flag = any(feature['area'] >= 200 and feature['compactness'] < 0.3 for feature in cluster_features)
    
    # Formatar o indicador de doença como uma string
    if has_disease_flag:
        status_doenca = "Sim"
    else:
        status_doenca = "Não"
    
    # Adicionar a explicação com base no status de doença
    if has_disease_flag:
        explanation = "A máquina detectou uma possível doença nos vasos sanguíneos."
    else:
        explanation = "A máquina não detectou nenhuma doença nos vasos sanguíneos."
    
    # ... (resto do seu código, se houver mais)
    
    return spent, img_out, status_doenca, explanation



# Criar a interface Gradio
my_app = gr.Interface(
    fn=processar_imagem_de_entrada_wrapper,
    inputs=[
        gr.inputs.Image(label="Angiograma:", shape=(512, 512)),
        gr.inputs.Dropdown(['SE-RegUNet 4GF', 'SE-RegUNet 16GF', 'AngioNet', 'EffUNet++ B5', 'Reg-SA-UNet++', 'UNet3+'], label='Modelo', default='SE-RegUNet 4GF'),
    ],
    outputs=[
        gr.outputs.Label(label="Tempo decorrido"),
        gr.outputs.Image(type="numpy", label="Imagem de Saída"),
        gr.outputs.Label(label="Possui Doença?"),
        gr.outputs.Label(label="Explicação"),
    ],
    title="Segmentação de Angiograma Coronariano",
    description="Esta aplicação segmenta angiogramas coronarianos usando modelos de segmentação pré-treinados.",
    theme="default",
    layout="vertical",
    allow_flagging=False,
)

# Iniciar a interface Gradio
my_app.launch()