kuko6 commited on
Commit
663913f
1 Parent(s): 8cc92e2

Upload 6 files

Browse files
Files changed (4) hide show
  1. app.py +29 -0
  2. labels.json +154 -0
  3. model.py +46 -0
  4. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms.functional as TF
4
+ from model import NeuralNetwork
5
+ import json
6
+
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ def pokemon_classifier(inp):
10
+ model = NeuralNetwork()
11
+ model.load_state_dict(torch.load('model_best.pt', map_location=torch.device(device)))
12
+ model.eval()
13
+
14
+ with open('labels.json') as f:
15
+ labels = json.load(f)
16
+
17
+ x = TF.to_tensor(inp)
18
+ x = TF.resize(x, 64, antialias=True)
19
+ x = x.to(device)
20
+ x = x.unsqueeze(0)
21
+
22
+ with torch.no_grad():
23
+ y_pred = model(x)
24
+ pokemon = torch.argmax(y_pred, dim=1).item()
25
+
26
+ return labels[str(pokemon)]
27
+
28
+ demo = gr.Interface(fn=pokemon_classifier, inputs=gr.Image(type="pil"), outputs="text")
29
+ demo.launch()
labels.json ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": "Abra",
3
+ "1": "Aerodactyl",
4
+ "2": "Alakazam",
5
+ "3": "Arbok",
6
+ "4": "Arcanine",
7
+ "5": "Articuno",
8
+ "6": "Beedrill",
9
+ "7": "Bellsprout",
10
+ "8": "Blastoise",
11
+ "9": "Bulbasaur",
12
+ "10": "Butterfree",
13
+ "11": "Caterpie",
14
+ "12": "Chansey",
15
+ "13": "Charizard",
16
+ "14": "Charmander",
17
+ "15": "Charmeleon",
18
+ "16": "Clefable",
19
+ "17": "Clefairy",
20
+ "18": "Cloyster",
21
+ "19": "Cubone",
22
+ "20": "Dewgong",
23
+ "21": "Diglett",
24
+ "22": "Ditto",
25
+ "23": "Dodrio",
26
+ "24": "Doduo",
27
+ "25": "Dragonair",
28
+ "26": "Dragonite",
29
+ "27": "Dratini",
30
+ "28": "Drowzee",
31
+ "29": "Dugtrio",
32
+ "30": "Eevee",
33
+ "31": "Ekans",
34
+ "32": "Electabuzz",
35
+ "33": "Electrode",
36
+ "34": "Exeggcute",
37
+ "35": "Exeggutor",
38
+ "36": "Farfetchd",
39
+ "37": "Fearow",
40
+ "38": "Flareon",
41
+ "39": "Gastly",
42
+ "40": "Gengar",
43
+ "41": "Geodude",
44
+ "42": "Gloom",
45
+ "43": "Golbat",
46
+ "44": "Goldeen",
47
+ "45": "Golduck",
48
+ "46": "Golem",
49
+ "47": "Graveler",
50
+ "48": "Grimer",
51
+ "49": "Growlithe",
52
+ "50": "Gyarados",
53
+ "51": "Haunter",
54
+ "52": "Hitmonchan",
55
+ "53": "Hitmonlee",
56
+ "54": "Horsea",
57
+ "55": "Hypno",
58
+ "56": "Ivysaur",
59
+ "57": "Jigglypuff",
60
+ "58": "Jolteon",
61
+ "59": "Jynx",
62
+ "60": "Kabuto",
63
+ "61": "Kabutops",
64
+ "62": "Kadabra",
65
+ "63": "Kakuna",
66
+ "64": "Kangaskhan",
67
+ "65": "Kingler",
68
+ "66": "Koffing",
69
+ "67": "Krabby",
70
+ "68": "Lapras",
71
+ "69": "Lickitung",
72
+ "70": "Machamp",
73
+ "71": "Machoke",
74
+ "72": "Machop",
75
+ "73": "Magikarp",
76
+ "74": "Magmar",
77
+ "75": "Magnemite",
78
+ "76": "Magneton",
79
+ "77": "Mankey",
80
+ "78": "Marowak",
81
+ "79": "Meowth",
82
+ "80": "Metapod",
83
+ "81": "Mew",
84
+ "82": "Mewtwo",
85
+ "83": "Moltres",
86
+ "84": "Mr.Mime",
87
+ "85": "Muk",
88
+ "86": "Nidoking",
89
+ "87": "Nidoqueen",
90
+ "88": "Nidoran-f",
91
+ "89": "Nidoran-m",
92
+ "90": "Nidorina",
93
+ "91": "Nidorino",
94
+ "92": "Ninetales",
95
+ "93": "Oddish",
96
+ "94": "Omanyte",
97
+ "95": "Omastar",
98
+ "96": "Onix",
99
+ "97": "Paras",
100
+ "98": "Parasect",
101
+ "99": "Persian",
102
+ "100": "Pidgeot",
103
+ "101": "Pidgeotto",
104
+ "102": "Pidgey",
105
+ "103": "Pikachu",
106
+ "104": "Pinsir",
107
+ "105": "Poliwag",
108
+ "106": "Poliwhirl",
109
+ "107": "Poliwrath",
110
+ "108": "Ponyta",
111
+ "109": "Porygon",
112
+ "110": "Primeape",
113
+ "111": "Psyduck",
114
+ "112": "Raichu",
115
+ "113": "Rapidash",
116
+ "114": "Raticate",
117
+ "115": "Rattata",
118
+ "116": "Rhydon",
119
+ "117": "Rhyhorn",
120
+ "118": "Sandshrew",
121
+ "119": "Sandslash",
122
+ "120": "Scyther",
123
+ "121": "Seadra",
124
+ "122": "Seaking",
125
+ "123": "Seel",
126
+ "124": "Shellder",
127
+ "125": "Slowbro",
128
+ "126": "Slowpoke",
129
+ "127": "Snorlax",
130
+ "128": "Spearow",
131
+ "129": "Squirtle",
132
+ "130": "Starmie",
133
+ "131": "Staryu",
134
+ "132": "Tangela",
135
+ "133": "Tauros",
136
+ "134": "Tentacool",
137
+ "135": "Tentacruel",
138
+ "136": "Vaporeon",
139
+ "137": "Venomoth",
140
+ "138": "Venonat",
141
+ "139": "Venusaur",
142
+ "140": "Victreebel",
143
+ "141": "Vileplume",
144
+ "142": "Voltorb",
145
+ "143": "Vulpix",
146
+ "144": "Wartortle",
147
+ "145": "Weedle",
148
+ "146": "Weepinbell",
149
+ "147": "Weezing",
150
+ "148": "Wigglytuff",
151
+ "149": "Zapdos",
152
+ "150": "Zubat"
153
+ }
154
+
model.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class NeuralNetwork(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.conv = nn.Sequential(
8
+ nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
9
+ nn.BatchNorm2d(64),
10
+ nn.ReLU(),
11
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1),
12
+ nn.Dropout2d(p=0.3),
13
+
14
+ nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
15
+ nn.BatchNorm2d(128),
16
+ nn.ReLU(),
17
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),
18
+ nn.Dropout2d(p=0.3),
19
+
20
+ nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
21
+ nn.BatchNorm2d(256),
22
+ nn.ReLU(),
23
+ nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1),
24
+ nn.Dropout2d(p=0.4),
25
+
26
+ nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
27
+ nn.BatchNorm2d(256),
28
+ nn.ReLU(),
29
+ nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1),
30
+ )
31
+ self.flatten = nn.Flatten()
32
+ self.fc = nn.Sequential(
33
+ nn.Linear(256*4*4, 256),
34
+ nn.BatchNorm1d(256),
35
+ nn.ReLU(),
36
+ nn.Dropout(p=0.5),
37
+ nn.Linear(256, 151),
38
+ )
39
+
40
+ def forward(self, x):
41
+ out = self.conv(x)
42
+ out = self.flatten(out)
43
+ out = self.fc(out)
44
+
45
+ return out
46
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ torchvision