import torch import torchvision from PIL import Image import numpy as np from skimage.morphology import erosion from dependecies.segroot.model import SegRoot from dependecies.segroot.dataloader import pad_pair_256, normalize from torchvision.transforms import v2 as transforms import onnxruntime as ort import cv2 as cv import os MODELS_PATH = r"./models" def pad_256(img_path): image = Image.open(img_path) W, H = image.size img, _ = pad_pair_256(image, image) NW, NH = img.size img = torchvision.transforms.ToTensor()(img) img = normalize(img) return img, (H, W, NH, NW) def pad_256_np(np_img): #image = Image.open(img_path) image = Image.fromarray(np_img) W, H = image.size img, _ = pad_pair_256(image, image) NW, NH = img.size img = torchvision.transforms.ToTensor()(img) img = normalize(img) return img, (H, W, NH, NW) def merge_images(files, path=""): is_array = False if type(files[0]) == np.ndarray: is_array = True final_img = [] resize_factor = 0.4 offset0 = 930 offset1 = 305 for index, file in enumerate(files): if is_array: img = file else: img = cv.imread(file) img = cv.cvtColor(img, cv.COLOR_BGR2RGB) #img = cv.resize(img, (0,0), fx=resize_factor, fy=resize_factor) img = cv.rotate(img, cv.ROTATE_90_CLOCKWISE) if index == 0: img = img[0:img.shape[0]-offset0,0:img.shape[1]] final_img = img elif index == len(file)-1: final_img = cv.vconcat([final_img, img]) else: #final_img = np.concatenate((final_img, img), axis=1) img = img[0:img.shape[0]-offset1,0:img.shape[1]] final_img = cv.vconcat([final_img, img]) final_img = cv.resize(final_img, (0,0), fx=resize_factor, fy=resize_factor) #cv.imwrite(path, final_img) print(final_img.shape) return final_img class RootSegmentor(): def __init__(self, model_type): self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.model_type = model_type if model_type != "seg_model": self.initialize() return def initialize(self): width = 8 depth = 5 weights_path = os.path.join(MODELS_PATH, r"best_segnet-(8,5)-0.6441.pt") if self.model_type == "segroot": #weights_path = os.path.join(r"D:\local_mydev\roots_finetuning\SegRoot0\weights\best_segnet-(8,5)-0.6441.pt" #weights_path = r"D:\local_mydev\SegRoot\weights\best_segnet-(8,5)-0.6441.pt" #weights_path = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\best_segnet-(8,5)-0.6441.pt" #weights_path = os.path.join(MODELS_PATH, r"AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\best_segnet-(8,5)-0.6441.pt") weights_path = os.path.join(MODELS_PATH, r"best_segnet-(8,5)-0.6441.pt") elif self.model_type == "segroot_finetuned": #weights_path = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\segroot-(8,5)_finetuned.pt" #weights_path = os.path.join(MODELS_PATH, r"AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\segroot-(8,5)_finetuned.pt") weights_path = os.path.join(MODELS_PATH, r"segroot-(8,5)_finetuned.pt") elif self.model_type == "segroot_finetuned_dec": #weights_path = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\segroot-(8,5)_finetuned_dec_full.pt" #weights_path = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\segroot-(8,5)_finetuned_clas.pt" #weights_path = os.path.join(MODELS_PATH, r"AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\segroot-(8,5)_finetuned_clas.pt") weights_path = os.path.join(MODELS_PATH, r"segroot-(8,5)_finetuned.pt") self.model = SegRoot(width, depth).to(self.device) if self.device.type == "cpu": print("load weights to cpu") #print(weights_path.as_posix()) self.model.load_state_dict(torch.load(weights_path, map_location="cpu")) else: print("load weights to gpu") #print(weights_path.as_posix()) self.model.load_state_dict(torch.load(weights_path)) for p in self.model.parameters(): p.requires_grad = False self.model.eval() return def predict(self, img_path): if self.model_type == "seg_model": print(str(type(img_path))) if type(img_path) == np.ndarray: img = img_path else: img = cv.imread(img_path) img = cv.cvtColor(img, cv.COLOR_BGR2RGB) weights_path = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\roots_model.onnx" weights_path = os.path.join(MODELS_PATH,"roots_model.onnx") ort_sess = ort.InferenceSession(weights_path ,providers=ort.get_available_providers() ) dim = img.shape transforms_list = [] transforms_list.append(transforms.ToTensor()) transforms_list.append(transforms.Resize((800,800))) #transforms_list.append(transforms.CenterCrop(800)) #transforms_list.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])) apply_t = transforms.Compose(transforms_list) img = apply_t(img) outputs = ort_sess.run(None, {'input': [img.numpy()]}) print(outputs) #np_res = outputs[0][0] output_image = outputs[0][:,:,1] final = cv.resize(output_image, (dim[0], dim[1])) return final else: thres = 0.9 print(str(type(img_path))) if type(img_path) == np.ndarray: img, dims = pad_256_np(img_path) else: img, dims = pad_256(img_path) H, W, NH, NW = dims img = img.to(self.device) img = img.unsqueeze(0) output = self.model(img) output = torch.squeeze(output) torch.cuda.empty_cache() prediction = output prediction[prediction >= thres] = 1.0 prediction[prediction < thres] = 0.0 if self.device.type == "cpu": prediction = prediction.detach().numpy() else: prediction = prediction.cpu().detach().numpy() prediction = erosion(prediction) # reverse padding prediction = prediction[ (NH - H) // 2 : (NH - H) // 2 + H, (NW - W) // 2 : (NW - W) // 2 + W ] return prediction