RingoDingo
commited on
Commit
·
1f6d2ce
1
Parent(s):
d02bda9
Upload 2 files
Browse files- maid_classifier_model.py +94 -0
- maidel_E22.pth +3 -0
maid_classifier_model.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from PIL import Image
|
4 |
+
from torchvision.transforms import ToTensor
|
5 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
6 |
+
|
7 |
+
class MCNN(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(MCNN, self).__init__()
|
10 |
+
|
11 |
+
# Convolution layers
|
12 |
+
self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
|
13 |
+
self.bn1 = nn.BatchNorm2d(64)
|
14 |
+
|
15 |
+
self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)
|
16 |
+
self.bn2 = nn.BatchNorm2d(128)
|
17 |
+
|
18 |
+
self.conv3 = nn.Conv2d(128, 256, 3, 1, 1)
|
19 |
+
self.bn3 = nn.BatchNorm2d(256)
|
20 |
+
|
21 |
+
self.conv4 = nn.Conv2d(256, 512, 3, 1, 1) # Added another convolutional layer
|
22 |
+
self.bn4 = nn.BatchNorm2d(512)
|
23 |
+
|
24 |
+
# Pooling layer
|
25 |
+
self.pool = nn.MaxPool2d(2, 2)
|
26 |
+
|
27 |
+
# Fully connected layers
|
28 |
+
self.fc1 = nn.Linear(100352, 2048)
|
29 |
+
self.fc2 = nn.Linear(2048, 1024)
|
30 |
+
self.fc3 = nn.Linear(1024, 512)
|
31 |
+
self.fc4 = nn.Linear(512, 256)
|
32 |
+
self.fc5 = nn.Linear(256, 2) # Two classes
|
33 |
+
|
34 |
+
# Activation and dropout
|
35 |
+
self.relu = nn.ReLU()
|
36 |
+
self.dropout = nn.Dropout(0.2)
|
37 |
+
|
38 |
+
def forward(self, pixel_values, labels=None):
|
39 |
+
x = self.pool(self.relu(self.bn1(self.conv1(pixel_values))))
|
40 |
+
x = self.pool(self.relu(self.bn2(self.conv2(x))))
|
41 |
+
x = self.pool(self.relu(self.bn3(self.conv3(x))))
|
42 |
+
x = self.pool(self.relu(self.bn4(self.conv4(x)))) # Pass through the added conv layer
|
43 |
+
|
44 |
+
x = x.view(x.size(0), -1) # flatten
|
45 |
+
x = self.dropout(self.relu(self.fc1(x)))
|
46 |
+
x = self.dropout(self.relu(self.fc2(x)))
|
47 |
+
x = self.dropout(self.relu(self.fc3(x)))
|
48 |
+
x = self.dropout(self.relu(self.fc4(x)))
|
49 |
+
logits = self.fc5(x)
|
50 |
+
|
51 |
+
loss = None
|
52 |
+
if labels is not None:
|
53 |
+
loss_fct = nn.CrossEntropyLoss()
|
54 |
+
loss = loss_fct(logits.view(-1, 2), labels.view(-1))
|
55 |
+
|
56 |
+
if loss is not None:
|
57 |
+
return logits, loss.item()
|
58 |
+
else:
|
59 |
+
return logits, None
|
60 |
+
|
61 |
+
def preprocess_image(img, desired_size=224):
|
62 |
+
im = img
|
63 |
+
|
64 |
+
# Resize and pad the image
|
65 |
+
old_size = im.size
|
66 |
+
ratio = float(desired_size) / max(old_size)
|
67 |
+
new_size = tuple([int(x*ratio) for x in old_size])
|
68 |
+
im = im.resize(new_size)
|
69 |
+
|
70 |
+
# Create a new image and paste the resized on it
|
71 |
+
new_im = Image.new("RGB", (desired_size, desired_size), "white")
|
72 |
+
new_im.paste(im, ((desired_size-new_size[0])//2,
|
73 |
+
(desired_size-new_size[1])//2))
|
74 |
+
return new_im
|
75 |
+
|
76 |
+
def predict_image(image, model):
|
77 |
+
# Ensure model is in eval mode
|
78 |
+
model.eval()
|
79 |
+
|
80 |
+
# Convert image to tensor
|
81 |
+
transform = ToTensor()
|
82 |
+
input_tensor = transform(image)
|
83 |
+
input_batch = input_tensor.unsqueeze(0)
|
84 |
+
|
85 |
+
# Move tensors to the right device
|
86 |
+
input_batch = input_batch.to(device)
|
87 |
+
|
88 |
+
# Forward pass of the image through the model
|
89 |
+
output = model(input_batch)
|
90 |
+
|
91 |
+
# Convert model output to probabilities using softmax
|
92 |
+
probabilities = torch.nn.functional.softmax(output[0], dim=1)
|
93 |
+
|
94 |
+
return probabilities.cpu().detach().numpy()
|
maidel_E22.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c785bfd1ecccce3fb08e55eb91b11a05d122577a100e20f2d4acb0564cd234b4
|
3 |
+
size 839342059
|