DHEIVER's picture
Update app.py
14ef73d
raw
history blame
5.63 kB
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
import gradio as gr
import torch
import cv2
import numpy as np
from preprocess import unsharp_masking
import time
device = "cuda" if torch.cuda.is_available() else "cpu"
print("torch: ", torch.__version__)
def ordenar_arquivos(img, modelo):
ori = img.copy()
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
h, w = img.shape
img_out = preprocessamento(img, modelo)
return img_out, h, w, img, ori
def preprocessamento(img, modelo='SE-RegUNet 4GF'):
img = cv2.resize(img, (512, 512))
img = unsharp_masking(img).astype(np.uint8)
if modelo == 'AngioNet' or modelo == 'UNet3+':
img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
img_out = np.expand_dims(img, axis=0)
elif modelo == 'SE-RegUNet 4GF':
clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
clahe2 = cv2.createCLAHE(clipLimit=8.0, tileGridSize=(8, 8))
image1 = clahe1.apply(img)
image2 = clahe2.apply(img)
img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6))
image2 = np.float32((image2 - image2.min()) / (image2.max() - image2.min() + 1e-6))
img_out = np.stack((img, image1, image2), axis=0)
else:
clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
image1 = clahe1.apply(img)
image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6))
img_out = np.stack((image1,) * 3, axis=0)
return img_out
def processar_imagem_de_entrada(img, modelo, pipe):
img = img.copy()
pipe = pipe.to(device).eval()
start = time.time()
img, h, w, ori_gray, ori = ordenar_arquivos(img, modelo)
img = torch.FloatTensor(img).unsqueeze(0).to(device)
with torch.no_grad():
if modelo == 'AngioNet':
img = torch.cat([img, img], dim=0)
logit = np.round(torch.softmax(pipe.forward(img), dim=1).detach().cpu().numpy()[0, 0]).astype(np.uint8)
spent = time.time() - start
spent = f"{spent:.3f} segundos"
if h != 512 or w != 512:
logit = cv2.resize(logit, (h, w))
logit = logit.astype(bool)
img_out = ori.copy()
img_out[logit, 0] = 255
return spent, img_out
# Define a função detect_disease atualizada
def detect_disease(img, limiar_porcentagem_estenose=20):
# Convertendo a imagem segmentada em uma máscara binária (0 para fundo e 1 para segmentação)
binary_mask = (img > 0).astype(np.uint8)
# Calculando a área total da segmentação
total_area = np.sum(binary_mask)
# Calculando a média das intensidades dos pixels na segmentação
mean_intensity = np.mean(img)
# Definindo um limiar adaptativo proporcional à média das intensidades
limiar_adaptativo = 0.5 * mean_intensity
# Calculando a área da estenose usando o limiar adaptativo
estenose_area = np.sum(binary_mask * (img >= limiar_adaptativo))
# Calculando a porcentagem de área de estenose em relação à área total da segmentação
porcentagem_estenose = (estenose_area / total_area) * 100
# Se a porcentagem de área de estenose for maior que o limiar, consideramos como "Disease Detected"
if porcentagem_estenose > limiar_porcentagem_estenose:
disease_status = "Disease Detected"
else:
disease_status = "No Disease"
return disease_status
def processar_imagem_de_entrada_wrapper(img, modelo):
model = models[modelo]
spent, segmented_img = processar_imagem_de_entrada(img, modelo, model)
disease_status = detect_disease(segmented_img)
return spent, segmented_img, disease_status
# Load the models outside the function
models = {
'SE-RegUNet 4GF': torch.jit.load('./model/SERegUNet4GF.pt'),
'SE-RegUNet 16GF': torch.jit.load('./model/SERegUNet16GF.pt'),
'AngioNet': torch.jit.load('./model/AngioNet.pt'),
'EffUNet++ B5': torch.jit.load('./model/EffUNetppb5.pt'),
'Reg-SA-UNet++': torch.jit.load('./model/RegSAUnetpp.pt'),
'UNet3+': torch.jit.load('./model/UNet3plus.pt'),
}
for model_name, model in models.items():
models[model_name] = model.to(device).eval()
my_app = gr.Interface(
fn=processar_imagem_de_entrada_wrapper,
inputs=[
gr.inputs.Image(label="Angiograma:", shape=(512, 512)),
gr.inputs.Dropdown(['SE-RegUNet 4GF', 'SE-RegUNet 16GF', 'AngioNet', 'EffUNet++ B5', 'Reg-SA-UNet++', 'UNet3+'], label='Modelo', default='SE-RegUNet 4GF'),
],
outputs=[
gr.outputs.Label(label="Tempo decorrido"),
gr.outputs.Image(type="numpy", label="Imagem de Saída"),
gr.outputs.Label(label="Status da Doença"),
],
title="Segmentação de Angiograma Coronariano",
description="Esta aplicação segmenta angiogramas coronarianos usando modelos de segmentação pré-treinados. Faça o upload de uma imagem de angiograma e selecione um modelo para visualizar o resultado da segmentação.\n\nSelecione uma imagem de angiograma coronariano e um modelo de segmentação no painel à esquerda.\n\nStatus da Doença:\n- 'Disease Detected': Indica que a segmentação detectou uma área significativa de estenose.\n- 'No Disease': Indica que a segmentação não detectou estenose significativa.\nCom base no resultado, é recomendado consultar um profissional de saúde para avaliação e orientação adicional, se necessário.",
theme="default",
allow_flagging="never", # O parâmetro "allow_flagging" deve receber uma string ('auto', 'manual', ou 'never')
)
my_app.launch()