# coding:utf-8 import os import numpy as np import cv2 from typing import Optional import torch # from models.transforms import ResizeLongestSide # from .transforms import ResizeLongestSide from torchvision import transforms def get_prompt_inp_scatter(scatter_file_): scatter_mask = cv2.imread(scatter_file_, cv2.IMREAD_UNCHANGED) return scatter_mask def pre_scatter_prompt(scatter, filp, device): if filp == True: scatter = cv2.flip(scatter, 1) img_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) scatter_torch = img_transform(scatter) scatter_torch = scatter_torch.to(device) return scatter_torch def get_prompt_inp(txt_file_, filp): f = open(txt_file_) lines = f.readlines() points = [] labels = [] boxes = [] masks = [] for line in lines: x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4, classname, _ = line.split(' ') # print(x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4, classname, _) x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4 = float(x_1), float(y_1), \ float(x_2), float(y_2), \ float(x_3), float(y_3), \ float(x_4), float(y_4) xmin = min(x_1, x_2, x_3, x_4) xmax = max(x_1, x_2, x_3, x_4) ymin = min(y_1, y_2, y_3, y_4) ymax = max(y_1, y_2, y_3, y_4) if filp: xmin = 1024.0 - xmin xmax = 1024.0 - xmax x_center = (xmin + xmax)/2 y_center = (ymin + ymax)/2 point = [x_center, y_center] box = [[xmin, ymin], [xmax, ymax]] # box = [xmin, ymin, xmax, ymax] mask = [] points.append(point) labels.append(classname) boxes.append(box) masks.append(mask) # boxes = boxes[:1] # return points, labels, boxes, masks return points, labels, boxes, None def pre_prompt(points=None, boxes=None, masks=None, device=None): points_torch = points if points != None: # points = points/16.0 points_torch = torch.as_tensor(points, dtype=torch.float, device=device) points_torch = points_torch/16.0 boxes_torch = boxes if boxes != None: # boxes = boxes/16.0 boxes_torch = torch.as_tensor(boxes, dtype=torch.float, device=device) boxes_torch = boxes_torch/16.0 # for box in boxes: # left_top, bottom_right = box masks_torch = masks if masks != None: masks_torch = torch.as_tensor(masks, dtype=torch.float, device=device) return points_torch, boxes_torch, masks_torch # def pre_prompt( # point_coords: Optional[np.ndarray] = None, # point_labels: Optional[np.ndarray] = None, # box: Optional[np.ndarray] = None, # mask_input: Optional[np.ndarray] = None, # device=None, # original_size = [1024, 1024] # ): # # transform = ResizeLongestSide(1024) # # Transform input prompts # coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None # if point_coords is not None: # assert ( # point_labels is not None # ), "point_labels must be supplied if point_coords is supplied." # point_coords = transform.apply_coords(point_coords, original_size) # coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device) # labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device) # coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] # if box is not None: # box = transform.apply_boxes(box, original_size) # box_torch = torch.as_tensor(box, dtype=torch.float, device=device) # box_torch = box_torch[None, :] # if mask_input is not None: # mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=device) # mask_input_torch = mask_input_torch[None, :, :, :] # # return coords_torch, labels_torch, box_torch, mask_input_torch if __name__ == '__main__': txt_dir = './ISAID/train/trainprompt/sub_labelTxt/' txt_list = os.listdir(txt_dir) txt_file_0 = txt_dir + txt_list[0] points, labels, boxes, masks = get_prompt_inp(txt_file_0) print(points) print(labels) print(boxes) # boxes = boxes / 16.0 boxes_torch = torch.as_tensor(boxes, dtype=torch.float) boxes_torch = boxes_torch/16.0 print(boxes_torch, boxes_torch.shape)