Spaces:
Runtime error
Runtime error
File size: 5,287 Bytes
938e515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import numpy as np
import cv2, torch
import os
import json
import argparse
import pycocotools.mask as mask_util
from tqdm import tqdm
def bbox_expand(img_height, img_width, bbox, exp_ratio):
x_min, y_min, x_max, y_max = bbox[:]
exp_x = (x_max - x_min) * ((exp_ratio - 1) / 2)
exp_y = (y_max - y_min) * ((exp_ratio - 1) / 2)
new_x_min = 0 if x_min - exp_x < 0 else np.round(x_min - exp_x)
new_y_min = 0 if y_min - exp_y < 0 else np.round(y_min - exp_y)
new_x_max = img_width - 1 if x_max + exp_x > img_width - 1 else np.round(x_max + exp_x)
new_y_max = img_height - 1 if y_max + exp_y > img_height - 1 else np.round(y_max + exp_y)
return int(new_x_min), int(new_y_min), int(new_x_max), int(new_y_max)
def make_crop_and_mask(img_info, pred, file_list, crop_save_dir, mask_save_dir, args):
img_name = img_info['file_name']
img_id = img_info['id'] - 1 # img_info['id'] start form 1
img_w = img_info['width']
img_h = img_info['height']
img = cv2.imread(os.path.join(args.img_dir, img_name))
exp_bbox = []
ori_bbox = []
bbox_name_list = []
bbox_score_list = []
person_idx = 0
panoptic_seg = np.zeros((img_h, img_w), dtype=np.uint8)
assert len(pred[img_id]['instances']) > 0, 'image without instance prediction'
for instance in pred[img_id]['instances']:
score = instance['score']
if score < args.conf_thres:
break
mask = mask_util.decode(instance['segmentation'])
mask_area = mask.sum()
if mask_area == 0: # if mask_area < img_w*img_h/1000:
continue
intersect = (mask > 0) & (panoptic_seg > 0)
intersect_area = intersect.sum()
if intersect_area * 1.0 / mask_area > args.overlap_threshold: # todo add args
continue
if intersect_area > 0:
mask = mask & (panoptic_seg == 0)
person_idx += 1
panoptic_seg = np.where(mask == 0, panoptic_seg, person_idx)
bbox_score_list.append(score)
ins_bbox = instance['bbox'] # [x,y,w,h] format
x_min, y_min, box_w, box_h = ins_bbox
x_max, y_max = x_min + box_w, y_min + box_h
exp_x_min, exp_y_min, exp_x_max, exp_y_max = bbox_expand(img_h, img_w, [x_min, y_min, x_max, y_max],
args.exp_ratio)
crop_img = img[exp_y_min:exp_y_max + 1, exp_x_min:exp_x_max + 1, :]
exp_bbox.append([exp_x_min, exp_y_min, exp_x_max, exp_y_max])
ori_bbox.append([x_min, y_min, x_max, y_max])
bbox_name = os.path.splitext(img_name)[0] + '_' + str(person_idx) + '_msrcnn.jpg'
bbox_name_list.append(bbox_name)
cv2.imwrite(os.path.join(crop_save_dir, bbox_name), crop_img)
assert person_idx > 0, 'image without instance'
mask_name = os.path.splitext(img_name)[0] + '_mask.npy'
np.save(os.path.join(mask_save_dir, mask_name), panoptic_seg)
############## json writing ##################
item = {}
item['dataset'] = 'CIHP'
item['im_name'] = img_name
item['img_height'] = img_h
item['img_width'] = img_w
item['center'] = [img_h / 2, img_w / 2]
item['person_num'] = person_idx
item['person_bbox'] = exp_bbox
item['real_person_bbox'] = ori_bbox
item['person_bbox_score'] = bbox_score_list
item['bbox_name'] = bbox_name_list
item['mask_name'] = mask_name
file_list.append(item)
json_file = {'root': file_list}
return json_file, file_list
def get_arguments():
parser = argparse.ArgumentParser(description="crop person val/test demo for inference")
parser.add_argument("--exp_ratio", type=float, default=1.2)
parser.add_argument("--overlap_threshold", type=float, default=0.5)
parser.add_argument("--conf_thres", type=float, default=0.5)
parser.add_argument("--img_dir", type=str,
default='/data03/v_xuyunqiu/data/instance-level_human_parsing/Testing/Images')
parser.add_argument("--save_dir", type=str,
default='/data03/v_xuyunqiu/Projects/experiment_data/testing/resnest_200_TTA_mask_nms_all_data')
parser.add_argument("--img_list", type=str,
default='/data03/v_xuyunqiu/Projects/pycococreator/annotations/CIHP_test.json')
parser.add_argument("--det_res", type=str,
default='/data02/v_xuyunqiu/detectron2-ResNeSt/tools/output_cihp_inference_resnest/inference_TTA/instances_predictions.pth')
return parser.parse_args()
def main(args):
img_info_list = json.load(open(args.img_list, encoding='UTF-8'))
pred = torch.load(args.det_res)
crop_save_dir = os.path.join(args.save_dir, 'crop_pic')
if not os.path.exists(crop_save_dir):
os.makedirs(crop_save_dir)
mask_save_dir = os.path.join(args.save_dir, 'crop_mask')
if not os.path.exists(mask_save_dir):
os.makedirs(mask_save_dir)
file_list = []
for img_info in tqdm(img_info_list['images']):
json_file, file_list = make_crop_and_mask(img_info, pred, file_list, crop_save_dir, mask_save_dir, args)
with open(os.path.join(args.save_dir, 'crop.json'), 'w') as f:
json.dump(json_file, f, indent=2)
if __name__ == '__main__':
args = get_arguments()
main(args)
|