rayyanphysicist commited on
Commit
e13d4d7
·
verified ·
1 Parent(s): 13d2d77

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +129 -0
  2. requirements.txt +5 -0
  3. vanilla_cnn_se.pth +3 -0
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ # Define the VanillaCNN_SE class
10
+ class SEBlock(nn.Module):
11
+ def __init__(self, channels, reduction_ratio=16):
12
+ super(SEBlock, self).__init__()
13
+ self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
14
+ self.fc1 = nn.Linear(channels, channels // reduction_ratio)
15
+ self.fc2 = nn.Linear(channels // reduction_ratio, channels)
16
+ self.sigmoid = nn.Sigmoid()
17
+
18
+ def forward(self, x):
19
+ batch_size, channels, _, _ = x.size()
20
+ y = self.global_avg_pool(x).view(batch_size, channels)
21
+ y = torch.relu(self.fc1(y))
22
+ y = self.sigmoid(self.fc2(y)).view(batch_size, channels, 1, 1)
23
+ return x * y
24
+
25
+ class VanillaCNN_SE(nn.Module):
26
+ def __init__(self, num_classes):
27
+ super(VanillaCNN_SE, self).__init__()
28
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
29
+ self.bn1 = nn.BatchNorm2d(64)
30
+ self.se1 = SEBlock(64)
31
+ self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
32
+ self.bn2 = nn.BatchNorm2d(128)
33
+ self.se2 = SEBlock(128)
34
+ self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
35
+ self.bn3 = nn.BatchNorm2d(256)
36
+ self.se3 = SEBlock(256)
37
+ self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
38
+ self.bn4 = nn.BatchNorm2d(512)
39
+ self.se4 = SEBlock(512)
40
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
41
+ self.fc1 = nn.Linear(512 * 14 * 14, 1024)
42
+ self.fc2 = nn.Linear(1024, num_classes)
43
+
44
+ def forward(self, x):
45
+ x = self.pool(torch.relu(self.bn1(self.conv1(x))))
46
+ x = self.se1(x)
47
+ x = self.pool(torch.relu(self.bn2(self.conv2(x))))
48
+ x = self.se2(x)
49
+ x = self.pool(torch.relu(self.bn3(self.conv3(x))))
50
+ x = self.se3(x)
51
+ x = self.pool(torch.relu(self.bn4(self.conv4(x))))
52
+ x = self.se4(x)
53
+ x = x.view(x.size(0), -1)
54
+ x = torch.relu(self.fc1(x))
55
+ x = self.fc2(x)
56
+ return x
57
+
58
+ # Load the model
59
+ @st.cache_resource
60
+
61
+ def load_model():
62
+ model = VanillaCNN_SE(num_classes=12) # Update num_classes as per your dataset
63
+ model.load_state_dict(torch.load("vanilla_cnn_se.pth", map_location=torch.device('cpu')))
64
+ model.eval()
65
+ return model
66
+
67
+ model = load_model()
68
+
69
+ # Define class names
70
+ class_names = [
71
+ "Maize", "Common wheat", "Common Chickweed", "Loose Silky-bent",
72
+ "Charlock", "Cleavers", "Sugar beet", "Fat Hen", "Scentless Mayweed",
73
+ "Small-flowered Cranesbill", "Shepherd’s Purse", "Black-grass"
74
+ ]
75
+
76
+ # Define transformations
77
+ transform = transforms.Compose([
78
+ transforms.Resize((224, 224)),
79
+ transforms.ToTensor()
80
+ ])
81
+
82
+ def mask_image(image):
83
+ # Convert PIL image to OpenCV format
84
+ image_np = np.array(image)
85
+ hsv_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2HSV)
86
+
87
+ # Define green color range
88
+ lower_green = np.array([30, 40, 40])
89
+ upper_green = np.array([90, 255, 255])
90
+
91
+ # Create a mask for the green area
92
+ mask = cv2.inRange(hsv_img, lower_green, upper_green)
93
+ masked_img = cv2.bitwise_and(image_np, image_np, mask=mask)
94
+
95
+ # Convert back to PIL image
96
+ return Image.fromarray(masked_img)
97
+
98
+ def predict_class(image):
99
+ # Transform the image for the model
100
+ image_tensor = transform(image).unsqueeze(0)
101
+
102
+ # Predict the class
103
+ with torch.no_grad():
104
+ outputs = model(image_tensor)
105
+ _, predicted = torch.max(outputs, 1)
106
+ return class_names[predicted.item()]
107
+
108
+ # Streamlit UI
109
+ st.title("Plant Seedling Classification")
110
+
111
+ st.write("Upload an image to classify the plant seedling and view the masked image.")
112
+
113
+ # File uploader
114
+ uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"])
115
+
116
+ if uploaded_file is not None:
117
+ # Load the image
118
+ image = Image.open(uploaded_file).convert("RGB")
119
+
120
+ # Mask the image
121
+ masked_image = mask_image(image)
122
+
123
+ # Predict the class
124
+ predicted_class = predict_class(image)
125
+
126
+ # Display results
127
+ st.image(image, caption="Original Image", use_column_width=True)
128
+ st.image(masked_image, caption="Masked Image", use_column_width=True)
129
+ st.write(f"### Predicted Class: {predicted_class}")
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ torchvision
4
+ opencv-python
5
+ numpy
vanilla_cnn_se.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:298c1f3392f0c833e00405804ccd18f95e58e78bd3076d6eef0ac14dba447062
3
+ size 417508402