import argparse import copy import os import pickle import random import cv2 import numpy as np import string import torch from mmcv import Config, DictAction from mmcv.cnn import fuse_conv_bn from mmcv.runner import load_checkpoint from mmpose.core import wrap_fp16_model from mmpose.models import build_posenet from torchvision import transforms from models import * import torchvision.transforms.functional as F from tools.visualization import plot_results, plot_query_results, plot_modified_query import ast import shutil COLORS = [ [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0]] class Resize_Pad: def __init__(self, w=256, h=256): self.w = w self.h = h def __call__(self, image): _, w_1, h_1 = image.shape ratio_1 = w_1 / h_1 # check if the original and final aspect ratios are the same within a margin if round(ratio_1, 2) != 1: # padding to preserve aspect ratio if ratio_1 > 1: # Make the image higher hp = int(w_1 - h_1) hp = hp // 2 image = F.pad(image, (hp, 0, hp, 0), 0, "constant") return F.resize(image, [self.h, self.w]) else: wp = int(h_1 - w_1) wp = wp // 2 image = F.pad(image, (0, wp, 0, wp), 0, "constant") return F.resize(image, [self.h, self.w]) else: return F.resize(image, [self.h, self.w]) def transform_keypoints_to_pad_and_resize(keypoints, image_size): trans_keypoints = keypoints.clone() h, w = image_size[:2] ratio_1 = w / h if ratio_1 > 1: # width is bigger than height - pad height hp = int(w - h) hp = hp // 2 trans_keypoints[:, 1] = keypoints[:, 1] + hp trans_keypoints *= (256. / w) else: # height is bigger than width - pad width wp = int(image_size[1] - image_size[0]) wp = wp // 2 trans_keypoints[:, 0] = keypoints[:, 0] + wp trans_keypoints *= (256. / h) return trans_keypoints def parse_args(): parser = argparse.ArgumentParser(description='Pose Anything Demo') parser.add_argument('--support_points', help='support keypoints text descriptions') parser.add_argument('--support_skeleton', help='list of keypoints skeleton') parser.add_argument('--query', help='Image file') parser.add_argument('--config', default=None, help='test config file path') parser.add_argument('--checkpoint', default=None, help='checkpoint file') parser.add_argument('--outdir', default='output', help='checkpoint file') parser.add_argument( '--fuse-conv-bn', action='store_true', help='Whether to fuse conv and bn, this will slightly increase' 'the inference speed') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, default={}, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. For example, ' "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") args = parser.parse_args() return args def merge_configs(cfg1, cfg2): # Merge cfg2 into cfg1 # Overwrite cfg1 if repeated, ignore if value is None. cfg1 = {} if cfg1 is None else cfg1.copy() cfg2 = {} if cfg2 is None else cfg2 for k, v in cfg2.items(): if v: cfg1[k] = v return cfg1 def main(): random.seed(0) np.random.seed(0) torch.manual_seed(0) args = parse_args() cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True cfg.data.test.test_mode = True os.makedirs(args.outdir, exist_ok=True) # Load data point_descriptions = ast.literal_eval(args.support_points) query_img = cv2.imread(args.query) if query_img is None: raise ValueError('Fail to read image') # just a placeholder, we don't have input keypoints kp_src = torch.zeros((len(point_descriptions), 2)) preprocess = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), Resize_Pad(cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size)]) if args.support_skeleton is not None: skeleton = ast.literal_eval(args.support_skeleton) if len(skeleton) == 0: skeleton = [(0, 0)] model_device = "cuda" if torch.cuda.is_available() else "cpu" query_img = preprocess(query_img).flip(0)[None].to(model_device) # Create heatmap from keypoints genHeatMap = TopDownGenerateTargetFewShot() data_cfg = cfg.data_cfg data_cfg['image_size'] = np.array([cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size]) data_cfg['joint_weights'] = None data_cfg['use_different_joint_weights'] = False kp_src_3d = torch.concatenate((kp_src, torch.zeros(kp_src.shape[0], 1)), dim=-1) kp_src_3d_weight = torch.concatenate((torch.ones_like(kp_src), torch.zeros(kp_src.shape[0], 1)), dim=-1) # everything that is related to the support image is used as placeholder target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg, kp_src_3d, kp_src_3d_weight, sigma=1) target_s = torch.tensor(target_s).float()[None] target_weight_s = torch.tensor(target_weight_s).float()[None].to(model_device) data = { 'img_s': [0], 'img_q': query_img, 'target_s': [target_s], 'target_weight_s': [target_weight_s], 'target_q': None, 'target_weight_q': None, 'return_loss': False, 'img_metas': [{'sample_skeleton': [skeleton], 'query_skeleton': skeleton, 'sample_point_descriptions': np.array([point_descriptions]), 'sample_joints_3d': [kp_src_3d], 'query_joints_3d': kp_src_3d, 'sample_center': [kp_src.mean(dim=0)], 'query_center': kp_src.mean(dim=0), 'sample_scale': [kp_src.max(dim=0)[0] - kp_src.min(dim=0)[0]], 'query_scale': kp_src.max(dim=0)[0] - kp_src.min(dim=0)[0], 'sample_rotation': [0], 'query_rotation': 0, 'sample_bbox_score': [1], 'query_bbox_score': 1, 'query_image_file': '', 'sample_image_file': [''], }] } # Load model model = build_posenet(cfg.model) fp16_cfg = cfg.get('fp16', None) if fp16_cfg is not None: wrap_fp16_model(model) load_checkpoint(model, args.checkpoint, map_location='cpu') if args.fuse_conv_bn: model = fuse_conv_bn(model) model.to(model_device) model.eval() with torch.no_grad(): outputs = model(**data) # visualize results vis_q_weight = target_weight_s[0] vis_q_image = query_img[0].detach().cpu().numpy().transpose(1, 2, 0) name_idx = plot_query_results(vis_q_image, vis_q_weight, skeleton, torch.tensor(outputs['points']).squeeze(0), out_dir=args.outdir) shutil.copyfile(args.query, f'./{args.outdir}/{str(name_idx)}_query_in.png') if __name__ == '__main__': main()