Spaces:
Running
Running
import gradio as gr | |
import torch | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import supervision as sv | |
from transformers import ( | |
RTDetrForObjectDetection, | |
RTDetrImageProcessor, | |
VitPoseConfig, | |
VitPoseForPoseEstimation, | |
VitPoseImageProcessor, | |
) | |
KEYPOINT_LABEL_MAP = { | |
0: "Nose", | |
1: "L_Eye", | |
2: "R_Eye", | |
3: "L_Ear", | |
4: "R_Ear", | |
5: "L_Shoulder", | |
6: "R_Shoulder", | |
7: "L_Elbow", | |
8: "R_Elbow", | |
9: "L_Wrist", | |
10: "R_Wrist", | |
11: "L_Hip", | |
12: "R_Hip", | |
13: "L_Knee", | |
14: "R_Knee", | |
15: "L_Ankle", | |
16: "R_Ankle", | |
} | |
class KeypointDetector: | |
def __init__(self): | |
self.person_detector = None | |
self.person_processor = None | |
self.pose_model = None | |
self.pose_processor = None | |
self.load_models() | |
def load_models(self): | |
"""Load all required models""" | |
# Object detection model | |
self.person_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365") | |
self.person_detector = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365") | |
# Pose estimation model | |
self.pose_processor = VitPoseImageProcessor.from_pretrained("nielsr/vitpose-base-simple") | |
self.pose_model = VitPoseForPoseEstimation.from_pretrained("nielsr/vitpose-base-simple") | |
def pascal_voc_to_coco(bboxes: np.ndarray) -> np.ndarray: | |
"""Convert Pascal VOC format to COCO format""" | |
bboxes = bboxes.copy() # Create a copy to avoid modifying the input | |
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] | |
bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] | |
return bboxes | |
def coco_to_xyxy(bboxes: np.ndarray) -> np.ndarray: | |
"""Convert COCO format (x,y,w,h) to xyxy format (x1,y1,x2,y2)""" | |
bboxes = bboxes.copy() | |
bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2] | |
bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3] | |
return bboxes | |
def detect_persons(self, image: Image.Image): | |
"""Detect persons in the image""" | |
inputs = self.person_processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = self.person_detector(**inputs) | |
results = self.person_processor.post_process_object_detection( | |
outputs, | |
target_sizes=torch.tensor([(image.height, image.width)]), | |
threshold=0.3 | |
) | |
dets = sv.Detections.from_transformers(results[0]).with_nms(0.5) | |
# Get boxes and scores for human class (index 0 in COCO dataset) | |
boxes = dets.xyxy[dets.class_id == 0] | |
scores = dets.confidence[dets.class_id == 0] | |
return boxes, scores | |
def detect_keypoints(self, image: Image.Image): | |
"""Detect keypoints in the image""" | |
# Detect persons first | |
boxes, scores = self.detect_persons(image) | |
boxes_coco = [self.pascal_voc_to_coco(boxes)] | |
# Detect pose keypoints | |
pixel_values = self.pose_processor(image, boxes=boxes_coco, return_tensors="pt").pixel_values | |
with torch.no_grad(): | |
outputs = self.pose_model(pixel_values) | |
pose_results = self.pose_processor.post_process_pose_estimation(outputs, boxes=boxes_coco)[0] | |
return pose_results, boxes, scores | |
def visualize_detections(self, image: Image.Image, pose_results, boxes, scores): | |
"""Visualize both bounding boxes and keypoints on the image""" | |
# Convert image to numpy array if needed | |
image_array = np.array(image) | |
# Setup detections for bounding boxes | |
detections = sv.Detections( | |
xyxy=boxes, | |
confidence=scores, | |
class_id=np.array([0]*len(scores)) | |
) | |
# Create box annotator | |
box_annotator = sv.BoxAnnotator( | |
color=sv.ColorPalette.DEFAULT, | |
thickness=2 | |
) | |
# Create edge annotator for keypoints | |
edge_annotator = sv.EdgeAnnotator( | |
color=sv.Color.GREEN, | |
thickness=3 | |
) | |
# Convert keypoints to supervision format | |
key_points = sv.KeyPoints( | |
xy=torch.cat([pose_result['keypoints'].unsqueeze(0) for pose_result in pose_results]).cpu().numpy() | |
) | |
# Annotate image with boxes first | |
annotated_frame = box_annotator.annotate( | |
scene=image_array.copy(), | |
detections=detections | |
) | |
# Then add keypoints | |
annotated_frame = edge_annotator.annotate( | |
scene=annotated_frame, | |
key_points=key_points | |
) | |
return Image.fromarray(annotated_frame) | |
def process_image(self, input_image): | |
"""Process image and return visualization""" | |
if input_image is None: | |
return None, "" | |
# Convert to PIL Image if necessary | |
if isinstance(input_image, np.ndarray): | |
image = Image.fromarray(input_image) | |
else: | |
image = input_image | |
# Detect keypoints and boxes | |
pose_results, boxes, scores = self.detect_keypoints(image) | |
# Visualize results | |
result_image = self.visualize_detections(image, pose_results, boxes, scores) | |
# Create detection information text | |
info_text = [] | |
# Box information | |
for i, (box, score) in enumerate(zip(boxes, scores)): | |
info_text.append(f"\nPerson {i + 1} (confidence: {score:.2f})") | |
info_text.append(f"Bounding Box: x1={box[0]:.1f}, y1={box[1]:.1f}, x2={box[2]:.1f}, y2={box[3]:.1f}") | |
# Add keypoint information for this person | |
pose_result = pose_results[i] | |
for j, keypoint in enumerate(pose_result["keypoints"]): | |
x, y, confidence = keypoint | |
info_text.append(f"Keypoint {KEYPOINT_LABEL_MAP[j]}: x={x:.1f}, y={y:.1f}, confidence={confidence:.2f}") | |
return result_image, "\n".join(info_text) | |
def create_gradio_interface(): | |
"""Create Gradio interface""" | |
detector = KeypointDetector() | |
with gr.Blocks() as interface: | |
gr.Markdown("# Human Detection and Keypoint Estimation using VitPose") | |
gr.Markdown("Upload an image to detect people and their keypoints. The model will:") | |
gr.Markdown("1. Detect people in the image (shown as bounding boxes)") | |
gr.Markdown("2. Identify keypoints for each detected person (shown as connected green lines)") | |
gr.Markdown("Huge shoutout to @NielsRogge and @SangbumChoi for this work!") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image") | |
process_button = gr.Button("Detect People & Keypoints") | |
with gr.Column(): | |
output_image = gr.Image(label="Detection Results") | |
detection_info = gr.Textbox( | |
label="Detection Information", | |
lines=10, | |
placeholder="Detection details will appear here..." | |
) | |
process_button.click( | |
fn=detector.process_image, | |
inputs=input_image, | |
outputs=[output_image, detection_info] | |
) | |
gr.Examples( | |
examples=[ | |
"http://images.cocodataset.org/val2017/000000000139.jpg" | |
], | |
inputs=input_image | |
) | |
return interface | |
if __name__ == "__main__": | |
interface = create_gradio_interface() | |
interface.launch() |