import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" import gradio as gr import torch import cv2 import numpy as np from preprocess import unsharp_masking import glob import time device = "cuda" if torch.cuda.is_available() else "cpu" print( "torch: ", torch.__version__, ) def ordenar_arquivos(img, modelo): ori = img.copy() img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) h, w = img.shape img_out = preprocessamento(img, modelo) return img_out, h, w, img, ori def preprocessamento(img, modelo='SE-RegUNet 4GF'): img = cv2.resize(img, (512, 512)) img = unsharp_masking(img).astype(np.uint8) if modelo == 'AngioNet' or modelo == 'UNet3+': img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6)) img_out = np.expand_dims(img, axis=0) elif modelo == 'SE-RegUNet 4GF': clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) clahe2 = cv2.createCLAHE(clipLimit=8.0, tileGridSize=(8, 8)) image1 = clahe1.apply(img) image2 = clahe2.apply(img) img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6)) image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6)) image2 = np.float32((image2 - image2.min()) / (image2.max() - image2.min() + 1e-6)) img_out = np.stack((img, image1, image2), axis=0) else: clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) image1 = clahe1.apply(img) image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6)) img_out = np.stack((image1,) * 3, axis=0) return img_out def processar_imagem_de_entrada(img, modelo, pipe): img = img.copy() pipe = pipe.to(device).eval() start = time.time() img, h, w, ori_gray, ori = ordenar_arquivos(img, modelo) img = torch.FloatTensor(img).unsqueeze(0).to(device) with torch.no_grad(): if modelo == 'AngioNet': img = torch.cat([img, img], dim=0) logit = np.round(torch.softmax(pipe.forward(img), dim=1).detach().cpu().numpy()[0, 0]).astype(np.uint8) spent = time.time() - start spent = f"{spent:.3f} segundos" if h != 512 or w != 512: logit = cv2.resize(logit, (h, w)) logit = logit.astype(bool) img_out = ori.copy() img_out[logit, 0] = 255 return spent, img_out # Load the models outside the function models = { 'SE-RegUNet 4GF': torch.jit.load('./model/SERegUNet4GF.pt'), 'SE-RegUNet 16GF': torch.jit.load('./model/SERegUNet16GF.pt'), 'AngioNet': torch.jit.load('./model/AngioNet.pt'), 'EffUNet++ B5': torch.jit.load('./model/EffUNetppb5.pt'), 'Reg-SA-UNet++': torch.jit.load('./model/RegSAUnetpp.pt'), 'UNet3+': torch.jit.load('./model/UNet3plus.pt'), } def processar_imagem_de_entrada_wrapper(img, modelo): model = models[modelo] spent, img_out = processar_imagem_de_entrada(img, modelo, model) # Define the function `has_disease` def has_disease(img_out): """ Checks if the angiogram has disease based on the segmentation. Args: img_out: The segmented angiogram. Returns: True if the angiogram has disease, False otherwise. """ percentage_of_vessels = np.sum(img_out) / (img_out.shape[0] * img_out.shape[1]) if percentage_of_vessels > 0.5: return True else: return False has_disease = has_disease(img_out) loc_doenca = np.where(img_out == 1) if has_disease: print("A doença está localizada nas seguintes coordenadas:") print(loc_doenca) else: print("Não há doença no angiograma.") return spent, img_out, has_disease, loc_doenca my_app = gr.Interface( fn=processar_imagem_de_entrada_wrapper, inputs=[ gr.inputs.Image(label="Angiograma:", shape=(512, 512)), gr.inputs.Dropdown(['SE-RegUNet 4GF', 'SE-RegUNet 16GF', 'AngioNet', 'EffUNet++ B5', 'Reg-SA-UNet++', 'UNet3+'], label='Modelo', default='SE-RegUNet 4GF'), ], outputs=[ gr.outputs.Label(label="Tempo decorrido"), gr.outputs.Image(type="numpy", label="Imagem de Saída"), gr.outputs.Label(label="Possui Doença?"), ], title="Segmentação de Angiograma Coronariano", description="Esta aplicação segmenta angiogramas coronarianos usando modelos de segmentação pré-treinados. Faça o upload de uma imagem de angiograma e selecione um modelo para visualizar o resultado da segmentação.\n\nSelecione uma imagem de angiograma coronariano e um modelo de segmentação no painel à esquerda.", theme="default", layout="vertical", allow_flagging=False, ) my_app.launch()