suko commited on
Commit
b6af6fc
·
verified ·
1 Parent(s): 1668fa4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -67
app.py CHANGED
@@ -4,112 +4,99 @@ import json
4
  import os
5
  from PIL import Image
6
  import onnxruntime as rt
 
7
  class ONNXModel:
8
  def __init__(self, dir_path) -> None:
9
- """Method to get name of model file. Assumes model is in the parent directory for script."""
10
  model_dir = os.path.dirname(dir_path)
11
  with open(os.path.join(model_dir, "signature.json"), "r") as f:
12
  self.signature = json.load(f)
 
13
  self.model_file = os.path.join(model_dir, self.signature.get("filename"))
14
  if not os.path.isfile(self.model_file):
15
- raise FileNotFoundError(f"Model file does not exist")
16
- # get the signature for model inputs and outputs
17
  self.signature_inputs = self.signature.get("inputs")
18
  self.signature_outputs = self.signature.get("outputs")
19
- self.session = None
20
  if "Image" not in self.signature_inputs:
21
- raise ValueError("ONNX model doesn't have 'Image' input! Check signature.json, and please report issue to Lobe.")
22
- # Look for the version in signature file.
23
- # If it's not found or the doesn't match expected, print a message
24
  version = self.signature.get("export_model_version")
25
  if version is None or version != EXPORT_MODEL_VERSION:
26
- print(
27
- f"There has been a change to the model format. Please use a model with a signature 'export_model_version' that matches {EXPORT_MODEL_VERSION}."
28
- )
29
 
30
  def load(self) -> None:
31
- """Load the model from path to model file"""
32
- # Load ONNX model as session.
33
- self.session = rt.InferenceSession(path_or_bytes=self.model_file)
34
 
35
  def predict(self, image: Image.Image) -> dict:
36
- """
37
- Predict with the ONNX session!
38
- """
39
- # process image to be compatible with the model
40
- img = self.process_image(image, self.signature_inputs.get("Image").get("shape"))
41
- # run the model!
42
- fetches = [(key, value.get("name")) for key, value in self.signature_outputs.items()]
43
- # make the image a batch of 1
44
- feed = {self.signature_inputs.get("Image").get("name"): [img]}
45
- outputs = self.session.run(output_names=[name for (_, name) in fetches], input_feed=feed)
46
- return self.process_output(fetches, outputs)
47
 
48
  def process_image(self, image: Image.Image, input_shape: list) -> np.ndarray:
49
- """
50
- Given a PIL Image, center square crop and resize to fit the expected model input, and convert from [0,255] to [0,1] values.
51
- """
52
  width, height = image.size
53
- # ensure image type is compatible with model and convert if not
54
  if image.mode != "RGB":
55
  image = image.convert("RGB")
56
- # center crop image (you can substitute any other method to make a square image, such as just resizing or padding edges with 0)
57
- if width != height:
58
- square_size = min(width, height)
59
- left = (width - square_size) / 2
60
- top = (height - square_size) / 2
61
- right = (width + square_size) / 2
62
- bottom = (height + square_size) / 2
63
- # Crop the center of the image
64
- image = image.crop((left, top, right, bottom))
65
- # now the image is square, resize it to be the right shape for the model input
66
- input_width, input_height = input_shape[1:3]
67
- if image.width != input_width or image.height != input_height:
68
- image = image.resize((input_width, input_height))
69
 
70
- # make 0-1 float instead of 0-255 int (that PIL Image loads by default)
 
 
 
 
 
 
 
 
71
  image = np.asarray(image) / 255.0
72
- # format input as model expects
73
  return image.astype(np.float32)
74
 
75
- def process_output(self, fetches: dict, outputs: dict) -> dict:
76
- # un-batch since we ran an image with batch size of 1,
77
- # convert to normal python types with tolist(), and convert any byte strings to normal strings with .decode()
78
  out_keys = ["label", "confidence"]
79
- results = {}
80
- for i, (key, _) in enumerate(fetches):
81
- val = outputs[i].tolist()[0]
82
- if isinstance(val, bytes):
83
- val = val.decode()
84
- results[key] = val
85
  confs = results["Confidences"]
86
- labels = self.signature.get("classes").get("Label")
87
  output = [dict(zip(out_keys, group)) for group in zip(labels, confs)]
88
- sorted_output = {"predictions": sorted(output, key=lambda k: k["confidence"], reverse=True)}
89
- return sorted_output
90
- EXPORT_MODEL_VERSION=1
91
  model = ONNXModel(dir_path="model.onnx")
92
  model.load()
93
 
94
-
95
  def predict(image):
96
- image = Image.fromarray(np.uint8(image), 'RGB')
 
97
  prediction = model.predict(image)
98
  for output in prediction["predictions"]:
99
  output["confidence"] = round(output["confidence"], 4)
100
  return prediction
101
 
102
- inputs = gr.Image(type="pil")
103
  outputs = gr.JSON()
104
 
105
- description = "This is a web interface for the Naked Detector model. Upload an image and get predictions for the presence of nudity. \n This model and website are created by KUO SUKO, C110156115 NKUST."
106
-
107
- interface = gr.Interface(title="Naked Detector", fn=predict, inputs=inputs, outputs=outputs, description=description)
108
- interface.launch()
109
-
110
-
111
-
112
-
113
 
 
 
 
 
 
 
 
114
 
 
115
 
 
 
4
  import os
5
  from PIL import Image
6
  import onnxruntime as rt
7
+
8
  class ONNXModel:
9
  def __init__(self, dir_path) -> None:
10
+ """Load model metadata and initialize ONNX session."""
11
  model_dir = os.path.dirname(dir_path)
12
  with open(os.path.join(model_dir, "signature.json"), "r") as f:
13
  self.signature = json.load(f)
14
+
15
  self.model_file = os.path.join(model_dir, self.signature.get("filename"))
16
  if not os.path.isfile(self.model_file):
17
+ raise FileNotFoundError("Model file does not exist.")
18
+
19
  self.signature_inputs = self.signature.get("inputs")
20
  self.signature_outputs = self.signature.get("outputs")
21
+
22
  if "Image" not in self.signature_inputs:
23
+ raise ValueError("ONNX model must have an 'Image' input. Check signature.json.")
24
+
25
+ # Check export version
26
  version = self.signature.get("export_model_version")
27
  if version is None or version != EXPORT_MODEL_VERSION:
28
+ print(f"Warning: Expected model version {EXPORT_MODEL_VERSION}, but found {version}.")
29
+
30
+ self.session = None
31
 
32
  def load(self) -> None:
33
+ """Load the ONNX model with execution providers."""
34
+ self.session = rt.InferenceSession(self.model_file, providers=["CPUExecutionProvider"])
 
35
 
36
  def predict(self, image: Image.Image) -> dict:
37
+ """Process image and run ONNX model inference."""
38
+ img = self.process_image(image, self.signature_inputs["Image"]["shape"])
39
+ feed = {self.signature_inputs["Image"]["name"]: [img]}
40
+ output_names = [self.signature_outputs[key]["name"] for key in self.signature_outputs]
41
+ outputs = self.session.run(output_names=output_names, input_feed=feed)
42
+ return self.process_output(outputs)
 
 
 
 
 
43
 
44
  def process_image(self, image: Image.Image, input_shape: list) -> np.ndarray:
45
+ """Resize and normalize the image."""
 
 
46
  width, height = image.size
 
47
  if image.mode != "RGB":
48
  image = image.convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ square_size = min(width, height)
51
+ left = (width - square_size) / 2
52
+ top = (height - square_size) / 2
53
+ right = (width + square_size) / 2
54
+ bottom = (height + square_size) / 2
55
+ image = image.crop((left, top, right, bottom))
56
+
57
+ input_width, input_height = input_shape[1:3]
58
+ image = image.resize((input_width, input_height))
59
  image = np.asarray(image) / 255.0
 
60
  return image.astype(np.float32)
61
 
62
+ def process_output(self, outputs: list) -> dict:
63
+ """Format the model output."""
 
64
  out_keys = ["label", "confidence"]
65
+ results = {key: outputs[i].tolist()[0] for i, key in enumerate(self.signature_outputs)}
 
 
 
 
 
66
  confs = results["Confidences"]
67
+ labels = self.signature["classes"]["Label"]
68
  output = [dict(zip(out_keys, group)) for group in zip(labels, confs)]
69
+ return {"predictions": sorted(output, key=lambda x: x["confidence"], reverse=True)}
70
+
71
+ EXPORT_MODEL_VERSION = 1
72
  model = ONNXModel(dir_path="model.onnx")
73
  model.load()
74
 
 
75
  def predict(image):
76
+ """Run inference on the given image."""
77
+ image = Image.fromarray(np.uint8(image), "RGB")
78
  prediction = model.predict(image)
79
  for output in prediction["predictions"]:
80
  output["confidence"] = round(output["confidence"], 4)
81
  return prediction
82
 
83
+ inputs = gr.Image(image_mode="RGB")
84
  outputs = gr.JSON()
85
 
86
+ description = (
87
+ "This is a web interface for the Naked Detector model. "
88
+ "Upload an image and get predictions for the presence of nudity.\n"
89
+ "Model and website created by KUO SUKO, C110156115 NKUST."
90
+ )
 
 
 
91
 
92
+ interface = gr.Interface(
93
+ fn=predict,
94
+ inputs=inputs,
95
+ outputs=outputs,
96
+ title="Naked Detector",
97
+ description=description
98
+ )
99
 
100
+ interface.launch()
101
 
102
+ # this is changed by ChatGPT, if it run like a shit. don't blame me ><