Spaces:
Running
Running
MuGeminorum
commited on
Commit
•
de12ba7
1
Parent(s):
58e5f22
add show copy btn
Browse files
app.py
CHANGED
@@ -7,10 +7,14 @@ from PIL import Image
|
|
7 |
from model import Model
|
8 |
from torchvision import transforms
|
9 |
import warnings
|
|
|
10 |
warnings.filterwarnings("ignore")
|
11 |
|
12 |
|
13 |
-
def download_model(
|
|
|
|
|
|
|
14 |
# Check if the file exists
|
15 |
if not os.path.exists(local_path):
|
16 |
print(f"Downloading file from {url}...")
|
@@ -18,13 +22,13 @@ def download_model(url="https://huggingface.co/MuGeminorum/SVHN-Recognition/reso
|
|
18 |
response = requests.get(url, stream=True)
|
19 |
|
20 |
# Get the total file size in bytes
|
21 |
-
total_size = int(response.headers.get(
|
22 |
|
23 |
# Initialize the tqdm progress bar
|
24 |
-
progress_bar = tqdm(total=total_size, unit=
|
25 |
|
26 |
# Open a local file with write-binary mode
|
27 |
-
with open(local_path,
|
28 |
for data in response.iter_content(chunk_size=1024):
|
29 |
# Update the progress bar
|
30 |
progress_bar.update(len(data))
|
@@ -42,22 +46,31 @@ def _infer(path_to_checkpoint_file, path_to_input_image):
|
|
42 |
model = Model()
|
43 |
model.restore(path_to_checkpoint_file)
|
44 |
# model.cuda()
|
45 |
-
outstr =
|
46 |
|
47 |
with torch.no_grad():
|
48 |
-
transform = transforms.Compose(
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
54 |
|
55 |
image = Image.open(path_to_input_image)
|
56 |
-
image = image.convert(
|
57 |
image = transform(image)
|
58 |
images = image.unsqueeze(dim=0) # .cuda()
|
59 |
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
length_prediction = length_logits.max(1)[1]
|
63 |
digit1_prediction = digit1_logits.max(1)[1]
|
@@ -71,7 +84,7 @@ def _infer(path_to_checkpoint_file, path_to_input_image):
|
|
71 |
digit2_prediction.item(),
|
72 |
digit3_prediction.item(),
|
73 |
digit4_prediction.item(),
|
74 |
-
digit5_prediction.item()
|
75 |
]
|
76 |
|
77 |
for i in range(length_prediction.item()):
|
@@ -89,23 +102,19 @@ def inference(image_path, weight_path="model-122000.pth"):
|
|
89 |
)
|
90 |
|
91 |
if not image_path:
|
92 |
-
image_path =
|
93 |
|
94 |
return _infer(weight_path, image_path)
|
95 |
|
96 |
|
97 |
-
if __name__ ==
|
98 |
-
example_images = [
|
99 |
-
'./examples/03.png',
|
100 |
-
'./examples/457.png',
|
101 |
-
'./examples/2003.png'
|
102 |
-
]
|
103 |
|
104 |
iface = gr.Interface(
|
105 |
fn=inference,
|
106 |
-
inputs=gr.Image(type=
|
107 |
-
outputs=gr.Textbox(label=
|
108 |
-
examples=example_images
|
109 |
)
|
110 |
|
111 |
iface.launch()
|
|
|
7 |
from model import Model
|
8 |
from torchvision import transforms
|
9 |
import warnings
|
10 |
+
|
11 |
warnings.filterwarnings("ignore")
|
12 |
|
13 |
|
14 |
+
def download_model(
|
15 |
+
url="https://huggingface.co/MuGeminorum/SVHN-Recognition/resolve/main/model-122000.pth",
|
16 |
+
local_path="model-122000.pth",
|
17 |
+
):
|
18 |
# Check if the file exists
|
19 |
if not os.path.exists(local_path):
|
20 |
print(f"Downloading file from {url}...")
|
|
|
22 |
response = requests.get(url, stream=True)
|
23 |
|
24 |
# Get the total file size in bytes
|
25 |
+
total_size = int(response.headers.get("content-length", 0))
|
26 |
|
27 |
# Initialize the tqdm progress bar
|
28 |
+
progress_bar = tqdm(total=total_size, unit="B", unit_scale=True)
|
29 |
|
30 |
# Open a local file with write-binary mode
|
31 |
+
with open(local_path, "wb") as file:
|
32 |
for data in response.iter_content(chunk_size=1024):
|
33 |
# Update the progress bar
|
34 |
progress_bar.update(len(data))
|
|
|
46 |
model = Model()
|
47 |
model.restore(path_to_checkpoint_file)
|
48 |
# model.cuda()
|
49 |
+
outstr = ""
|
50 |
|
51 |
with torch.no_grad():
|
52 |
+
transform = transforms.Compose(
|
53 |
+
[
|
54 |
+
transforms.Resize([64, 64]),
|
55 |
+
transforms.CenterCrop([54, 54]),
|
56 |
+
transforms.ToTensor(),
|
57 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
58 |
+
]
|
59 |
+
)
|
60 |
|
61 |
image = Image.open(path_to_input_image)
|
62 |
+
image = image.convert("RGB")
|
63 |
image = transform(image)
|
64 |
images = image.unsqueeze(dim=0) # .cuda()
|
65 |
|
66 |
+
(
|
67 |
+
length_logits,
|
68 |
+
digit1_logits,
|
69 |
+
digit2_logits,
|
70 |
+
digit3_logits,
|
71 |
+
digit4_logits,
|
72 |
+
digit5_logits,
|
73 |
+
) = model.eval()(images)
|
74 |
|
75 |
length_prediction = length_logits.max(1)[1]
|
76 |
digit1_prediction = digit1_logits.max(1)[1]
|
|
|
84 |
digit2_prediction.item(),
|
85 |
digit3_prediction.item(),
|
86 |
digit4_prediction.item(),
|
87 |
+
digit5_prediction.item(),
|
88 |
]
|
89 |
|
90 |
for i in range(length_prediction.item()):
|
|
|
102 |
)
|
103 |
|
104 |
if not image_path:
|
105 |
+
image_path = "./examples/03.png"
|
106 |
|
107 |
return _infer(weight_path, image_path)
|
108 |
|
109 |
|
110 |
+
if __name__ == "__main__":
|
111 |
+
example_images = ["./examples/03.png", "./examples/457.png", "./examples/2003.png"]
|
|
|
|
|
|
|
|
|
112 |
|
113 |
iface = gr.Interface(
|
114 |
fn=inference,
|
115 |
+
inputs=gr.Image(type="filepath", label="Upload photo"),
|
116 |
+
outputs=gr.Textbox(label="Recognition result", show_copy_button=True),
|
117 |
+
examples=example_images,
|
118 |
)
|
119 |
|
120 |
iface.launch()
|