File size: 4,212 Bytes
11aa02b
 
 
 
 
 
 
 
 
 
 
73ad39d
11aa02b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e427f9
e879ffa
11aa02b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e427f9
11aa02b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7398a9f
11aa02b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

import torch
import torchvision
import base64 
import torch.nn.functional as F
import gradio as gr
from PIL import Image
from pathlib import Path
from torch import nn
from torchvision import transforms

device = 'cpu' # 'cuda' if torch.cuda.is_available() else 
base_path = str(Path(__file__).parent)
path_of_model = base_path + "/cnn_net.pt"
default_img = base_path + "/base64_img.bin"

def load_the_model():
  class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 39)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
  loaded_model = Net()#.to(device)
  loaded_model.load_state_dict(torch.load(path_of_model, map_location=torch.device('cpu')))
  loaded_model.eval()
  return loaded_model

def loads_data_classes() -> list:
    """
    Load the class labels for the prediction.

    Returns:
        list: A list of class labels.
    """
    class_labels = ['ا','ب','ت','ث','ج','ح','خ','د','ذ','ر','ز','س','ش','ص','ض','ط','ظ','ع','غ','ف',
                    'ق','ك','ل','لا','م','ن','ه','و','ي','٠','١','٢','٣','٤','٥','٦','٧','٨','٩']
    return class_labels

def base64_to_image(base64_file):

    # Decode the Base64 string to binary data
    image_data = base64.b64decode(base64_file)
    
    image_path = 'decoded.png'
    with open(image_path, 'wb') as output_file:
        output_file.write(image_data)
    
    return image_path

def read_base64_file(file_path):
    with open(file_path, 'r') as file:
        base64_string = file.read()
    return base64_string

def predict_on_base64(model, base64_file):
    path = base64_to_image(base64_file)
    img = Image.open(path)
    model.eval()
    with torch.inference_mode():
        custom_image = torchvision.io.read_image(str(path)).type(torch.float32)
        # Divide the image pixel values by 255 to get them between [0, 1]
        custom_image = custom_image / 255. 
        # apply the model transformations
        transform_img = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize((32,32)),
        ])

        # Transform target image
        custom_image_transformed = transform_img(custom_image)
        # Add an extra dimension to image (Batch_size)
        custom_image_transformed_with_batch_size = custom_image_transformed.unsqueeze(dim=0)
        
        # Make a prediction on image with an extra dimension
        custom_image_pred = model(custom_image_transformed_with_batch_size) # .to(device)
        # Getting the probs
        prob = torch.softmax(custom_image_pred, dim=1)
        # Getting the sample prob
        sample_prob = round(prob[0][prob.argmax()].item(), 3)
        # getting the highest logit
        test_pred_labels = custom_image_pred.argmax(dim=1).item()
    labels = loads_data_classes()
    test_pred_labels = labels[test_pred_labels]
    
    return test_pred_labels, sample_prob, img

model = load_the_model() # load the model

def predict(user_base_64_file):
  base64_string = read_base64_file(user_base_64_file.name) # convert the base64 image to string
  prediction, probability, img = predict_on_base64(model=model, base64_file=base64_string) # use the model and getting prediction
  return prediction, probability, img


demo = gr.Interface(fn=predict,
                                              inputs=gr.File(value = default_img ,label="Upload a Base64 Image File with .txt(utf-8 format) or .bin"),
                                              outputs=[gr.Textbox(label="Predicted Label"), gr.Textbox(label="Probability"), gr.Image(label="Image")],
                                              title="Arabic Letter Recognition", allow_flagging=False
                                              )
demo.launch(share=True)