import gradio as gr import numpy as np import json import os from PIL import Image import onnxruntime as rt class ONNXModel: def __init__(self, dir_path) -> None: """Method to get name of model file. Assumes model is in the parent directory for script.""" model_dir = os.path.dirname(dir_path) with open(os.path.join(model_dir, "signature.json"), "r") as f: self.signature = json.load(f) self.model_file = os.path.join(model_dir, self.signature.get("filename")) if not os.path.isfile(self.model_file): raise FileNotFoundError(f"Model file does not exist") # get the signature for model inputs and outputs self.signature_inputs = self.signature.get("inputs") self.signature_outputs = self.signature.get("outputs") self.session = None if "Image" not in self.signature_inputs: raise ValueError("ONNX model doesn't have 'Image' input! Check signature.json, and please report issue to Lobe.") # Look for the version in signature file. # If it's not found or the doesn't match expected, print a message version = self.signature.get("export_model_version") if version is None or version != EXPORT_MODEL_VERSION: print( f"There has been a change to the model format. Please use a model with a signature 'export_model_version' that matches {EXPORT_MODEL_VERSION}." ) def load(self) -> None: """Load the model from path to model file""" # Load ONNX model as session. self.session = rt.InferenceSession(path_or_bytes=self.model_file) def predict(self, image: Image.Image) -> dict: """ Predict with the ONNX session! """ # process image to be compatible with the model img = self.process_image(image, self.signature_inputs.get("Image").get("shape")) # run the model! fetches = [(key, value.get("name")) for key, value in self.signature_outputs.items()] # make the image a batch of 1 feed = {self.signature_inputs.get("Image").get("name"): [img]} outputs = self.session.run(output_names=[name for (_, name) in fetches], input_feed=feed) return self.process_output(fetches, outputs) def process_image(self, image: Image.Image, input_shape: list) -> np.ndarray: """ Given a PIL Image, center square crop and resize to fit the expected model input, and convert from [0,255] to [0,1] values. """ width, height = image.size # ensure image type is compatible with model and convert if not if image.mode != "RGB": image = image.convert("RGB") # center crop image (you can substitute any other method to make a square image, such as just resizing or padding edges with 0) if width != height: square_size = min(width, height) left = (width - square_size) / 2 top = (height - square_size) / 2 right = (width + square_size) / 2 bottom = (height + square_size) / 2 # Crop the center of the image image = image.crop((left, top, right, bottom)) # now the image is square, resize it to be the right shape for the model input input_width, input_height = input_shape[1:3] if image.width != input_width or image.height != input_height: image = image.resize((input_width, input_height)) # make 0-1 float instead of 0-255 int (that PIL Image loads by default) image = np.asarray(image) / 255.0 # format input as model expects return image.astype(np.float32) def process_output(self, fetches: dict, outputs: dict) -> dict: # un-batch since we ran an image with batch size of 1, # convert to normal python types with tolist(), and convert any byte strings to normal strings with .decode() out_keys = ["label", "confidence"] results = {} for i, (key, _) in enumerate(fetches): val = outputs[i].tolist()[0] if isinstance(val, bytes): val = val.decode() results[key] = val confs = results["Confidences"] labels = self.signature.get("classes").get("Label") output = [dict(zip(out_keys, group)) for group in zip(labels, confs)] sorted_output = {"predictions": sorted(output, key=lambda k: k["confidence"], reverse=True)} return sorted_output EXPORT_MODEL_VERSION=1 model = ONNXModel(dir_path="model.onnx") model.load() def predict(image): image = Image.fromarray(np.uint8(image), 'RGB') prediction = model.predict(image) for output in prediction["predictions"]: output["confidence"] = round(output["confidence"], 4) return prediction inputs = gr.Image(type="pil") outputs = gr.JSON() description = "This is a web interface for the Naked Detector model. Upload an image and get predictions for the presence of nudity. \n This model and website are created by KUO SUKO, C110156115 NKUST." interface = gr.Interface(title="Naked Detector", fn=predict, inputs=inputs, outputs=outputs, description=description) interface.launch()