File size: 8,127 Bytes
7206ed3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
# Built from https://huggingface.co/spaces/hlydecker/MegaDetector_v5
# Built from https://huggingface.co/spaces/sofmi/MegaDetector_DLClive/blob/main/app.py
# Built from https://huggingface.co/spaces/Neslihan/megadetector_dlcmodels/blob/main/app.py
import os
import yaml
import numpy as np
from matplotlib import cm
import gradio as gr
from PIL import Image, ImageColor, ImageFont, ImageDraw
# check git lfs pull!!
from DLC_models.download_utils import DownloadModel
from dlclive import DLCLive, Processor
from viz_utils import save_results_as_json, draw_keypoints_on_image, draw_bbox_w_text, save_results_only_dlc
from detection_utils import predict_md, crop_animal_detections, predict_dlc
from ui_utils import gradio_inputs_for_MD_DLC, gradio_outputs_for_MD_DLC, gradio_description_and_examples
# import pdb
#########################################
# Input params - Global vars
MD_models_dict = {'md_v5a': "MD_models/md_v5a.0.0.pt", #
'md_v5b': "MD_models/md_v5b.0.0.pt"}
# DLC models target dirs
DLC_models_dict = {#'full_cat': "DLC_models/DLC_Cat/",
#'full_dog': "DLC_models/DLC_Dog/",
'full_human': "DLC_models/DLC_human_dancing/",
'full_macaque': 'DLC_models/DLC_monkey/',
'primate_face': "DLC_models/DLC_FacialLandmarks/"}
# FONTS = {'amiko': "fonts/Amiko-Regular.ttf",
# 'nature': "fonts/LoveNature.otf",
# 'painter':"fonts/PainterDecorator.otf",
# 'animals': "fonts/UncialAnimals.ttf",
# 'zen': "fonts/ZEN.TTF"}
#####################################################
def predict_pipeline(img_input,
mega_model_input,
dlc_model_input_str,
flag_dlc_only,
flag_show_str_labels,
bbox_likelihood_th,
kpts_likelihood_th,
font_style,
font_size,
keypt_color,
marker_size,
):
if not flag_dlc_only:
############################################################
# ### Run Megadetector
md_results = predict_md(img_input,
MD_models_dict[mega_model_input], #mega_model_input,
size=640) #Image.fromarray(results.imgs[0])
################################################################
# Obtain animal crops for bboxes with confidence above th
list_crops = crop_animal_detections(img_input,
md_results,
bbox_likelihood_th)
############################################################
## Get DLC model and label map
# If model is found: do not download (previous execution is likely within same day)
# TODO: can we ask the user whether to reload dlc model if a directory is found?
if os.path.isdir(DLC_models_dict[dlc_model_input_str]) and \
len(os.listdir(DLC_models_dict[dlc_model_input_str])) > 0:
path_to_DLCmodel = DLC_models_dict[dlc_model_input_str]
else:
path_to_DLCmodel = DownloadModel(dlc_model_input_str,
DLC_models_dict[dlc_model_input_str])
# extract map label ids to strings
pose_cfg_path = os.path.join(DLC_models_dict[dlc_model_input_str],
'pose_cfg.yaml')
with open(pose_cfg_path, "r") as stream:
pose_cfg_dict = yaml.safe_load(stream)
map_label_id_to_str = dict([(k,v) for k,v in zip([el[0] for el in pose_cfg_dict['all_joints']], # pose_cfg_dict['all_joints'] is a list of one-element lists,
pose_cfg_dict['all_joints_names'])])
##############################################################
# Run DLC and visualise results
dlc_proc = Processor()
# if required: ignore MD crops and run DLC on full image [mostly for testing]
if flag_dlc_only:
# compute kpts on input img
list_kpts_per_crop = predict_dlc([np.asarray(img_input)],
kpts_likelihood_th,
path_to_DLCmodel,
dlc_proc)
# draw kpts on input img #fix!
draw_keypoints_on_image(img_input,
list_kpts_per_crop[0], # a numpy array with shape [num_keypoints, 2].
map_label_id_to_str,
flag_show_str_labels,
use_normalized_coordinates=False,
font_style=font_style,
font_size=font_size,
keypt_color=keypt_color,
marker_size=marker_size)
donw_file = save_results_only_dlc(list_kpts_per_crop[0], map_label_id_to_str,dlc_model_input_str)
return img_input, donw_file
else:
# Compute kpts for each crop
list_kpts_per_crop = predict_dlc(list_crops,
kpts_likelihood_th,
path_to_DLCmodel,
dlc_proc)
# resize input image to match megadetector output
img_background = img_input.resize((md_results.ims[0].shape[1],
md_results.ims[0].shape[0]))
# draw keypoints on each crop and paste to background img
for ic, (np_crop, kpts_crop) in enumerate(zip(list_crops,
list_kpts_per_crop)):
img_crop = Image.fromarray(np_crop)
# Draw keypts on crop
draw_keypoints_on_image(img_crop,
kpts_crop, # a numpy array with shape [num_keypoints, 2].
map_label_id_to_str,
flag_show_str_labels,
use_normalized_coordinates=False, # if True, then I should use md_results.xyxyn for list_kpts_crop
font_style=font_style,
font_size=font_size,
keypt_color=keypt_color,
marker_size=marker_size)
# Paste crop in original image
img_background.paste(img_crop,
box = tuple([int(t) for t in md_results.xyxy[0][ic,:2]]))
# Plot bbox
bb_per_animal = md_results.xyxy[0].tolist()[ic]
pred = md_results.xyxy[0].tolist()[ic][4]
if bbox_likelihood_th < pred:
draw_bbox_w_text(img_background,
bb_per_animal,
font_style=font_style,
font_size=font_size) # TODO: add selectable color for bbox?
# Save detection results as json
download_file = save_results_as_json(md_results,list_kpts_per_crop,map_label_id_to_str, bbox_likelihood_th,dlc_model_input_str,mega_model_input)
return img_background, download_file
#########################################################
# Define user interface and launch
inputs = gradio_inputs_for_MD_DLC(list(MD_models_dict.keys()),
list(DLC_models_dict.keys()))
outputs = gradio_outputs_for_MD_DLC()
[gr_title,
gr_description,
examples] = gradio_description_and_examples()
# launch
demo = gr.Interface(predict_pipeline,
inputs=inputs,
outputs=outputs,
title=gr_title,
description=gr_description,
examples=examples,
theme="huggingface")
demo.launch(enable_queue=True, share=True)
|