File size: 3,018 Bytes
278c80b
 
 
 
 
 
 
 
 
 
 
ce99183
278c80b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fabd5cb
278c80b
fabd5cb
278c80b
 
 
 
fabd5cb
278c80b
a76f3f7
b780a4b
278c80b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import torch
import random
import warnings
import gradio as gr
from PIL import Image
from model import Model
from torchvision import transforms
from modelscope import snapshot_download


MODEL_DIR = snapshot_download("Genius-Society/svhn", cache_dir="./__pycache__")


def infer(input_img: str, checkpoint_file: str):
    try:
        model = Model()
        model.restore(f"{MODEL_DIR}/{checkpoint_file}")
        outstr = ""
        with torch.no_grad():
            transform = transforms.Compose(
                [
                    transforms.Resize([64, 64]),
                    transforms.CenterCrop([54, 54]),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                ]
            )
            image = Image.open(input_img)
            image = image.convert("RGB")
            image = transform(image)
            images = image.unsqueeze(dim=0)
            (
                length_logits,
                digit1_logits,
                digit2_logits,
                digit3_logits,
                digit4_logits,
                digit5_logits,
            ) = model.eval()(images)
            length_prediction = length_logits.max(1)[1]
            digit1_prediction = digit1_logits.max(1)[1]
            digit2_prediction = digit2_logits.max(1)[1]
            digit3_prediction = digit3_logits.max(1)[1]
            digit4_prediction = digit4_logits.max(1)[1]
            digit5_prediction = digit5_logits.max(1)[1]
            output = [
                digit1_prediction.item(),
                digit2_prediction.item(),
                digit3_prediction.item(),
                digit4_prediction.item(),
                digit5_prediction.item(),
            ]

            for i in range(length_prediction.item()):
                outstr += str(output[i])

        return outstr

    except Exception as e:
        return f"{e}"


def get_files(dir_path=MODEL_DIR, ext=".pth"):
    files_and_folders = os.listdir(dir_path)
    outputs = []
    for file in files_and_folders:
        if file.endswith(ext):
            outputs.append(file)

    return outputs


if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    models = get_files()
    images = get_files(f"{MODEL_DIR}/examples", ".png")
    samples = []
    for img in images:
        samples.append(
            [
                f"{MODEL_DIR}/examples/{img}",
                models[random.randint(0, len(models) - 1)],
            ]
        )

    gr.Interface(
        fn=infer,
        inputs=[
            gr.Image(label="Upload an image", type="filepath"),
            gr.Dropdown(
                label="Select a model",
                choices=models,
                value=models[0],
            ),
        ],
        outputs=gr.Textbox(label="Recognition result", show_copy_button=True),
        examples=samples,
        title="Door Number Recognition",
        flagging_mode="never",
        cache_examples=False,
    ).launch()