Spaces:
Configuration error
Configuration error
import numpy as np | |
import cv2 | |
import tritonclient.grpc as grpcclient | |
import sys | |
import argparse | |
class_names =['Helmet',"No_helmet","person"] | |
def get_triton_client(url: str = 'localhost:8001'): | |
try: | |
keepalive_options = grpcclient.KeepAliveOptions( | |
keepalive_time_ms=2**31 - 1, | |
keepalive_timeout_ms=20000, | |
keepalive_permit_without_calls=False, | |
http2_max_pings_without_data=2 | |
) | |
triton_client = grpcclient.InferenceServerClient( | |
url=url, | |
verbose=False, | |
keepalive_options=keepalive_options) | |
except Exception as e: | |
print("channel creation failed: " + str(e)) | |
sys.exit() | |
return triton_client | |
def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h): | |
label = f'{class_names[class_id]}: {confidence:.2f}' | |
color = (255, 0, ) | |
cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2) | |
cv2.putText(img, label, (x - 10, y - 10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
def read_image(image_path: str, expected_image_shape) -> np.ndarray: | |
expected_width = expected_image_shape[0] | |
expected_height = expected_image_shape[1] | |
expected_length = min((expected_height, expected_width)) | |
original_image: np.ndarray = cv2.imread(image_path) | |
[height, width, _] = original_image.shape | |
length = max((height, width)) | |
image = np.zeros((length, length, 3), np.uint8) | |
image[0:height, 0:width] = original_image | |
scale = length / expected_length | |
input_image = cv2.resize(image, (expected_width, expected_height)) | |
input_image = (input_image / 255.0).astype(np.float32) | |
# Channel first | |
input_image = input_image.transpose(2, 0, 1) | |
# Expand dimensions | |
input_image = np.expand_dims(input_image, axis=0) | |
return original_image, input_image, scale | |
def run_inference(model_name: str, input_image: np.ndarray, | |
triton_client: grpcclient.InferenceServerClient): | |
inputs = [] | |
outputs = [] | |
inputs.append(grpcclient.InferInput('images', input_image.shape, "FP32")) | |
# Initialize the data | |
inputs[0].set_data_from_numpy(input_image) | |
outputs.append(grpcclient.InferRequestedOutput('num_detections')) | |
outputs.append(grpcclient.InferRequestedOutput('detection_boxes')) | |
outputs.append(grpcclient.InferRequestedOutput('detection_scores')) | |
outputs.append(grpcclient.InferRequestedOutput('detection_classes')) | |
# Test with outputs | |
results = triton_client.infer(model_name=model_name, | |
inputs=inputs, | |
outputs=outputs) | |
num_detections = results.as_numpy('num_detections') | |
detection_boxes = results.as_numpy('detection_boxes') | |
detection_scores = results.as_numpy('detection_scores') | |
detection_classes = results.as_numpy('detection_classes') | |
return num_detections, detection_boxes, detection_scores, detection_classes | |
def main(image_path, model_name, url): | |
triton_client = get_triton_client(url) | |
expected_image_shape = triton_client.get_model_metadata(model_name).inputs[0].shape[-2:] | |
original_image, input_image, scale = read_image(image_path, expected_image_shape) | |
num_detections, detection_boxes, detection_scores, detection_classes = run_inference( | |
model_name, input_image, triton_client) | |
print(detection_classes) | |
print(detection_boxes) | |
for index in range(num_detections[0]): | |
box = detection_boxes[index] | |
draw_bounding_box(original_image, | |
detection_classes[index], | |
detection_scores[index], | |
round(box[0] * scale), | |
round(box[1] * scale), | |
round((box[0] + box[2]) * scale), | |
round((box[1] + box[3]) * scale)) | |
cv2.imwrite('output.jpg', original_image) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--image_path', type=str, default='./assets/Image (47).png') | |
parser.add_argument('--model_name', type=str, default='yolov8_ensemble') | |
parser.add_argument('--url', type=str, default='172.17.0.1:8001') | |
args = parser.parse_args() | |
main(args.image_path, args.model_name, args.url) | |