import random import collections import gradio as gr import numpy as np import psutil import torch from PIL import ImageDraw, Image, ImageEnhance from matplotlib import pyplot as plt from mmcv import Config from mmcv.runner import load_checkpoint from mmpose.core import wrap_fp16_model from mmpose.models import build_posenet from torchvision import transforms import matplotlib.patheffects as mpe from demo import Resize_Pad from EdgeCape.models import * def process_img(support_image, global_state): global_state['images']['image_orig'] = support_image global_state['images']['image_kp'] = support_image reset_kp(global_state) return support_image, global_state def adj_mx_from_edges(num_pts, skeleton, device='cuda', normalization_fix=True): adj_mx = torch.empty(0, device=device) batch_size = len(skeleton) for b in range(batch_size): edges = torch.tensor(skeleton[b]) adj = torch.zeros(num_pts, num_pts, device=device) adj[edges[:, 0], edges[:, 1]] = 1 adj_mx = torch.concatenate((adj_mx, adj.unsqueeze(0)), dim=0) trans_adj_mx = torch.transpose(adj_mx, 1, 2) cond = (trans_adj_mx > adj_mx).float() adj = adj_mx + trans_adj_mx * cond - adj_mx * cond return adj def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_w, skeleton=None, prediction=None, radius=6, in_color=None, original_skeleton=None, img_alpha=0.6, target_keypoints=None): h, w, c = support_img.shape prediction = prediction[-1] * h if isinstance(prediction, torch.Tensor): prediction = prediction.cpu().numpy() if isinstance(skeleton, list): skeleton = adj_mx_from_edges(num_pts=100, skeleton=[skeleton]).cpu().numpy()[0] original_skeleton = skeleton support_img = (support_img - np.min(support_img)) / (np.max(support_img) - np.min(support_img)) query_img = (query_img - np.min(query_img)) / (np.max(query_img) - np.min(query_img)) error_mask = None for id, (img, w, keypoint, adj) in enumerate(zip([support_img, support_img, query_img], [support_w, support_w, query_w], # [support_kp, query_kp])): [support_kp, support_kp, prediction], [original_skeleton, skeleton, skeleton])): color = in_color f, axes = plt.subplots() plt.imshow(img, alpha=img_alpha) # On qeury image plot if id == 2 and target_keypoints is not None: error = np.linalg.norm(keypoint - target_keypoints, axis=-1) error_mask = error > (256 * 0.05) for k in range(keypoint.shape[0]): if w[k] > 0: kp = keypoint[k, :2] c = (1, 0, 0, 0.75) if w[k] == 1 else (0, 0, 1, 0.6) if error_mask is not None and error_mask[k]: c = (1, 1, 0, 0.75) patch = plt.Circle(kp, radius, color=c, path_effects=[mpe.withStroke(linewidth=8, foreground='black'), mpe.withStroke(linewidth=4, foreground='white'), mpe.withStroke(linewidth=2, foreground='black'), ], zorder=260) axes.add_patch(patch) axes.text(kp[0], kp[1], k, fontsize=10, color='black', ha="center", va="center", zorder=320, ) else: patch = plt.Circle(kp, radius, color=c, path_effects=[mpe.withStroke(linewidth=2, foreground='black')], zorder=200) axes.add_patch(patch) axes.text(kp[0], kp[1], k, fontsize=(radius + 4), color='white', ha="center", va="center", zorder=300, path_effects=[ mpe.withStroke(linewidth=max(1, int((radius + 4) / 5)), foreground='black')]) # axes.text(kp[0], kp[1], k) plt.draw() if adj is not None: # Make max value 6 draw_skeleton = adj ** 1 max_skel_val = np.max(draw_skeleton) draw_skeleton = draw_skeleton / max_skel_val * 6 for i in range(1, keypoint.shape[0]): for j in range(0, i): if w[i] > 0 and w[j] > 0 and original_skeleton[i][j] > 0: if color is None: num_colors = int((skeleton > 0.05).sum() / 2) color = iter(plt.cm.rainbow(np.linspace(0, 1, num_colors + 1))) c = next(color) elif isinstance(color, str): c = color elif isinstance(color, collections.Iterable): c = next(color) else: raise ValueError("Color must be a string or an iterable") if w[i] > 0 and w[j] > 0 and skeleton[i][j] > 0: width = draw_skeleton[i][j] stroke_width = width + (width / 3) patch = plt.Line2D([keypoint[i, 0], keypoint[j, 0]], [keypoint[i, 1], keypoint[j, 1]], linewidth=width, color=c, alpha=0.6, path_effects=[mpe.withStroke(linewidth=stroke_width, foreground='black')], zorder=1) axes.add_artist(patch) plt.axis('off') # command for hiding the axis. plt.subplots_adjust(0, 0, 1, 1, 0, 0) return plt def process(query_img, state, cfg_path='configs/test/1shot_split1.py', checkpoint_path='ckpt/1shot_split1.pth'): cfg = Config.fromfile(cfg_path) width, height, _ = state['original_support_image'].shape kp_src_np = np.array(state['kp_src']).copy().astype(np.float32) kp_src_np[:, 0] = kp_src_np[:, 0] / (width // 4) * cfg.model.encoder_config.img_size kp_src_np[:, 1] = kp_src_np[:, 1] / (height // 4) * cfg.model.encoder_config.img_size kp_src_np = np.flip(kp_src_np, 1).copy() kp_src_tensor = torch.tensor(kp_src_np).float() preprocess = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), Resize_Pad(cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size)]) if len(state['skeleton']) == 0: state['skeleton'] = [(0, 0)] support_img = preprocess(state['images']['image_orig']).flip(0)[None] np_query = np.array(query_img)[:, :, ::-1].copy() q_img = preprocess(np_query).flip(0)[None] # Create heatmap from keypoints genHeatMap = TopDownGenerateTargetFewShot() data_cfg = cfg.data_cfg data_cfg['image_size'] = np.array([cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size]) data_cfg['joint_weights'] = None data_cfg['use_different_joint_weights'] = False kp_src_3d = torch.cat( (kp_src_tensor, torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1) kp_src_3d_weight = torch.cat( (torch.ones_like(kp_src_tensor), torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1) target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg, kp_src_3d, kp_src_3d_weight, sigma=1) target_s = torch.tensor(target_s).float()[None] target_weight_s = torch.ones_like( torch.tensor(target_weight_s).float()[None]) data = { 'img_s': [support_img], 'img_q': q_img, 'target_s': [target_s], 'target_weight_s': [target_weight_s], 'target_q': None, 'target_weight_q': None, 'return_loss': False, 'img_metas': [{'sample_skeleton': [state['skeleton']], 'query_skeleton': state['skeleton'], 'sample_joints_3d': [kp_src_3d], 'query_joints_3d': kp_src_3d, 'sample_center': [kp_src_tensor.mean(dim=0)], 'query_center': kp_src_tensor.mean(dim=0), 'sample_scale': [ kp_src_tensor.max(dim=0)[0] - kp_src_tensor.min(dim=0)[0]], 'query_scale': kp_src_tensor.max(dim=0)[0] - kp_src_tensor.min(dim=0)[0], 'sample_rotation': [0], 'query_rotation': 0, 'sample_bbox_score': [1], 'query_bbox_score': 1, 'query_image_file': '', 'sample_image_file': [''], }] } # Load model model = build_posenet(cfg.model) fp16_cfg = cfg.get('fp16', None) if fp16_cfg is not None: wrap_fp16_model(model) load_checkpoint(model, checkpoint_path, map_location='cpu') model.eval() with torch.no_grad(): outputs = model(**data) # visualize results vis_s_weight = target_weight_s[0] vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0) vis_q_image = q_img[0].detach().cpu().numpy().transpose(1, 2, 0) support_kp = kp_src_3d out = plot_results(vis_s_image, vis_q_image, support_kp, vis_s_weight, None, vis_s_weight, outputs['skeleton'], torch.tensor(outputs['points']).squeeze(), original_skeleton=state['skeleton'], img_alpha=1.0, ) return out, state def update_examples(support_img, posed_support, query_img, state, r=0.015, width=0.02): state['color_idx'] = 0 state['original_support_image'] = np.array(support_img)[:, :, ::-1].copy() support_img, posed_support, _ = set_query(support_img, state, example=True) w, h = support_img.size draw_pose = ImageDraw.Draw(support_img) draw_limb = ImageDraw.Draw(posed_support) r = int(r * w) width = int(width * w) for pixel in state['kp_src']: leftUpPoint = (pixel[1] - r, pixel[0] - r) rightDownPoint = (pixel[1] + r, pixel[0] + r) twoPointList = [leftUpPoint, rightDownPoint] draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255)) draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255)) for limb in state['skeleton']: point_a = state['kp_src'][limb[0]][::-1] point_b = state['kp_src'][limb[1]][::-1] if state['color_idx'] < len(COLORS): c = COLORS[state['color_idx']] state['color_idx'] += 1 else: c = random.choices(range(256), k=3) draw_limb.line([point_a, point_b], fill=tuple(c), width=width) return support_img, posed_support, query_img, state def get_select_coords(global_state, evt: gr.SelectData ): """This function only support click for point selection """ xy = evt.index global_state["points"].append(xy) # point_idx = get_latest_points_pair(points) # if point_idx is None: # points[0] = {'start': xy, 'target': None} # print(f'Click Image - Start - {xy}') # elif points[point_idx].get('target', None) is None: # points[point_idx]['target'] = xy # print(f'Click Image - Target - {xy}') # else: # points[point_idx + 1] = {'start': xy, 'target': None} # print(f'Click Image - Start - {xy}') image_raw = global_state['images']['image_kp'] image_draw = update_image_draw( image_raw, xy, global_state ) global_state['images']['image_kp'] = image_draw return global_state, image_draw def get_closest_point_idx(pts_list, xy): x, y = xy closest_point = min(pts_list, key=lambda p: (p[0] - x) ** 2 + (p[1] - y) ** 2) closest_point_index = pts_list.index(closest_point) return closest_point_index def reset_skeleton(global_state): image = global_state["images"]["image_kp"] global_state["images"]["image_skel"] = image global_state["skeleton"] = [] global_state["curr_type_point"] = "start" global_state["prev_point"] = None return image def reset_kp(global_state): image = global_state["images"]["image_orig"] global_state["images"]["image_kp"] = image global_state["images"]["image_skel"] = image global_state["skeleton"] = [] global_state["points"] = [] global_state["curr_type_point"] = "start" global_state["prev_point"] = None return image, image def select_skeleton(global_state, evt: gr.SelectData, ): xy = evt.index pts_list = global_state["points"] closest_point_idx = get_closest_point_idx(pts_list, xy) image_raw = global_state['images']['image_skel'] if global_state["curr_type_point"] == "end": prev_point_idx = global_state["prev_point_idx"] prev_point = pts_list[prev_point_idx] points = [prev_point, xy] image_draw = draw_limbs_on_image(image_raw, points ) global_state['images']['image_skel'] = image_draw global_state['skeleton'].append([prev_point_idx, closest_point_idx]) global_state["curr_type_point"] = "start" global_state["prev_point_idx"] = None else: global_state["prev_point_idx"] = closest_point_idx global_state["curr_type_point"] = "end" return global_state, global_state['images']['image_skel'] def reverse_point_pairs(points): new_points = [] for p in points: new_points.append([p[1], p[0]]) return new_points def update_image_draw(image, points, global_state): if len(global_state["points"]) < 2: alpha = 0.5 else: alpha = 1.0 image_draw = draw_points_on_image(image, points, alpha=alpha) return image_draw def print_memory_usage(): # Print system memory usage print(f"System memory usage: {psutil.virtual_memory().percent}%") # Print GPU memory usage if torch.cuda.is_available(): device = torch.device("cuda") print(f"GPU memory usage: {torch.cuda.memory_allocated() / 1e9} GB") print( f"Max GPU memory usage: {torch.cuda.max_memory_allocated() / 1e9} GB") device_properties = torch.cuda.get_device_properties(device) available_memory = device_properties.total_memory - \ torch.cuda.max_memory_allocated() print(f"Available GPU memory: {available_memory / 1e9} GB") else: print("No GPU available") def draw_limbs_on_image(image, points,): color = tuple(random.choices(range(256), k=3)) overlay_rgba = Image.new("RGBA", image.size, 0) overlay_draw = ImageDraw.Draw(overlay_rgba) p_start, p_target = points if p_start is not None and p_target is not None: p_draw = int(p_start[0]), int(p_start[1]) t_draw = int(p_target[0]), int(p_target[1]) overlay_draw.line( (p_draw[0], p_draw[1], t_draw[0], t_draw[1]), fill=color, width=10, ) return Image.alpha_composite(image.convert("RGBA"), overlay_rgba).convert("RGB") def draw_points_on_image(image, points, radius_scale=0.01, alpha=1.): if alpha < 1: enhancer = ImageEnhance.Brightness(image) image = enhancer.enhance(1.1) overlay_rgba = Image.new("RGBA", image.size, 0) overlay_draw = ImageDraw.Draw(overlay_rgba) p_color = (255, 0, 0) rad_draw = int(image.size[0] * radius_scale) if points is not None: p_draw = int(points[0]), int(points[1]) overlay_draw.ellipse( ( p_draw[0] - rad_draw, p_draw[1] - rad_draw, p_draw[0] + rad_draw, p_draw[1] + rad_draw, ), fill=p_color, ) return Image.alpha_composite(image.convert("RGBA"), overlay_rgba).convert("RGB")