Mehyaar commited on
Commit
debf278
1 Parent(s): 381fc54

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +89 -0
  2. vehicle_classifier.pth +3 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ transform_test = transforms.Compose([
9
+ transforms.Resize(256),
10
+ transforms.CenterCrop(224),
11
+ transforms.ToTensor(),
12
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
13
+ ])
14
+
15
+ class_names = [
16
+ 'Auto Rickshaws', 'Bikes', 'Cars', 'Motorcycles',
17
+ 'Planes', 'Ships', 'Trains'
18
+ ]
19
+ class VehicleClassifier(nn.Module):
20
+ def __init__(self):
21
+ super(VehicleClassifier, self).__init__()
22
+
23
+ # Convolutional Layers
24
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
25
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
26
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
27
+ self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
28
+
29
+ # Pooling Layer
30
+ self.pool = nn.MaxPool2d(2, 2)
31
+
32
+ # FC Layers
33
+ self.fc1 = nn.Linear(256 * 14 * 14, 512)
34
+ self.fc2 = nn.Linear(512, 256)
35
+ self.fc3 = nn.Linear(256, 7) # 7 classes for the 7 vehicle categories
36
+
37
+ self.dropout = nn.Dropout(0.5)
38
+
39
+ def forward(self, x):
40
+ # Apply Convolutional Layers with ReLU activation and Pooling
41
+ x = self.pool(F.relu(self.conv1(x)))
42
+ x = self.pool(F.relu(self.conv2(x)))
43
+ x = self.pool(F.relu(self.conv3(x)))
44
+ x = self.pool(F.relu(self.conv4(x)))
45
+
46
+ # Flatten the tensor before feeding into the FCL
47
+ x = x.view(-1, 256 * 14 * 14)
48
+ x = F.relu(self.fc1(x))
49
+ x = self.dropout(x)
50
+ x = F.relu(self.fc2(x))
51
+ x = self.dropout(x)
52
+ x = self.fc3(x)
53
+ return x
54
+ model = VehicleClassifier().to('cpu')
55
+ model.load_state_dict(torch.load('vehicle_classifier.pth', map_location=torch.device('cpu')))
56
+ model.eval()
57
+
58
+ def predict(image):
59
+ try:
60
+ image = Image.open(image).convert('RGB')
61
+ image = transform_test(image).unsqueeze(0) # Add batch dimension
62
+
63
+ print(f"Image shape after transformation: {image.shape}")
64
+
65
+ with torch.no_grad():
66
+ outputs = model(image)
67
+ print(f"Model output: {outputs}")
68
+ _, predicted = torch.max(outputs, 1)
69
+
70
+ prediction = class_names[predicted.item()]
71
+ print(f"Predicted class: {prediction}")
72
+
73
+ return prediction
74
+ except Exception as e:
75
+ print(f"Error during prediction: {e}")
76
+ traceback.print_exc()
77
+ return "An error occurred during prediction."
78
+
79
+
80
+ interface = gr.Interface(
81
+ fn=predict,
82
+ inputs=gr.Image(type='filepath'),
83
+ outputs=gr.Label(num_top_classes=1),
84
+ title="Vehicle Classification",
85
+ description="Upload an image of a vehicle, and the model will predict its type."
86
+ )
87
+
88
+ # Launch the interface
89
+ interface.launch(share=True)
vehicle_classifier.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0ed7dbfc5083a96209bad40a0eb4e7889505456a0ffcb999406fff96e14c68d
3
+ size 104853810