File size: 5,893 Bytes
283999e ff8277d 283999e ff8277d af135e8 ff8277d af135e8 ff8277d dab2261 ff8277d af135e8 283999e ff8277d dab2261 283999e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
import gradio as gr
import torch
import cv2
import numpy as np
from preprocess import unsharp_masking
import time
from sklearn.cluster import KMeans
device = "cuda" if torch.cuda.is_available() else "cpu"
# Função para ordenar e pré-processar a imagem de entrada
def ordenar_arquivos(img, modelo):
ori = img.copy()
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
h, w = img.shape
img_out = preprocessamento(img, modelo)
return img_out, h, w, img, ori
# Função para pré-processar a imagem com base no modelo selecionado
def preprocessamento(img, modelo='SE-RegUNet 4GF'):
img = cv2.resize(img, (512, 512))
img = unsharp_masking(img).astype(np.uint8)
if modelo == 'AngioNet' or modelo == 'UNet3+':
img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
img_out = np.expand_dims(img, axis=0)
elif modelo == 'SE-RegUNet 4GF':
clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
clahe2 = cv2.createCLAHE(clipLimit=8.0, tileGridSize=(8, 8))
image1 = clahe1.apply(img)
image2 = clahe2.apply(img)
img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6))
image2 = np.float32((image2 - image2.min()) / (image2.max() - image2.min() + 1e-6))
img_out = np.stack((img, image1, image2), axis=0)
else:
clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
image1 = clahe1.apply(img)
image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6))
img_out = np.stack((image1,) * 3, axis=0)
return img_out
# Função para processar a imagem de entrada
def processar_imagem_de_entrada(img, modelo, pipe):
img = img.copy()
pipe = pipe.to(device).eval()
start = time.time()
img, h, w, ori_gray, ori = ordenar_arquivos(img, modelo)
img = torch.FloatTensor(img).unsqueeze(0).to(device)
with torch.no_grad():
if modelo == 'AngioNet':
img = torch.cat([img, img], dim=0)
logit = np.round(torch.softmax(pipe.forward(img), dim=1).detach().cpu().numpy()[0, 0]).astype(np.uint8)
spent = time.time() - start
spent = f"{spent:.3f} segundos"
if h != 512 or w != 512:
logit = cv2.resize(logit, (h, w))
logit = logit.astype(bool)
img_out = ori.copy()
img_out[logit, 0] = 255
return spent, img_out
# Carregar modelos pré-treinados
models = {
'SE-RegUNet 4GF': torch.jit.load('./model/SERegUNet4GF.pt'),
'SE-RegUNet 16GF': torch.jit.load('./model/SERegUNet16GF.pt'),
'AngioNet': torch.jit.load('./model/AngioNet.pt'),
'EffUNet++ B5': torch.jit.load('./model/EffUNetppb5.pt'),
'Reg-SA-UNet++': torch.jit.load('./model/RegSAUnetpp.pt'),
'UNet3+': torch.jit.load('./model/UNet3plus.pt'),
}
from scipy.spatial import distance
from scipy.ndimage import label
import numpy as np
def processar_imagem_de_entrada_wrapper(img, modelo):
model = models[modelo]
spent, img_out = processar_imagem_de_entrada(img, modelo, model)
# Verificar se há doença usando K-Means
kmeans = KMeans(n_clusters=2, random_state=0)
flattened_img = img_out[:, :, 0].reshape((-1, 1)) # Use the intensity channel
kmeans.fit(flattened_img)
labels = kmeans.labels_
cluster_centers = kmeans.cluster_centers_
# Extração de características dos clusters
num_clusters = len(cluster_centers)
cluster_features = []
for i in range(num_clusters):
cluster_mask = labels == i # Create a boolean mask for the cluster
# Calcular área do cluster
area = np.sum(cluster_mask)
if area == 0: # Skip empty clusters
continue
# Calcular forma do cluster usando a relação entre área e perímetro
contours, _ = cv2.findContours(np.uint8(cluster_mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if len(contours) > 0:
perimeter = cv2.arcLength(contours[0], True)
compactness = 4 * np.pi * area / (perimeter ** 2)
cluster_features.append({'area': area, 'compactness': compactness})
# Decidir se há doença com base nas características dos clusters
has_disease_flag = any(feature['area'] >= 200 and feature['compactness'] < 0.3 for feature in cluster_features)
# Formatar o indicador de doença como uma string
if has_disease_flag:
status_doenca = "Sim"
else:
status_doenca = "Não"
# Adicionar a explicação com base no status de doença
if has_disease_flag:
explanation = "A máquina detectou uma possível doença nos vasos sanguíneos."
else:
explanation = "A máquina não detectou nenhuma doença nos vasos sanguíneos."
# ... (resto do seu código, se houver mais)
return spent, img_out, status_doenca, explanation
# Criar a interface Gradio
my_app = gr.Interface(
fn=processar_imagem_de_entrada_wrapper,
inputs=[
gr.inputs.Image(label="Angiograma:", shape=(512, 512)),
gr.inputs.Dropdown(['SE-RegUNet 4GF', 'SE-RegUNet 16GF', 'AngioNet', 'EffUNet++ B5', 'Reg-SA-UNet++', 'UNet3+'], label='Modelo', default='SE-RegUNet 4GF'),
],
outputs=[
gr.outputs.Label(label="Tempo decorrido"),
gr.outputs.Image(type="numpy", label="Imagem de Saída"),
gr.outputs.Label(label="Possui Doença?"),
gr.outputs.Label(label="Explicação"),
],
title="Segmentação de Angiograma Coronariano",
description="Esta aplicação segmenta angiogramas coronarianos usando modelos de segmentação pré-treinados.",
theme="default",
layout="vertical",
allow_flagging=False,
)
# Iniciar a interface Gradio
my_app.launch()
|