harpreetsahota's picture
Update app.py
003af74
raw
history blame contribute delete
No virus
3.47 kB
from io import BytesIO
import cv2
import gradio as gr
import numpy as np
import requests
from PIL import Image
from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.training.utils.visualization.detection import draw_bbox
from super_gradients.training.utils.visualization.pose_estimation import PoseVisualization
# Initialize your pose estimation model
yolo_nas_pose = models.get("yolo_nas_pose_l",
num_classes=17,
checkpoint_path="./yolo_nas_pose_l_coco_pose.pth")
def process_and_predict(url=None,
image=None,
confidence=0.5,
iou=0.5):
# If a URL is provided, use it directly for prediction
if url is not None and url.strip() != "":
response = requests.get(url)
image = Image.open(BytesIO(response.content))
image = np.array(image)
result = yolo_nas_pose.predict(image, conf=confidence,iou=iou)
# If a file is uploaded, read it, convert it to a numpy array and use it for prediction
elif image is not None:
result = yolo_nas_pose.predict(image, conf=confidence,iou=iou)
else:
return None # If no input is provided, return None
# Extract prediction data
image_prediction = result._images_prediction_lst[0]
pose_data = image_prediction.prediction
# Visualize the prediction
output_image = PoseVisualization.draw_poses(
image=image_prediction.image,
poses=pose_data.poses,
boxes=pose_data.bboxes_xyxy,
scores=pose_data.scores,
is_crowd=None,
edge_links=pose_data.edge_links,
edge_colors=pose_data.edge_colors,
keypoint_colors=pose_data.keypoint_colors,
joint_thickness=2,
box_thickness=2,
keypoint_radius=5
)
blank_image = np.zeros_like(image_prediction.image)
skeleton_image = PoseVisualization.draw_poses(
image=blank_image,
poses=pose_data.poses,
boxes=pose_data.bboxes_xyxy,
scores=pose_data.scores,
is_crowd=None,
edge_links=pose_data.edge_links,
edge_colors=pose_data.edge_colors,
keypoint_colors=pose_data.keypoint_colors,
joint_thickness=2,
box_thickness=2,
keypoint_radius=5
)
return output_image, skeleton_image
# Define the Gradio interface
iface = gr.Interface(
fn=process_and_predict,
inputs=[
gr.Textbox(placeholder="Enter Image URL", label="Image URL"),
gr.Image(label="Upload Image", type='numpy'),
gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Confidence Threshold"),
gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="IoU Threshold")
],
outputs=[
gr.components.Image(label="Estimated Pose"),
gr.components.Image(label="Skeleton Only")
],
title="YOLO-NAS-Pose Demo",
description="Upload an image, enter an image URL, or use your webcam to use a pretrained YOLO-NAS-Pose L for inference. Get more hands-on with the [starter notebook for inference](https://bit.ly/yn-pose-inference), and learn how to fine-tune your own model with the [fine-tuning notebook](https://bit.ly/yn-pose-fine-tuning). The official home of YOLO-NAS-Pose is SuperGradients, [gives us a ⭐️ on GitHub!](https://github.com/Deci-AI/super-gradients)",
live=True,
allow_flagging=False,
)
# Launch the interface
iface.launch()