Spaces:
Build error
Build error
from io import BytesIO | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
import requests | |
from PIL import Image | |
from super_gradients.common.object_names import Models | |
from super_gradients.training import models | |
from super_gradients.training.utils.visualization.detection import draw_bbox | |
from super_gradients.training.utils.visualization.pose_estimation import PoseVisualization | |
# Initialize your pose estimation model | |
yolo_nas_pose = models.get("yolo_nas_pose_l", | |
num_classes=17, | |
checkpoint_path="/content/yolo_nas_pose_l_coco_pose.pth") | |
def process_and_predict(url=None, | |
image=None, | |
confidence=0.5, | |
iou=0.5): | |
# If a URL is provided, use it directly for prediction | |
if url is not None and url.strip() != "": | |
response = requests.get(url) | |
image = Image.open(BytesIO(response.content)) | |
image = np.array(image) | |
result = yolo_nas_pose.predict(image, conf=confidence,iou=iou) | |
# If a file is uploaded, read it, convert it to a numpy array and use it for prediction | |
elif image is not None: | |
result = yolo_nas_pose.predict(image, conf=confidence,iou=iou) | |
else: | |
return None # If no input is provided, return None | |
# Extract prediction data | |
image_prediction = result._images_prediction_lst[0] | |
pose_data = image_prediction.prediction | |
# Visualize the prediction | |
output_image = PoseVisualization.draw_poses( | |
image=image_prediction.image, | |
poses=pose_data.poses, | |
boxes=pose_data.bboxes_xyxy, | |
scores=pose_data.scores, | |
is_crowd=None, | |
edge_links=pose_data.edge_links, | |
edge_colors=pose_data.edge_colors, | |
keypoint_colors=pose_data.keypoint_colors, | |
joint_thickness=2, | |
box_thickness=2, | |
keypoint_radius=5 | |
) | |
blank_image = np.zeros_like(image_prediction.image) | |
skeleton_image = PoseVisualization.draw_poses( | |
image=blank_image, | |
poses=pose_data.poses, | |
boxes=pose_data.bboxes_xyxy, | |
scores=pose_data.scores, | |
is_crowd=None, | |
edge_links=pose_data.edge_links, | |
edge_colors=pose_data.edge_colors, | |
keypoint_colors=pose_data.keypoint_colors, | |
joint_thickness=2, | |
box_thickness=2, | |
keypoint_radius=5 | |
) | |
# Convert the resulting visualization to a PIL Image | |
# output_image_pil = Image.fromarray(output_image.astype('uint8'), 'RGB') | |
# Return the PIL Image directly | |
return output_image, skeleton_image | |
# Define the Gradio interface | |
iface = gr.Interface( | |
fn=process_and_predict, | |
inputs=[ | |
gr.Textbox(placeholder="Enter Image URL", label="Image URL"), | |
gr.Image(label="Upload Image", type='numpy'), | |
gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Confidence Threshold"), | |
gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="IoU Threshold") | |
], | |
outputs=[ | |
gr.components.Image(label="Estimated Pose"), | |
gr.components.Image(label="Skeleton Only") | |
], | |
title="YOLO-NAS-Pose Demo", | |
description="Upload an image, enter an image URL, or use your webcam to use a pretrained YOLO-NAS-Pose L for inference. You can check out the ", | |
live=False, | |
allow_flagging=False, | |
) | |
# Launch the interface | |
iface.launch() | |