from sys import platform import onnxruntime as rt from einops import rearrange import numpy as np if platform in {"win32", "win64"}: import onnxruntime.tools.add_openvino_win_libs as utils utils.add_openvino_libs_to_path() class Predictor: def __init__(self, model_config): """ Initialize the Predictor class. Args: model_config (dict): Model configuration containing path_to_model, path_to_class_list, threshold, and topk values. """ self.config = model_config self.provider = self.config["provider"] self.threshold = self.config["threshold"] self.labels = {} self.model_init(self.config["path_to_model"]) self.create_labels() def create_labels(self): """ Create a dictionary of labels from the provided path_to_class_list. """ with open(self.config["path_to_class_list"], "r") as f: labels = [line.strip() for line in f] labels = self.decode_preds(labels) idx_lbl_pairs = [x.split("\t") for x in labels] self.labels = {int(x[0]): x[1] for x in idx_lbl_pairs} def softmax(self, x): exp_x = np.exp(x - np.max(x, axis=1, keepdims=True)) return exp_x / np.sum(exp_x, axis=1, keepdims=True) def predict(self, x): """ Make a prediction using the provided input frames. Args: x (list): List of input frames. Returns: dict: Dictionary containing predicted labels and confidence values. """ clip = np.array(x).astype(np.float32) / 255.0 clip = rearrange(clip, "t h w c -> 1 c t h w") prediction = self.model([self.output_name], {self.input_name: clip})[0] prediction = self.softmax(prediction) prediction = np.squeeze(prediction) topk_labels = prediction.argsort()[-self.config["topk"] :][::-1] topk_confidence = prediction[topk_labels] result = [self.labels[lbl_idx] for lbl_idx in topk_labels] if np.max(topk_confidence) < self.threshold: return None return { "labels": dict(zip([i for i in range(len(result))], result)), "confidence": dict(zip([i for i in range(len(result))], topk_confidence)), } def model_init(self, path_to_model: str) -> None: """ Load and init the ONNX model using the provided path. Args: path_to_model (str): Path to the ONNX model file. Returns: None """ session = rt.InferenceSession(path_to_model, providers=[self.provider]) self.input_name = session.get_inputs()[0].name self.output_name = session.get_outputs()[0].name self.model = session.run def decode_preds(self, data): if platform in {"win32", "win64"}: data = [i.encode("cp1251").decode("utf-8") for i in data] return data