root
commited on
Commit
•
11aa02b
1
Parent(s):
40b1173
Add application file
Browse files- app.py +114 -0
- base64_img.bin +3 -0
- cnn_net.pt +3 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torchvision
|
4 |
+
import base64
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import gradio as gr
|
7 |
+
from PIL import Image
|
8 |
+
from pathlib import Path
|
9 |
+
from torch import nn
|
10 |
+
from torchvision import transforms
|
11 |
+
|
12 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
13 |
+
base_path = str(Path(__file__).parent)
|
14 |
+
path_of_model = base_path + "/cnn_net.pt"
|
15 |
+
default_img = base_path + "/base64_img.bin"
|
16 |
+
|
17 |
+
def load_the_model():
|
18 |
+
class Net(nn.Module):
|
19 |
+
def __init__(self):
|
20 |
+
super().__init__()
|
21 |
+
self.conv1 = nn.Conv2d(1, 6, 5)
|
22 |
+
self.pool = nn.MaxPool2d(2, 2)
|
23 |
+
self.conv2 = nn.Conv2d(6, 16, 5)
|
24 |
+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
25 |
+
self.fc2 = nn.Linear(120, 84)
|
26 |
+
self.fc3 = nn.Linear(84, 39)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
x = self.pool(F.relu(self.conv1(x)))
|
30 |
+
x = self.pool(F.relu(self.conv2(x)))
|
31 |
+
x = torch.flatten(x, 1) # flatten all dimensions except batch
|
32 |
+
x = F.relu(self.fc1(x))
|
33 |
+
x = F.relu(self.fc2(x))
|
34 |
+
x = self.fc3(x)
|
35 |
+
return x
|
36 |
+
|
37 |
+
loaded_model = Net().to(device)
|
38 |
+
loaded_model.load_state_dict(torch.load(path_of_model))
|
39 |
+
loaded_model.eval()
|
40 |
+
return loaded_model
|
41 |
+
|
42 |
+
def loads_data_classes() -> list:
|
43 |
+
"""
|
44 |
+
Load the class labels for the prediction.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
list: A list of class labels.
|
48 |
+
"""
|
49 |
+
class_labels = ['ا','ب','ت','ث','ج','ح','خ','د','ذ','ر','ز','س','ش','ص','ض','ط','ظ','ع','غ','ف',
|
50 |
+
'ق','ك','ل','لا','م','ن','ه','و','ي','٠','١','٢','٣','٤','٥','٦','٧','٨','٩']
|
51 |
+
return class_labels
|
52 |
+
|
53 |
+
def base64_to_image(base64_file):
|
54 |
+
|
55 |
+
# Decode the Base64 string to binary data
|
56 |
+
image_data = base64.b64decode(base64_file)
|
57 |
+
|
58 |
+
image_path = 'decoded.png'
|
59 |
+
with open(image_path, 'wb') as output_file:
|
60 |
+
output_file.write(image_data)
|
61 |
+
|
62 |
+
return image_path
|
63 |
+
|
64 |
+
def read_base64_file(file_path):
|
65 |
+
with open(file_path, 'r') as file:
|
66 |
+
base64_string = file.read()
|
67 |
+
return base64_string
|
68 |
+
|
69 |
+
def predict_on_base64(model, base64_file):
|
70 |
+
path = base64_to_image(base64_file)
|
71 |
+
img = Image.open(path)
|
72 |
+
model.eval()
|
73 |
+
with torch.inference_mode():
|
74 |
+
custom_image = torchvision.io.read_image(str(path)).type(torch.float32)
|
75 |
+
# Divide the image pixel values by 255 to get them between [0, 1]
|
76 |
+
custom_image = custom_image / 255.
|
77 |
+
# apply the model transformations
|
78 |
+
transform_img = transforms.Compose([
|
79 |
+
transforms.Grayscale(),
|
80 |
+
transforms.Resize((32,32)),
|
81 |
+
])
|
82 |
+
|
83 |
+
# Transform target image
|
84 |
+
custom_image_transformed = transform_img(custom_image)
|
85 |
+
# Add an extra dimension to image (Batch_size)
|
86 |
+
custom_image_transformed_with_batch_size = custom_image_transformed.unsqueeze(dim=0)
|
87 |
+
|
88 |
+
# Make a prediction on image with an extra dimension
|
89 |
+
custom_image_pred = model(custom_image_transformed_with_batch_size.to(device))
|
90 |
+
# Getting the probs
|
91 |
+
prob = torch.softmax(custom_image_pred, dim=1)
|
92 |
+
# Getting the sample prob
|
93 |
+
sample_prob = round(prob[0][prob.argmax()].item(), 3)
|
94 |
+
# getting the highest logit
|
95 |
+
test_pred_labels = custom_image_pred.argmax(dim=1).item()
|
96 |
+
labels = loads_data_classes()
|
97 |
+
test_pred_labels = labels[test_pred_labels]
|
98 |
+
|
99 |
+
return test_pred_labels, sample_prob, img
|
100 |
+
|
101 |
+
model = load_the_model() # load the model
|
102 |
+
|
103 |
+
def predict(user_base_64_file):
|
104 |
+
base64_string = read_base64_file(user_base_64_file.name) # convert the base64 image to string
|
105 |
+
prediction, probability, img = predict_on_base64(model=model, base64_file=base64_string) # use the model and getting prediction
|
106 |
+
return prediction, probability, img
|
107 |
+
|
108 |
+
|
109 |
+
demo = gr.Interface(fn=predict,
|
110 |
+
inputs=gr.File(value = default_img ,label="Upload a Base64 Image File"),
|
111 |
+
outputs=[gr.Textbox(label="Predicted Label"), gr.Textbox(label="Probability"), gr.Image(label="Image")],
|
112 |
+
title="Arabic Letter Recognition", allow_flagging=False
|
113 |
+
)
|
114 |
+
demo.launch(share=True)
|
base64_img.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9ba29ee5c5a1969449ca242baf100450addfe9f5d66b8a41f0fa37b311fd7deb
|
3 |
+
size 496
|
cnn_net.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d1c91753db110ed44205879307a8533ec57a64b5940406c8ff1132f3cc92bb41
|
3 |
+
size 260024
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Flask
|
2 |
+
torch==2.3.0+cpu
|
3 |
+
torchvision==0.18.0+cpu
|
4 |
+
-f https://download.pytorch.org/whl/torch_stable.html
|