phenotyping_pipeline / bgremover.py
Andres Felipe Ruiz-Hurtado
initial
9f3ae4a
import cv2 as cv
import numpy as np
from PIL import Image
import glob
import pathlib
import sys
import u2net_utils
import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim
from u2net_utils.data_loader import RescaleT
from u2net_utils.data_loader import ToTensor
from u2net_utils.data_loader import ToTensorLab
from u2net_utils.data_loader import SalObjDataset
from u2net_utils.model import U2NET # full size version 173.6 MB
from u2net_utils.model import U2NETP # small version u2net 4.7 MB
from torchvision import models
import onnxruntime as ort
import cv2 as cv
import numpy as np
from torchvision.transforms import v2 as transforms
# MODEL_PATH = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_gpu\models\u2net.pth"
# MODEL_PATH = r"D:\CIAT\catalogue\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models"
# MODEL_PATH = r"D:\local_mydata\models\spidermites\best_models"
MODEL_PATH = "./models"
#************************
# from loguru import logger
# from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
# import subprocess
# # Grounding DINO
# import GroundingDINO.groundingdino.datasets.transforms as T
# from GroundingDINO.groundingdino.models import build_model
# from GroundingDINO.groundingdino.util import box_ops
# from GroundingDINO.groundingdino.util.slconfig import SLConfig
# from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
# from huggingface_hub import hf_hub_download
import gc
def clear():
gc.collect()
torch.cuda.empty_cache()
# normalize the predicted SOD probability map
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d-mi)/(ma-mi)
return dn
class BackgroundRemover():
def __init__(self):
#Load model
#model_dir = "/workspace/u2net.pth"
#model_dir = "D:/local_mydata/models/u2net.pth"
model_dir = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_gpu\models\u2net.pth"
model_dir = os.path.join(MODEL_PATH, "u2net.pth")
## Load model
net = U2NET(3,1)
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_dir))
net.cuda()
else:
net.load_state_dict(torch.load(model_dir, map_location='cpu'))
net.eval()
self.net = net
def remove_background(self, filepath_image):
img_name_list = [filepath_image]
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
lbl_name_list = [],
transform=transforms.Compose([RescaleT(320),
ToTensorLab(flag=0)])
)
test_salobj_dataloader = DataLoader(test_salobj_dataset,
batch_size=1,
shuffle=False,
num_workers=1)
net = self.net
for i_test, data_test in enumerate(test_salobj_dataloader):
print("inferencing:",img_name_list[i_test].split(os.sep)[-1])
inputs_test = data_test['image']
inputs_test = inputs_test.type(torch.FloatTensor)
if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)
d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
# normalization
pred = d1[:,0,:,:]
pred = normPRED(pred)
# save results to test_results folder
#if not os.path.exists(prediction_dir):
# os.makedirs(prediction_dir, exist_ok=True)
#save_output(img_name_list[i_test],pred,prediction_dir)
predict = pred
predict = predict.squeeze()
#mask_torch.permute(1, 2, 0).detach().cpu().numpy()
predict_np = predict.cpu().data.numpy()
img = cv.imread(filepath_image)
w = img.shape[1]
h = img.shape[0]
#im = Image.fromarray(predict_np*255).convert('RGB')
#image = io.imread(filepath_image)
#imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
imo = cv.resize(predict_np, (w,h), cv.INTER_LINEAR )
#del d1,d2,d3,d4,d5,d6,d7
return imo
def remove_background_save(self, path_in, path_out, path_out_mask = None):
print("remove_background_save")
mask_torch = self.remove_background(path_in)
mask = mask_torch*255
mask = mask.astype(np.uint8)
img = cv.imread(path_in)
mask0 = mask#cv.UMat(cv.imread(mask,0))
#127
#200
ret,binary_mask = cv.threshold(mask0,80,255,cv.THRESH_BINARY)
binary_mask = np.uint8(binary_mask)
res = cv.bitwise_and(img,img, mask = binary_mask)
cv.imwrite(path_out, res)
if not (path_out_mask == None):
cv.imwrite(path_out_mask, mask)
def remove_background_dir(self, path_in, path_out):
img_name_list = glob.glob(os.path.join(path_in, "*.jpg"))
for img_name in img_name_list:
img_name_output = img_name.replace(path_in, path_out)
if not os.path.exists(img_name_output):
self.remove_background_save(img_name, img_name_output)
print(img_name.replace(path_in, path_out))
def remove_background_gradio(self, np_image):
w = np_image.shape[1]
h = np_image.shape[0]
#image = torch.tensor(np_image)
#image = image.permute(2,0,1)
image = np_image#Image.fromarray(np_image)
imidx = np.array([0])
#label = "test"
#***
label_3 = np.zeros(image.shape)
label = np.zeros(label_3.shape[0:2])
if(3==len(label_3.shape)):
label = label_3[:,:,0]
elif(2==len(label_3.shape)):
label = label_3
if(3==len(image.shape) and 2==len(label.shape)):
label = label[:,:,np.newaxis]
elif(2==len(image.shape) and 2==len(label.shape)):
image = image[:,:,np.newaxis]
label = label[:,:,np.newaxis]
#***
sample = {'imidx':imidx, 'image':image, 'label':label}
print(image.shape)
print(label.shape)
eval_transform = transforms.Compose([RescaleT(320),ToTensorLab(flag=0)])
#eval_transform = transforms.Compose([RescaleT(320)])
#eval_transform = transforms.Compose([RescaleT(320)])
#eval_transform = transforms.Compose([ToTensorLab(flag=0)])
#eval_transform = transforms.Compose([transforms.Resize(320)
# , transforms.ToTensor()])
#eval_transform = transforms.Compose([transforms.Resize(320)])
test_salobj_dataloader = DataLoader(sample,
batch_size=1,
shuffle=False,
num_workers=1)
sample = eval_transform(sample)
net = self.net
#for i_test, data_test in enumerate(test_salobj_dataloader):
#device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#x = eval_transform(sample)
#x = x[:3, ...].to(device)
inputs_test = sample['image']
inputs_test = inputs_test.type(torch.FloatTensor)
inputs_test = inputs_test.unsqueeze(0)
print(inputs_test.shape)
if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)
d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
# normalization
pred = d1[:,0,:,:]
pred = normPRED(pred)
predict = pred
predict = predict.squeeze()
#mask_torch.permute(1, 2, 0).detach().cpu().numpy()
predict_np = predict.cpu().data.numpy()
imo = cv.resize(predict_np, (w,h), cv.INTER_LINEAR )
mask = imo*255
mask = mask.astype(np.uint8)
mask0 = mask#cv.UMat(cv.imread(mask,0))
#127
#200
ret,binary_mask = cv.threshold(mask0,80,255,cv.THRESH_BINARY)
#ret,binary_mask = cv.threshold(mask0,233,255,cv.THRESH_BINARY)
binary_mask = np.uint8(binary_mask)
res = cv.bitwise_and(np_image,np_image, mask = binary_mask)
return mask, res
def apply_mask(self, input, mask, threshold):
mask = cv.cvtColor(mask, cv.COLOR_BGR2GRAY)
ret,binary_mask = cv.threshold(mask,threshold,255,cv.THRESH_BINARY)
#binary_mask = np.uint8(binary_mask)
#binary_mask = mask
print("apply mask")
print(input.shape)
print(input.dtype)
print(binary_mask.shape)
print(binary_mask.dtype)
res = cv.bitwise_and(input,input, mask = binary_mask)
# foreground_alpha = mask.astype(np.float32) / 255.0
# # Create a new image to store the result with same size and type as foreground
# blended_image = np.zeros_like(input)
# # Loop through each pixel and apply alpha based on mask value
# for channel in range(3): # Loop through BGR channels
# blended_image[:, :, channel] = input[:, :, channel] * foreground_alpha
return res, binary_mask
def get_transform(train = True):
transforms_list = []
#if train:
# transforms.append(T.RandomHorizontalFlip(0.5))
transforms_list.append(transforms.Resize(256))
transforms_list.append(transforms.CenterCrop(256))
#transforms_list.append(transforms.ToDtype(torch.float, scale=True))
transforms_list.append(transforms.ToTensor())
#transforms_list.append(transforms.ToDtype(torch.float32, scale=True))
transforms_list.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
return transforms.Compose(transforms_list)
class DamageClassifier():
def __init__(self):
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model_name =""
def initialize(self, model_name):
#Load model
if model_name == "Resnet18":
model_filepath = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\resnet18_SpidermitesModel.pth"
model_filepath = os.path.join(MODEL_PATH, "resnet18_SpidermitesModel.pth")
model = models.resnet18(weights='IMAGENET1K_V1')
if model_name == "Resnet152":
model_filepath = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\short_resnet152_SpidermitesModel_44_44.pth"
model_filepath = os.path.join(MODEL_PATH, "short_resnet152_SpidermitesModel_44_44.pth")
model = models.resnet152(weights='IMAGENET1K_V1')
if model_name == "Googlenet":
model_filepath = r"\\catalogue.cgiarad.org\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\regnet_x_32gf_SpidermitesModel.pth"
model_filepath = model_filepath = os.path.join(MODEL_PATH, "regnet_x_32gf_SpidermitesModel.pth")
model = models.regnet_x_32gf(weights='IMAGENET1K_V1')
if model_name == "Regnet32":
model_filepath = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\short_resnet18_SpidermitesModel.pth"
model_filepath = model_filepath = os.path.join(MODEL_PATH, "short_resnet18_SpidermitesModel.pth")
model = models.resnet18(weights='IMAGENET1K_V1')
#Add fully connected layer at the end with num_classes as output
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 4)
if torch.cuda.is_available():
model.load_state_dict(torch.load(model_filepath))
model.cuda()
else:
model.load_state_dict(torch.load(model_filepath, map_location='cpu'))
model.eval()
self.model = model
self.model_name = model_name
return
def inference(self, np_image, model_name):
if model_name == "Regnet":
model_filepath = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\regnet_x_32gf_SpidermitesModel.onnx"
model_filepath = model_filepath = os.path.join(MODEL_PATH, "regnet_x_32gf_SpidermitesModel.onnx")
ort_sess = ort.InferenceSession(model_filepath
,providers=ort.get_available_providers()
)
transforms_list = []
transforms_list.append(transforms.ToTensor())
transforms_list.append(transforms.Resize(512))
transforms_list.append(transforms.CenterCrop(512))
transforms_list.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
apply_t = transforms.Compose(transforms_list)
img = apply_t(np_image)
imgs = np.array([img.numpy()])
outputs = ort_sess.run(None, {'input': [img.numpy()]})
np_res = outputs[0][0]
final_res = {'0-(No damage)': np_res[0]
,'1-3-(Moderately damaged)': np_res[1]
,'4-7-(Damaged)': np_res[2]
,'8-10-(Severely damaged)': np_res[3]}
return final_res
else:
if self.model_name != model_name:
self.initialize(model_name)
with torch.no_grad():
print("inference")
print(np_image.shape)
pil_image = Image.fromarray(np_image.astype('uint8'))
data_transforms = get_transform(train = False)
img = data_transforms(pil_image)
inputs = img.to(self.device)
outputs = self.model(inputs.unsqueeze(0))
#_, preds = torch.max(outputs, 1)
print(outputs)
_, preds = torch.max(outputs, 1)
print(preds)
m = nn.Softmax(dim=1)
res = m(outputs)
print(res)
np_res = res[0].cpu().numpy()
print(np_res)
final_res = {'0-(No damage)': np_res[0]
,'1-3-(Moderately damaged)': np_res[1]
,'4-7-(Damaged)': np_res[2]
,'8-10-(Severely damaged)': np_res[3]}
return final_res
class ColorCheckerDetector():
def __init__(self):
return
def process(self, np_image_mask, np_image):
ret,binary_mask = cv.threshold(np_image_mask,80,255,cv.THRESH_BINARY)
binary_mask_C = cv.cvtColor(binary_mask, cv.COLOR_BGR2GRAY) #change to single channel
(contours, hierarchy) = cv.findContours(binary_mask_C, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
main_contour = contours[0]
# compute the center of the contour
moments = cv.moments(main_contour)
cx = int(moments["m10"] / moments["m00"])
cy = int(moments["m01"] / moments["m00"])
# Bounding rect
bb_x,bb_y,bb_w,bb_h = cv.boundingRect(binary_mask_C)
# Min Bounding rect
rect = cv.minAreaRect(main_contour)
box = cv.boxPoints(rect)
box = np.int64(box)
# Fitting line
rows,cols = binary_mask_C.shape[:2]
#[vx,vy,x,y] = cv.fitLine(main_contour, cv.DIST_L2,0,0.01,0.01)
[vx,vy,x,y] = cv.fitLine(box, cv.DIST_L2,0,0.01,0.01)
lefty = int((-x*vy/vx) + y)
righty = int(((cols-x)*vy/vx)+y)
point1 = (cols-1,righty)
point2 = (0,lefty)
angle = np.arctan2(np.abs(righty-lefty),cols)
# rotation matrix
M_rot = cv.getRotationMatrix2D((cx, cy), -angle*180.0/np.pi, 1.0)
rotated = cv.warpAffine(np_image, M_rot, (binary_mask.shape[1], binary_mask.shape[0]))
#perspective transform
input_pts = box.astype(np.float32)
maxHeight = 200
maxWidth = 290
output_pts = np.float32([[0, 0],
[maxWidth - 1, 0],
[maxWidth - 1, maxHeight - 1] ,
[0, maxHeight - 1]]
)
M_per = cv.getPerspectiveTransform(input_pts,output_pts)
corrected = cv.warpPerspective(np_image,M_per,(maxWidth, maxHeight),flags=cv.INTER_LINEAR)
res = cv.drawContours(np_image, main_contour, -1, (255,255,0), 5)
res = cv.rectangle(res,(bb_x,bb_y),(bb_x+bb_w,bb_y+bb_h),(0,255,0),5)
res = cv.drawContours(res,[box],0,(0,0,255),5)
res = cv.line(res,(cols-1,righty),(0,lefty),(0,0,255),5)
return [res, rotated, corrected]
class BatchProcessor():
def __init__(self):
return
def batch_process(self, input_dir, output_dir, output_suffixes = ["output"], format="jpg", pattern='**/*.tiff', processing_fc=None, output_format = None):
if processing_fc == None:
print("Processing function is None")
return
else:
if output_format == None:
output_format = format
# Get list of files in folder and subfolders
pattern = '**/*.' + format
files = glob.glob(pattern, root_dir=input_dir, recursive=True)
for file in files:
filepath = os.path.join(input_dir, file)
basename = os.path.basename(filepath)
parent_dir = os.path.dirname(filepath)
extra_path = file.replace(basename,"")
output_dir = os.path.join(output_dir, extra_path)
# Create output filepath list
output_filepaths = []
for suffix in output_suffixes:
output_filepaths.append(os.path.join(output_dir, basename.replace("." + format, "_" + suffix + "." + output_format)))
if not os.path.exists(output_filepaths[0]):# Process only if first output file does not exist
if not os.path.exists(output_dir): # Create subfolders if necessary
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
processing_fc(filepath, output_filepaths) # Process and save file
print(file)
print(output_filepaths[0])
print("****")
class Segmentor():
def __init__(self):
self.sam_predictor = None
self.groundingdino_model = None
#self.sam_checkpoint = './sam_vit_h_4b8939.pth'
#self.sam_checkpoint = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\sam_vit_h_4b8939.pth"
self.sam_checkpoint = r"D:\local_mydev\Grounded-Segment-Anything\sam_vit_h_4b8939.pth"
# self.config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
# self.ckpt_repo_id = "ShilongLiu/GroundingDINO"
# self.ckpt_filename = "groundingdino_swint_ogc.pth"
self.config_file = r"D:\local_mydev\gsam\GroundingDINO\groundingdino\config\GroundingDINO_SwinT_OGC.py"
self.ckpt_repo_id = "ShilongLiu/GroundingDINO"
self.ckpt_filename = "groundingdino_swint_ogc.pth"
self.device ='cpu'
self.load_sam_model(self.device)
self.load_groundingdino_model(self.device)
return
def get_sam_vit_h_4b8939(self):
return
# if not os.path.exists('./sam_vit_h_4b8939.pth'):
# logger.info(f"get sam_vit_h_4b8939.pth...")
# result = subprocess.run(['wget', '-nv', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True)
# print(f'wget sam_vit_h_4b8939.pth result = {result}')
def load_sam_model(self, device):
sam_checkpoint = self.sam_checkpoint
# initialize SAM
self.get_sam_vit_h_4b8939()
logger.info(f"initialize SAM model...")
sam_device = device
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
self.sam_predictor = SamPredictor(sam_model)
self.sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
def get_grounding_output(self, model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
logits.shape[0]
# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
logits_filt.shape[0]
# get phrase
tokenlizer = model.tokenizer
tokenized = tokenlizer(caption)
# build pred
pred_phrases = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
if with_logits:
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
else:
pred_phrases.append(pred_phrase)
return boxes_filt, pred_phrases
def load_model_hf(self, model_config_path, repo_id, filename, device='cpu'):
args = SLConfig.fromfile(model_config_path)
model = build_model(args)
args.device = device
cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
checkpoint = torch.load(cache_file, map_location=device)
print(checkpoint['model'])
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
print("Model loaded from {} \n => {}".format(cache_file, log))
_ = model.eval()
return model
def load_groundingdino_model(self, device):
config_file = self.config_file
ckpt_repo_id = self.ckpt_repo_id
ckpt_filename = self.ckpt_filename
# initialize groundingdino model
logger.info(f"initialize groundingdino model...")
self.groundingdino_model = self.load_model_hf(config_file, ckpt_repo_id, ckpt_filename, device=device) #'cpu')
logger.info(f"initialize groundingdino model...{type(self.groundingdino_model)}")
def show_mask(self, mask, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
color = np.array([1.0, 0, 0, 1.0])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
return mask_image
def process(self, np_image, text_prompt):
results = []
results.append(np_image)
#results.append(np_image)
sam_predictor = self.sam_predictor
groundingdino_model = self.groundingdino_model
image = np_image
#text_prompt = text_prompt.strip()
box_threshold = 0.3
text_threshold = 0.25
size = image.shape
H, W = size[1], size[0]
# RUN grounding dino model
groundingdino_device = 'cpu'
#image_dino = torch.from_numpy(image)
image_dino = Image.fromarray(image)
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
print(image.shape)
image_dino, _ = transform(image_dino, None) # 3, h, w
boxes_filt, pred_phrases =self.get_grounding_output(
groundingdino_model, image_dino, text_prompt, box_threshold, text_threshold, device=groundingdino_device
)
if sam_predictor:
sam_predictor.set_image(image)
if sam_predictor:
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
masks, _, _, _ = sam_predictor.predict_torch(
point_coords = None,
point_labels = None,
boxes = transformed_boxes,
multimask_output = False,
)
print("RESULTS*************")
print(len(masks))
# results = []
for mask in masks:
print(type(mask))
print(mask.shape)
#mask_img = mask.cpu().data.numpy()
mask_img =self.show_mask(mask.cpu().numpy())
print(type(mask_img))
print(mask_img.shape)
results.append(mask_img)
# results.append(mask.cpu().numpy())
return results
#assert sam_checkpoint, 'sam_checkpoint is not found!'
return None