ttheland commited on
Commit
6928524
·
1 Parent(s): b7baf49

deploying to spaces

Browse files
Files changed (2) hide show
  1. app.py +127 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from torchvision import models, transforms
5
+ import time
6
+ import os
7
+ import copy
8
+ import pickle
9
+ from PIL import Image
10
+ import datetime
11
+ import gdown
12
+ import urllib.request
13
+ import gradio as gr
14
+ import markdown
15
+
16
+ url = 'https://drive.google.com/file/d/1qKiyp4r8SqUtz2ZWk3E6oZhyhl6t8lyG/view?usp=sharing'
17
+ path_class_names = "./class_names_restnet_leeds_butterfly.pkl"
18
+ gdown.download(url, path_class_names, quiet=False)
19
+
20
+ url = 'https://drive.google.com/file/d/1Ep2YWU4M-yVkF7AFP3aD1sVhuriIDzFe/view?usp=sharing'
21
+ path_model = "./model_state_restnet_leeds_butterfly.pth"
22
+ gdown.download(url, path_model, quiet=False)
23
+
24
+ url = "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f8/Red_postman_butterfly_%28Heliconius_erato%29.jpg/1599px-Red_postman_butterfly_%28Heliconius_erato%29.jpg"
25
+ path_input = "./h_erato.jpg"
26
+ urllib.request.urlretrieve(url, filename=path_input)
27
+
28
+ url = "https://www.ukbutterflies.co.uk/photo_album/source/664a285ca7b4379147d598ea5127228f.jpg"
29
+ path_input = "./d_plexippus.jpg"
30
+ urllib.request.urlretrieve(url, filename=path_input)
31
+
32
+ # normalisation
33
+ data_transforms_test = transforms.Compose([
34
+ transforms.Resize(256),
35
+ transforms.CenterCrop(224),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
38
+ ])
39
+
40
+ class_names = pickle.load(open(path_class_names, "rb"))
41
+
42
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
43
+
44
+ model_ft = models.resnet18(pretrained=True)
45
+ num_ftrs = model_ft.fc.in_features
46
+ model_ft.fc = nn.Linear(num_ftrs, len(class_names))
47
+ model_ft = model_ft.to(device)
48
+ model_ft.load_state_dict(copy.deepcopy(torch.load(path_model,device)))
49
+
50
+ # Proper labeling
51
+ id_to_name = {
52
+ '001_Danaus Plexippus': 'Danaus plexippus - Monarch',
53
+ '002_Heliconius Charitonius': 'Heliconius charitonius - Zebra Longwing',
54
+ '003_Heliconius Erato': 'Heliconius erato - Red Postman',
55
+ '004_Junonia Coenia': 'Junonia coenia - Common Buckeye',
56
+ '005_Lycaena Phlaeas': 'Lycaena phlaeas - Small Copper',
57
+ '006_Nymphalis Antiopa': 'Nymphalis antiopa - Mourning Cloak',
58
+ '007_Papilio Cresphontes': 'Papilio cresphontes - Giant Swallowtail',
59
+ '008_Pieris Rapae': 'Pieris rapae - Cabbage White',
60
+ '009_Vanessa Atalanta': 'Vanessa atalanta - Red Admiral',
61
+ '010_Vanessa Cardui': 'Vanessa cardui - Painted Lady',
62
+ }
63
+
64
+ def do_inference(img):
65
+ img_t = data_transforms_test(img)
66
+ batch_t = torch.unsqueeze(img_t, 0)
67
+ model_ft.eval()
68
+ # We don't need gradients for test, so wrap in
69
+ # no_grad to save memory
70
+ with torch.no_grad():
71
+ batch_t = batch_t.to(device)
72
+ # forward propagation
73
+ output = model_ft( batch_t)
74
+ # get prediction
75
+ probs = torch.nn.functional.softmax(output, dim=1)
76
+ output = torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int)
77
+ probs = probs.cpu().numpy()[0]
78
+ probs = probs[output]
79
+ labels = np.array(class_names)[output]
80
+ return {id_to_name[labels[i]]: round(float(probs[i]),2) for i in range(len(labels))}
81
+
82
+ im = gr.inputs.Image(shape=(512, 512), image_mode='RGB',
83
+ invert_colors=False, source="upload",
84
+ type="pil")
85
+ title = "Butterfly Classification Demo"
86
+ description = "A pretrained ResNet18 CNN trained on the Leeds Butterfly Dataset. Libraries: PyTorch, Gradio."
87
+ examples = [['./h_erato.jpg'],['d_plexippus.jpg']]
88
+ article_text = markdown.markdown('''
89
+
90
+ <h1 style="color:white">PyTorch image classification - A pretrained ResNet18 CNN trained on the <a href="http://www.josiahwang.com/dataset/leedsbutterfly/">Leeds Butterfly Dataset</a></h1>
91
+ <br>
92
+ <p>The Leeds Butterfly Dataset consists of 832 images in 10 classes:</p>
93
+ <ul>
94
+ <li>Danaus plexippus - Monarch</li>
95
+ <li>Heliconius charitonius - Zebra Longwing</li>
96
+ <li>Heliconius erato - Red Postman</li>
97
+ <li>Lycaena phlaeas - Small Copper</li>
98
+ <li>Junonia coenia - Common Buckeye</li>
99
+ <li>Nymphalis antiopa - Mourning Cloak</li>
100
+ <li>Papilio cresphontes - Giant Swallowtail</li>
101
+ <li>Pieris rapae - Cabbage White</li>
102
+ <li>Vanessa atalanta - Red Admiral</li>
103
+ <li>Vanessa cardui - Painted Lady</li>
104
+ </ul>
105
+ <br>
106
+ <p>Part of a dissertation project. Author: <a href="https://github.com/ttheland">ttheland</a></p>
107
+ ''')
108
+
109
+ # enable queue
110
+ enable_queue = True
111
+
112
+ iface = gr.Interface(
113
+ do_inference,
114
+ im,
115
+ gr.outputs.Label(num_top_classes=2),
116
+ live=False,
117
+ interpretation=None,
118
+ title=title,
119
+ description=description,
120
+ article= article_text,
121
+ examples=examples,
122
+ enable_queue=enable_queue
123
+ )
124
+
125
+ iface.test.launch()
126
+
127
+ iface.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torchvision
2
+ gdown
3
+ markdown