OwenElliott commited on
Commit
aae5426
1 Parent(s): a544f2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -3,36 +3,40 @@ from urllib.request import urlopen
3
  from PIL import Image
4
  import timm
5
  import torch
 
6
 
7
- # Load the model
8
  model = timm.create_model("hf_hub:Marqo/nsfw-image-detection-384", pretrained=True)
9
  model = model.eval()
10
 
11
- # Prepare the data transformation
12
  data_config = timm.data.resolve_model_data_config(model)
13
  transforms = timm.data.create_transform(**data_config, is_training=False)
14
 
15
- # Prediction function
16
  def predict(image):
 
17
  with torch.no_grad():
18
- # Transform the image
19
  input_tensor = transforms(image).unsqueeze(0)
20
- # Run the model
21
  output = model(input_tensor).softmax(dim=-1).cpu()
22
- # Get class names
23
  class_names = model.pretrained_cfg["label_names"]
24
- # Create the result dictionary
25
  result = {class_names[i]: float(output[0, i]) for i in range(len(class_names))}
26
- return result
 
 
27
 
28
- # Gradio interface
29
- interface = gr.Interface(
30
  fn=predict,
31
- inputs=gr.Image(type="pil"),
32
- outputs=gr.Label(num_top_classes=3),
 
 
 
33
  title="NSFW Image Detection",
34
- description="Upload an image to detect if it is NSFW or Safe for Work."
 
 
 
 
35
  )
36
 
37
  if __name__ == "__main__":
38
- interface.launch()
 
3
  from PIL import Image
4
  import timm
5
  import torch
6
+ import time
7
 
 
8
  model = timm.create_model("hf_hub:Marqo/nsfw-image-detection-384", pretrained=True)
9
  model = model.eval()
10
 
 
11
  data_config = timm.data.resolve_model_data_config(model)
12
  transforms = timm.data.create_transform(**data_config, is_training=False)
13
 
 
14
  def predict(image):
15
+ start_time = time.time()
16
  with torch.no_grad():
 
17
  input_tensor = transforms(image).unsqueeze(0)
 
18
  output = model(input_tensor).softmax(dim=-1).cpu()
 
19
  class_names = model.pretrained_cfg["label_names"]
 
20
  result = {class_names[i]: float(output[0, i]) for i in range(len(class_names))}
21
+ end_time = time.time()
22
+ inference_time = end_time - start_time
23
+ return result, f"Inference time: {inference_time:.2f} seconds"
24
 
25
+
26
+ demo = gr.Interface(
27
  fn=predict,
28
+ inputs=gr.Image(type="pil", height=512),
29
+ outputs=[
30
+ gr.Label(num_top_classes=2),
31
+ gr.Textbox(label="Inference Time")
32
+ ],
33
  title="NSFW Image Detection",
34
+ description=(
35
+ "Upload an image to detect if it is **NSFW (Not Safe For Work)** or **Safe For Work (SFW)**.\n\n"
36
+ "This app uses the [Marqo/nsfw-image-detection-384](https://huggingface.co/Marqo/nsfw-image-detection-384) "
37
+ "image classification model from Hugging Face's `timm` library."
38
+ )
39
  )
40
 
41
  if __name__ == "__main__":
42
+ demo.launch()