Ramendra commited on
Commit
9ca9208
1 Parent(s): bee69fc

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ### 1. Imports
3
+
4
+ import gradio as gr
5
+ import os
6
+ import torch
7
+
8
+ from model import create_model_alexnet
9
+ from timeit import default_timer as timer
10
+ from typing import Tuple, Dict
11
+
12
+ ### 2. Model and transforms preparation ###
13
+
14
+ # Create model_alexnet
15
+ model_alexnet, transforms = create_model_alexnet( num_classes=2)
16
+
17
+ # Load saved weights
18
+ model_alexnet.load_state_dict(torch.load(f="cat_dog_classifier.pth", map_location=torch.device("cpu"))) # load to CPU
19
+
20
+ ### 3. Predict function ###
21
+
22
+ # Create predict function
23
+ def predict(img):
24
+
25
+ # Start the timer
26
+ start_time = timer()
27
+
28
+ model_alexnet.eval()
29
+
30
+ # Reading the image and size transformation
31
+ features = Image.open(img)
32
+ img = auto_transform(features).unsqueeze(0)
33
+
34
+ with torch.inference_mode():
35
+ output = model_alexnet(img).to(device)
36
+ _, predicted = torch.max(output, 1)
37
+
38
+ # Create a prediction label and prediction probability dictionary for each prediction class
39
+ # This is the required format for Gradio's output parameter
40
+ pred_labels_and_probs = 'dog' if predicted.item() ==1 else 'cat'
41
+
42
+ # Calculate the prediction time
43
+ pred_time = round(timer() - start_time, 5)
44
+
45
+ # Return the prediction dictionary and prediction time
46
+ return pred_labels, pred_time
47
+
48
+
49
+ ### 4. Gradio app ###
50
+
51
+ import gradio as gr
52
+
53
+ # Create title, description and article strings
54
+ title = "Classification Demo"
55
+ description = "Cat/Dog classification - Transfer Learning "
56
+
57
+ # Create the Gradio demo
58
+ demo = gr.Interface(fn=predict, # mapping function from input to output
59
+ inputs=gr.Image(type='filepath'), # what are the inputs?
60
+ outputs=[gr.Label(label="Predictions"), # what are the outputs?
61
+ gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
62
+ #examples=example_list,
63
+ title=title,
64
+ description=description,)
65
+
66
+ # Launch the demo!
67
+ demo.launch()