File size: 3,088 Bytes
e1e7fa2
 
 
 
 
 
 
 
 
 
 
 
eca1055
e1e7fa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
718237e
e1e7fa2
 
 
 
 
 
 
 
 
 
 
 
 
718237e
e1e7fa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed3f8d6
e1e7fa2
 
 
 
 
b7472a4
ed3f8d6
 
 
b7472a4
 
e1e7fa2
 
 
b7472a4
 
e1e7fa2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import torch
import requests
import gradio as gr
from tqdm import tqdm
from PIL import Image
from model import Model
from torchvision import transforms
import warnings
warnings.filterwarnings("ignore")


def download_model(url="https://huggingface.co/MuGeminorum/SVHN-Recognition/resolve/main/model-122000.pth", local_path="model-122000.pth"):
    # Check if the file exists
    if not os.path.exists(local_path):
        print(f"Downloading file from {url}...")
        # Make a request to the URL
        response = requests.get(url, stream=True)

        # Get the total file size in bytes
        total_size = int(response.headers.get('content-length', 0))

        # Initialize the tqdm progress bar
        progress_bar = tqdm(total=total_size, unit='B', unit_scale=True)

        # Open a local file with write-binary mode
        with open(local_path, 'wb') as file:
            for data in response.iter_content(chunk_size=1024):
                # Update the progress bar
                progress_bar.update(len(data))

                # Write the data to the local file
                file.write(data)

        # Close the progress bar
        progress_bar.close()

        print("Download completed.")


def _infer(path_to_checkpoint_file, path_to_input_image):
    model = Model()
    model.restore(path_to_checkpoint_file)
    # model.cuda()
    outstr = ''

    with torch.no_grad():
        transform = transforms.Compose([
            transforms.Resize([64, 64]),
            transforms.CenterCrop([54, 54]),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

        image = Image.open(path_to_input_image)
        image = image.convert('RGB')
        image = transform(image)
        images = image.unsqueeze(dim=0)  # .cuda()

        length_logits, digit1_logits, digit2_logits, digit3_logits, digit4_logits, digit5_logits = model.eval()(images)

        length_prediction = length_logits.max(1)[1]
        digit1_prediction = digit1_logits.max(1)[1]
        digit2_prediction = digit2_logits.max(1)[1]
        digit3_prediction = digit3_logits.max(1)[1]
        digit4_prediction = digit4_logits.max(1)[1]
        digit5_prediction = digit5_logits.max(1)[1]

        output = [
            digit1_prediction.item(),
            digit2_prediction.item(),
            digit3_prediction.item(),
            digit4_prediction.item(),
            digit5_prediction.item()
        ]

        for i in range(length_prediction.item()):
            outstr += str(output[i])

    return outstr


def inference(image_path, weight_path="model-122000.pth"):
    download_model()

    if not image_path:
        image_path = './images/03.png'

    return _infer(weight_path, image_path)


if __name__ == '__main__':
    example_images = [
        './images/03.png',
        './images/457.png',
        './images/2003.png'
    ]

    iface = gr.Interface(
        fn=inference,
        inputs=gr.Image(type='filepath'),
        outputs=gr.Textbox(),
        examples=example_images
    )

    iface.launch()