vasu0508 commited on
Commit
4ee4e2e
·
1 Parent(s): efe6781
Files changed (4) hide show
  1. app.py +167 -0
  2. models/model.ckpt +3 -0
  3. requirements.txt +89 -0
  4. scripts/rename.py +33 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from PIL import Image
4
+ import pytorch_lightning as pl
5
+ import torch.nn as nn
6
+ from torchvision import transforms as T
7
+ from torchvision import models
8
+ import matplotlib.pyplot as plt
9
+ import onnxruntime as ort
10
+ from glob import glob
11
+ import streamlit as st
12
+ import numpy as np
13
+ from torchmetrics.functional import accuracy
14
+ from torchmetrics import Accuracy
15
+
16
+ #Define the labels
17
+ labels = ['Defect', 'Non-Defect']
18
+
19
+ # Define the sample images
20
+ sample_images = {
21
+ "Defect01": "pics/Defect/2.jpg",
22
+ "Defect02": "pics/Defect/6.jpg",
23
+ "Defect03": "pics/Defect/8.jpg",
24
+ "Non-Defect01": "pics/nDefect/3.jpg",
25
+ "Non-Defect02": "pics/nDefect/4.jpg",
26
+ "Non-Defect03": "pics/nDefect/8.jpg"
27
+ }
28
+
29
+ class DefectResNet(pl.LightningModule):
30
+ def __init__(self, n_classes=2):
31
+ super(DefectResNet, self).__init__()
32
+
33
+ # จำนวนของพันธุ์output (2)
34
+ self.n_classes = n_classes
35
+
36
+ #เปลี่ยน layer สุดท้าย
37
+ self.backbone = models.resnet50(pretrained=True)
38
+ # self.backbone = models.resnet152(pretrained=True)
39
+ # self.backbone = models.vgg19(pretrained=True)
40
+ for param in self.backbone.parameters():
41
+ param.requires_grad = False
42
+
43
+ # เปลี่ยน fc layer เป็น output ขนาด 2
44
+ self.backbone.fc = torch.nn.Linear(self.backbone.fc.in_features, n_classes) #For ResNet base mdoel
45
+ # self.backbone.classifier[6] = torch.nn.Linear(self.backbone.classifier[6].in_features, n_classes) #For VGG bse model
46
+
47
+ self.entropy_loss = nn.CrossEntropyLoss()
48
+ self.accuracy = Accuracy(task="multiclass", num_classes=2)
49
+
50
+ self.save_hyperparameters(logger=False)
51
+
52
+ def forward(self, x):
53
+ preds = self.backbone(x)
54
+ return preds
55
+
56
+ def training_step(self, batch, batch_idx):
57
+ x, y = batch
58
+ logits = self.backbone(x)
59
+ loss = self.entropy_loss(logits, y)
60
+ y_pred = torch.argmax(logits, dim=1)
61
+ self.log("train_loss", loss)
62
+ self.log("train_acc", self.accuracy(y_pred, y))
63
+ return loss
64
+
65
+ def validation_step(self, batch, batch_idx):
66
+ x, y = batch
67
+ logits = self.backbone(x)
68
+ loss = self.entropy_loss(logits, y)
69
+ y_pred = torch.argmax(logits, dim=1)
70
+ self.log("val_loss", loss)
71
+ self.log("val_acc", self.accuracy(y_pred, y))
72
+ return loss
73
+
74
+ def configure_optimizers(self):
75
+ self.optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
76
+ return {
77
+ "optimizer": self.optimizer,
78
+ "monitor": "val_loss",
79
+ }
80
+
81
+ def test_step(self, batch, batch_idx):
82
+ x, y = batch
83
+ logits = self.backbone(x)
84
+ loss = self.entropy_loss(logits, y)
85
+ y_pred = torch.argmax(logits, dim=1)
86
+ self.log("val_loss", loss)
87
+ self.log("val_acc", self.accuracy(y_pred, y))
88
+ return loss
89
+
90
+ def _shared_eval_step(self, batch, batch_idx):
91
+ x, y = batch
92
+ y_hat = self.model(x)
93
+ logits = self.backbone(x)
94
+ loss = self.entropy_loss(logits, y)
95
+ acc = accuracy(y_hat, y)
96
+ return loss, acc
97
+
98
+ # Load the model on the appropriate device
99
+ loadmodel = DefectResNet()
100
+ def load_checkpoint(checkpoint):
101
+ loadmodel.load_state_dict(checkpoint["state_dict"])
102
+ load_checkpoint(torch.load("models/model.ckpt", map_location=torch.device('cpu')))
103
+ loadmodel.eval()
104
+
105
+ transform = T.Compose([
106
+ T.Resize((224, 224)),
107
+ T.ToTensor()
108
+ ])
109
+
110
+ def predict(image):
111
+ image = transform(image).unsqueeze(0)
112
+
113
+ # Perform the prediction
114
+ with torch.no_grad():
115
+ logits = loadmodel(image)
116
+ probs = F.softmax(logits, dim=1)
117
+ return probs
118
+
119
+ # Define the Streamlit app
120
+ def app():
121
+ predictions = None
122
+ st.title("Digital textile printing defect classification for industrial.")
123
+ uploaded_file = st.file_uploader("Upload your image...", type=["jpg"])
124
+
125
+ with st.expander("Or choose from sample here..."):
126
+ sample = st.selectbox(label = "Select here", options = list(sample_images.keys()), label_visibility="hidden")
127
+ col1, col2, col3 = st.columns(3)
128
+ with col1:
129
+ st.image(sample_images["Defect01"], caption="Defect01", use_column_width=True)
130
+ with col2:
131
+ st.image(sample_images["Defect02"], caption="Defect02", use_column_width=True)
132
+ with col3:
133
+ st.image(sample_images["Defect03"], caption="Defect03", use_column_width=True)
134
+ col1, col2, col3 = st.columns(3)
135
+ with col1:
136
+ st.image(sample_images["Non-Defect01"], caption="Non-Defect01", use_column_width=True)
137
+ with col2:
138
+ st.image(sample_images["Non-Defect02"], caption="Non-Defect02", use_column_width=True)
139
+ with col3:
140
+ st.image(sample_images["Non-Defect03"], caption="Non-Defect03", use_column_width=True)
141
+
142
+ # If an image is uploaded, make a prediction on it
143
+ if uploaded_file is not None:
144
+ image = Image.open(uploaded_file)
145
+ st.image(image, caption="Uploaded Image", use_column_width=True)
146
+ predictions = predict(image)
147
+ elif sample:
148
+ image = Image.open(sample_images[sample])
149
+ st.image(image, caption=sample.capitalize() + " Image", use_column_width=True)
150
+ predictions = predict(image)
151
+
152
+ # Show predictions with their probabilities
153
+ if predictions is not None:
154
+ # st.write(predictions)
155
+ st.subheader(f'Predictions : {labels[torch.argmax(predictions[0]).item()]}')
156
+ for pred, prob in zip(labels, predictions[0]):
157
+ st.write(f"{pred}: {prob * 100:.2f}%")
158
+ st.progress(prob.item())
159
+ else:
160
+ st.write("No predictions.")
161
+ st.subheader("Credits")
162
+ st.write("By : Settapun Laoaree | AI-Builders")
163
+ st.markdown("Source : [Github](https://github.com/ShokulSet/DefectDetection-AIBuilders) [Hugging Face](https://huggingface.co/spaces/sh0kul/DefectDetection-Deploy)")
164
+
165
+ # Run the app
166
+ if __name__ == "__main__":
167
+ app()
models/model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5601b234e608862cde6159ba32bd77a3e5e2b23e41ce488ee778bf4154419090
3
+ size 94409193
requirements.txt ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.8.4
2
+ aiosignal==1.3.1
3
+ altair==4.2.2
4
+ async-timeout==4.0.2
5
+ attrs==23.1.0
6
+ blinker==1.6.2
7
+ cachetools==5.3.0
8
+ certifi==2023.5.7
9
+ charset-normalizer==3.1.0
10
+ click==8.1.3
11
+ cmake==3.26.3
12
+ coloredlogs==15.0.1
13
+ contourpy==1.0.7
14
+ cycler==0.11.0
15
+ decorator==5.1.1
16
+ entrypoints==0.4
17
+ filelock==3.12.0
18
+ flatbuffers==23.5.9
19
+ fonttools==4.39.4
20
+ frozenlist==1.3.3
21
+ fsspec==2023.5.0
22
+ gitdb==4.0.10
23
+ GitPython==3.1.31
24
+ humanfriendly==10.0
25
+ idna==3.4
26
+ importlib-metadata==6.6.0
27
+ Jinja2==3.1.2
28
+ jsonschema==4.17.3
29
+ kiwisolver==1.4.4
30
+ lightning-utilities==0.8.0
31
+ lit==16.0.5
32
+ markdown-it-py==2.2.0
33
+ MarkupSafe==2.1.2
34
+ matplotlib==3.7.1
35
+ mdurl==0.1.2
36
+ mpmath==1.3.0
37
+ multidict==6.0.4
38
+ networkx==3.1
39
+ numpy==1.24.3
40
+ nvidia-cublas-cu11==11.10.3.66
41
+ nvidia-cuda-cupti-cu11==11.7.101
42
+ nvidia-cuda-nvrtc-cu11==11.7.99
43
+ nvidia-cuda-runtime-cu11==11.7.99
44
+ nvidia-cudnn-cu11==8.5.0.96
45
+ nvidia-cufft-cu11==10.9.0.58
46
+ nvidia-curand-cu11==10.2.10.91
47
+ nvidia-cusolver-cu11==11.4.0.1
48
+ nvidia-cusparse-cu11==11.7.4.91
49
+ nvidia-nccl-cu11==2.14.3
50
+ nvidia-nvtx-cu11==11.7.91
51
+ onnxruntime==1.14.1
52
+ packaging==23.1
53
+ pandas==2.0.1
54
+ Pillow==9.5.0
55
+ protobuf==3.20.3
56
+ pyarrow==12.0.0
57
+ pydeck==0.8.1b0
58
+ Pygments==2.15.1
59
+ Pympler==1.0.1
60
+ pyparsing==3.0.9
61
+ pyrsistent==0.19.3
62
+ python-dateutil==2.8.2
63
+ pytorch-lightning==2.0.2
64
+ pytz==2023.3
65
+ pytz-deprecation-shim==0.1.0.post0
66
+ PyYAML==6.0
67
+ requests==2.31.0
68
+ rich==13.3.5
69
+ six==1.16.0
70
+ smmap==5.0.0
71
+ streamlit==1.23.1
72
+ sympy==1.12
73
+ tenacity==8.2.2
74
+ toml==0.10.2
75
+ toolz==0.12.0
76
+ torch==2.0.1
77
+ torchmetrics==0.11.4
78
+ torchvision==0.15.2
79
+ tornado==6.3.2
80
+ tqdm==4.65.0
81
+ triton==2.0.0
82
+ typing_extensions==4.6.0
83
+ tzdata==2023.3
84
+ tzlocal==4.3
85
+ urllib3==2.0.2
86
+ validators==0.20.0
87
+ watchdog==3.0.0
88
+ yarl==1.9.2
89
+ zipp==3.15.0
scripts/rename.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python 3 code to rename multiple
2
+ # files in a directory or defect_path
3
+
4
+ # importing os module
5
+ import os
6
+
7
+ # Function to rename multiple files
8
+ def main():
9
+
10
+ defect_path = "../pics/Defect"
11
+ ndefect_path = "../pics/nDefect"
12
+ for count, filename in enumerate(os.listdir(defect_path)):
13
+ dst = f"{str(count)}.jpg"
14
+ src =f"{defect_path}/{filename}" # defect_pathname/filename, if .py file is outside defect_path
15
+ dst =f"{defect_path}/{dst}"
16
+
17
+ # rename() function will
18
+ # rename all the files
19
+ os.rename(src, dst)
20
+
21
+ for count, filename in enumerate(os.listdir(ndefect_path)):
22
+ dst = f"{str(count)}.jpg"
23
+ src =f"{ndefect_path}/{filename}" # defect_pathname/filename, if .py file is outside defect_path
24
+ dst =f"{ndefect_path}/{dst}"
25
+ # rename() function will
26
+ # rename all the files
27
+ os.rename(src, dst)
28
+
29
+ # Driver Code
30
+ if __name__ == '__main__':
31
+
32
+ # Calling main() function
33
+ main()