helmet-detection / main.py
drkareemkamal's picture
Upload 11 files
7861a07 verified
raw
history blame
4.3 kB
import numpy as np
import cv2
import tritonclient.grpc as grpcclient
import sys
import argparse
class_names =['Helmet',"No_helmet","person"]
def get_triton_client(url: str = 'localhost:8001'):
try:
keepalive_options = grpcclient.KeepAliveOptions(
keepalive_time_ms=2**31 - 1,
keepalive_timeout_ms=20000,
keepalive_permit_without_calls=False,
http2_max_pings_without_data=2
)
triton_client = grpcclient.InferenceServerClient(
url=url,
verbose=False,
keepalive_options=keepalive_options)
except Exception as e:
print("channel creation failed: " + str(e))
sys.exit()
return triton_client
def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h):
label = f'{class_names[class_id]}: {confidence:.2f}'
color = (255, 0, )
cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2)
cv2.putText(img, label, (x - 10, y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
def read_image(image_path: str, expected_image_shape) -> np.ndarray:
expected_width = expected_image_shape[0]
expected_height = expected_image_shape[1]
expected_length = min((expected_height, expected_width))
original_image: np.ndarray = cv2.imread(image_path)
[height, width, _] = original_image.shape
length = max((height, width))
image = np.zeros((length, length, 3), np.uint8)
image[0:height, 0:width] = original_image
scale = length / expected_length
input_image = cv2.resize(image, (expected_width, expected_height))
input_image = (input_image / 255.0).astype(np.float32)
# Channel first
input_image = input_image.transpose(2, 0, 1)
# Expand dimensions
input_image = np.expand_dims(input_image, axis=0)
return original_image, input_image, scale
def run_inference(model_name: str, input_image: np.ndarray,
triton_client: grpcclient.InferenceServerClient):
inputs = []
outputs = []
inputs.append(grpcclient.InferInput('images', input_image.shape, "FP32"))
# Initialize the data
inputs[0].set_data_from_numpy(input_image)
outputs.append(grpcclient.InferRequestedOutput('num_detections'))
outputs.append(grpcclient.InferRequestedOutput('detection_boxes'))
outputs.append(grpcclient.InferRequestedOutput('detection_scores'))
outputs.append(grpcclient.InferRequestedOutput('detection_classes'))
# Test with outputs
results = triton_client.infer(model_name=model_name,
inputs=inputs,
outputs=outputs)
num_detections = results.as_numpy('num_detections')
detection_boxes = results.as_numpy('detection_boxes')
detection_scores = results.as_numpy('detection_scores')
detection_classes = results.as_numpy('detection_classes')
return num_detections, detection_boxes, detection_scores, detection_classes
def main(image_path, model_name, url):
triton_client = get_triton_client(url)
expected_image_shape = triton_client.get_model_metadata(model_name).inputs[0].shape[-2:]
original_image, input_image, scale = read_image(image_path, expected_image_shape)
num_detections, detection_boxes, detection_scores, detection_classes = run_inference(
model_name, input_image, triton_client)
print(detection_classes)
print(detection_boxes)
for index in range(num_detections[0]):
box = detection_boxes[index]
draw_bounding_box(original_image,
detection_classes[index],
detection_scores[index],
round(box[0] * scale),
round(box[1] * scale),
round((box[0] + box[2]) * scale),
round((box[1] + box[3]) * scale))
cv2.imwrite('output.jpg', original_image)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_path', type=str, default='./assets/Image (47).png')
parser.add_argument('--model_name', type=str, default='yolov8_ensemble')
parser.add_argument('--url', type=str, default='172.17.0.1:8001')
args = parser.parse_args()
main(args.image_path, args.model_name, args.url)