DHEIVER's picture
Update app.py
49b9d12
raw
history blame
No virus
7.57 kB
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
import gradio as gr
import torch
import cv2
import numpy as np
from preprocess import unsharp_masking
import time
from sklearn.cluster import KMeans
import os
device = "cuda" if torch.cuda.is_available() else "cpu"
# Função para ordenar e pré-processar a imagem de entrada
def ordenar_e_preprocessar_imagem(img, modelo):
ori = img.copy()
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
h, w = img.shape
img_out = preprocessar_imagem(img, modelo)
return img_out, h, w, img, ori
# Função para pré-processar a imagem com base no modelo selecionado
def preprocessar_imagem(img, modelo='SE-RegUNet 4GF'):
# Redimensionar a imagem para 512x512
img = cv2.resize(img, (512, 512))
# Aplicar a máscara de nitidez à imagem
img = unsharp_masking(img).astype(np.uint8)
# Função auxiliar para normalizar a imagem
def normalizar_imagem(img):
return np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
if modelo == 'AngioNet' or modelo == 'UNet3+':
img = normalizar_imagem(img)
img_out = np.expand_dims(img, axis=0)
elif modelo == 'SE-RegUNet 4GF':
clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
clahe2 = cv2.createCLAHE(clipLimit=8.0, tileGridSize=(8, 8))
image1 = clahe1.apply(img)
image2 = clahe2.apply(img)
img = normalizar_imagem(img)
image1 = normalizar_imagem(image1)
image2 = normalizar_imagem(image2)
img_out = np.stack((img, image1, image2), axis=0)
else:
clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
image1 = clahe1.apply(img)
image1 = normalizar_imagem(image1)
img_out = np.stack((image1,) * 3, axis=0)
return img_out
# Função para processar a imagem de entrada
def processar_imagem_de_entrada(img, modelo, pipe, salvar_resultado=False):
try:
# Faça uma cópia da imagem original
img = img.copy()
# Coloque o modelo na GPU (se disponível) e configure-o para modo de avaliação
pipe = pipe.to(device).eval()
# Registre o tempo de início
start = time.time()
# Pré-processe a imagem e obtenha informações de dimensão
img, h, w, ori_gray, ori = ordenar_e_preprocessar_imagem(img, modelo)
# Converta a imagem para o formato esperado pelo modelo e coloque-a na GPU
img = torch.FloatTensor(img).unsqueeze(0).to(device)
# Realize a inferência do modelo sem gradientes
with torch.no_grad():
if modelo == 'AngioNet':
img = torch.cat([img, img], dim=0)
logit = np.round(torch.softmax(pipe.forward(img), dim=1).detach().cpu().numpy()[0, 0]).astype(np.uint8)
# Calcule o tempo decorrido
spent = time.time() - start
spent = f"{spent:.3f} segundos"
# Redimensione o resultado, se necessário
if h != 512 or w != 512:
logit = cv2.resize(logit, (h, w))
# Converta o resultado para um formato booleano
logit = logit.astype(bool)
# Crie uma cópia da imagem original para saída e aplique a máscara
img_out = ori.copy()
img_out[logit, 0] = 255
# Salve o resultado em um arquivo se a opção estiver ativada
if salvar_resultado:
nome_arquivo = f'resultado_{int(time.time())}.png'
cv2.imwrite(nome_arquivo, img_out)
return spent, img_out
except Exception as e:
# Em caso de erro, retorne uma mensagem de erro
return str(e), None
# Carregar modelos pré-treinados
models = {
'SE-RegUNet 4GF': torch.jit.load('./model/SERegUNet4GF.pt'),
'SE-RegUNet 16GF': torch.jit.load('./model/SERegUNet16GF.pt'),
'AngioNet': torch.jit.load('./model/AngioNet.pt'),
'EffUNet++ B5': torch.jit.load('./model/EffUNetppb5.pt'),
'Reg-SA-UNet++': torch.jit.load('./model/RegSAUnetpp.pt'),
'UNet3+': torch.jit.load('./model/UNet3plus.pt'),
}
from scipy.spatial import distance
from scipy.ndimage import label
import numpy as np
# Adicionar a opção de salvar o resultado em um arquivo
def processar_imagem_de_entrada_wrapper(img, modelo, salvar_resultado=False):
model = models[modelo]
resultado, img_out = processar_imagem_de_entrada(img, modelo, model, salvar_resultado)
# Resto do código permanece inalterado
kmeans = KMeans(n_clusters=2, random_state=0)
flattened_img = img_out[:, :, 0].reshape((-1, 1))
kmeans.fit(flattened_img)
labels = kmeans.labels_
cluster_centers = kmeans.cluster_centers_
# Detecção de doenças usando K-Means
kmeans = KMeans(n_clusters=2, random_state=0)
flattened_img = img_out[:, :, 0].reshape((-1, 1)) # Use o canal de intensidade
kmeans.fit(flattened_img)
labels = kmeans.labels_
cluster_centers = kmeans.cluster_centers_
# Resto do código permanece inalterado
# Extração de características dos clusters
num_clusters = len(cluster_centers)
cluster_features = []
for i in range(num_clusters):
cluster_mask = labels == i # Create a boolean mask for the cluster
# Calcular área do cluster
area = np.sum(cluster_mask)
if area == 0: # Skip empty clusters
continue
# Calcular forma do cluster usando a relação entre área e perímetro
contours, _ = cv2.findContours(np.uint8(cluster_mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if len(contours) > 0:
perimeter = cv2.arcLength(contours[0], True)
compactness = 4 * np.pi * area / (perimeter ** 2)
cluster_features.append({'area': area, 'compactness': compactness})
# Decidir se há doença com base nas características dos clusters
has_disease_flag = any(feature['area'] >= 200 and feature['compactness'] < 0.3 for feature in cluster_features)
# Formatar o indicador de doença como uma string
if has_disease_flag:
status_doenca = "Sim"
explanation = "A máquina detectou uma possível doença nos vasos sanguíneos."
else:
status_doenca = "Não"
explanation = "A máquina não detectou nenhuma doença nos vasos sanguíneos."
# Resto do código permanece inalterado
return resultado, img_out, status_doenca, explanation, f"{num_analises} análises realizadas"
# Inicializar a contagem de análises
num_analises = 0
# Criar a interface Gradio
my_app = gr.Interface(
fn=processar_imagem_de_entrada_wrapper,
inputs=[
gr.inputs.Image(label="Angiograma:", shape=(512, 512)),
gr.inputs.Dropdown(['SE-RegUNet 4GF','SE-RegUNet 16GF', 'AngioNet', 'EffUNet++ B5', 'Reg-SA-UNet++', 'UNet3+'], label='Modelo', default='SE-RegUNet 4GF'),
gr.inputs.Checkbox(label="Salvar Resultado"),
],
outputs=[
gr.outputs.Label(label="Tempo decorrido"),
gr.outputs.Image(type="numpy", label="Imagem de Saída"),
gr.outputs.Label(label="Possui Doença?"),
gr.outputs.Label(label="Explicação"),
gr.outputs.Label(label="Análises Realizadas"),
],
title="Segmentação de Angiograma Coronariano",
description="Esta aplicação segmenta angiogramas coronarianos usando modelos de segmentação pré-treinados.",
theme="default",
layout="vertical",
allow_flagging=False,
)
# Iniciar a interface Gradio
my_app.launch()