root commited on
Commit
11aa02b
1 Parent(s): 40b1173

Add application file

Browse files
Files changed (4) hide show
  1. app.py +114 -0
  2. base64_img.bin +3 -0
  3. cnn_net.pt +3 -0
  4. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torchvision
4
+ import base64
5
+ import torch.nn.functional as F
6
+ import gradio as gr
7
+ from PIL import Image
8
+ from pathlib import Path
9
+ from torch import nn
10
+ from torchvision import transforms
11
+
12
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+ base_path = str(Path(__file__).parent)
14
+ path_of_model = base_path + "/cnn_net.pt"
15
+ default_img = base_path + "/base64_img.bin"
16
+
17
+ def load_the_model():
18
+ class Net(nn.Module):
19
+ def __init__(self):
20
+ super().__init__()
21
+ self.conv1 = nn.Conv2d(1, 6, 5)
22
+ self.pool = nn.MaxPool2d(2, 2)
23
+ self.conv2 = nn.Conv2d(6, 16, 5)
24
+ self.fc1 = nn.Linear(16 * 5 * 5, 120)
25
+ self.fc2 = nn.Linear(120, 84)
26
+ self.fc3 = nn.Linear(84, 39)
27
+
28
+ def forward(self, x):
29
+ x = self.pool(F.relu(self.conv1(x)))
30
+ x = self.pool(F.relu(self.conv2(x)))
31
+ x = torch.flatten(x, 1) # flatten all dimensions except batch
32
+ x = F.relu(self.fc1(x))
33
+ x = F.relu(self.fc2(x))
34
+ x = self.fc3(x)
35
+ return x
36
+
37
+ loaded_model = Net().to(device)
38
+ loaded_model.load_state_dict(torch.load(path_of_model))
39
+ loaded_model.eval()
40
+ return loaded_model
41
+
42
+ def loads_data_classes() -> list:
43
+ """
44
+ Load the class labels for the prediction.
45
+
46
+ Returns:
47
+ list: A list of class labels.
48
+ """
49
+ class_labels = ['ا','ب','ت','ث','ج','ح','خ','د','ذ','ر','ز','س','ش','ص','ض','ط','ظ','ع','غ','ف',
50
+ 'ق','ك','ل','لا','م','ن','ه','و','ي','٠','١','٢','٣','٤','٥','٦','٧','٨','٩']
51
+ return class_labels
52
+
53
+ def base64_to_image(base64_file):
54
+
55
+ # Decode the Base64 string to binary data
56
+ image_data = base64.b64decode(base64_file)
57
+
58
+ image_path = 'decoded.png'
59
+ with open(image_path, 'wb') as output_file:
60
+ output_file.write(image_data)
61
+
62
+ return image_path
63
+
64
+ def read_base64_file(file_path):
65
+ with open(file_path, 'r') as file:
66
+ base64_string = file.read()
67
+ return base64_string
68
+
69
+ def predict_on_base64(model, base64_file):
70
+ path = base64_to_image(base64_file)
71
+ img = Image.open(path)
72
+ model.eval()
73
+ with torch.inference_mode():
74
+ custom_image = torchvision.io.read_image(str(path)).type(torch.float32)
75
+ # Divide the image pixel values by 255 to get them between [0, 1]
76
+ custom_image = custom_image / 255.
77
+ # apply the model transformations
78
+ transform_img = transforms.Compose([
79
+ transforms.Grayscale(),
80
+ transforms.Resize((32,32)),
81
+ ])
82
+
83
+ # Transform target image
84
+ custom_image_transformed = transform_img(custom_image)
85
+ # Add an extra dimension to image (Batch_size)
86
+ custom_image_transformed_with_batch_size = custom_image_transformed.unsqueeze(dim=0)
87
+
88
+ # Make a prediction on image with an extra dimension
89
+ custom_image_pred = model(custom_image_transformed_with_batch_size.to(device))
90
+ # Getting the probs
91
+ prob = torch.softmax(custom_image_pred, dim=1)
92
+ # Getting the sample prob
93
+ sample_prob = round(prob[0][prob.argmax()].item(), 3)
94
+ # getting the highest logit
95
+ test_pred_labels = custom_image_pred.argmax(dim=1).item()
96
+ labels = loads_data_classes()
97
+ test_pred_labels = labels[test_pred_labels]
98
+
99
+ return test_pred_labels, sample_prob, img
100
+
101
+ model = load_the_model() # load the model
102
+
103
+ def predict(user_base_64_file):
104
+ base64_string = read_base64_file(user_base_64_file.name) # convert the base64 image to string
105
+ prediction, probability, img = predict_on_base64(model=model, base64_file=base64_string) # use the model and getting prediction
106
+ return prediction, probability, img
107
+
108
+
109
+ demo = gr.Interface(fn=predict,
110
+ inputs=gr.File(value = default_img ,label="Upload a Base64 Image File"),
111
+ outputs=[gr.Textbox(label="Predicted Label"), gr.Textbox(label="Probability"), gr.Image(label="Image")],
112
+ title="Arabic Letter Recognition", allow_flagging=False
113
+ )
114
+ demo.launch(share=True)
base64_img.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ba29ee5c5a1969449ca242baf100450addfe9f5d66b8a41f0fa37b311fd7deb
3
+ size 496
cnn_net.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1c91753db110ed44205879307a8533ec57a64b5940406c8ff1132f3cc92bb41
3
+ size 260024
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Flask
2
+ torch==2.3.0+cpu
3
+ torchvision==0.18.0+cpu
4
+ -f https://download.pytorch.org/whl/torch_stable.html