svhn / app.py
admin
sync
79c4c0b
raw
history blame
No virus
3.01 kB
import os
import torch
import random
import warnings
import gradio as gr
from PIL import Image
from model import Model
from torchvision import transforms
from modelscope import snapshot_download
MODEL_DIR = snapshot_download("MuGeminorum/svhn", cache_dir="./__pycache__")
def infer(input_img: str, checkpoint_file: str):
try:
model = Model()
model.restore(f"{MODEL_DIR}/{checkpoint_file}")
outstr = ""
with torch.no_grad():
transform = transforms.Compose(
[
transforms.Resize([64, 64]),
transforms.CenterCrop([54, 54]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
image = Image.open(input_img)
image = image.convert("RGB")
image = transform(image)
images = image.unsqueeze(dim=0)
(
length_logits,
digit1_logits,
digit2_logits,
digit3_logits,
digit4_logits,
digit5_logits,
) = model.eval()(images)
length_prediction = length_logits.max(1)[1]
digit1_prediction = digit1_logits.max(1)[1]
digit2_prediction = digit2_logits.max(1)[1]
digit3_prediction = digit3_logits.max(1)[1]
digit4_prediction = digit4_logits.max(1)[1]
digit5_prediction = digit5_logits.max(1)[1]
output = [
digit1_prediction.item(),
digit2_prediction.item(),
digit3_prediction.item(),
digit4_prediction.item(),
digit5_prediction.item(),
]
for i in range(length_prediction.item()):
outstr += str(output[i])
return outstr
except Exception as e:
return f"{e}"
def get_files(dir_path=MODEL_DIR, ext=".pth"):
files_and_folders = os.listdir(dir_path)
outputs = []
for file in files_and_folders:
if file.endswith(ext):
outputs.append(file)
return outputs
if __name__ == "__main__":
warnings.filterwarnings("ignore")
models = get_files()
images = get_files(f"{MODEL_DIR}/examples", ".png")
samples = []
for img in images:
samples.append(
[
f"{MODEL_DIR}/examples/{img}",
models[random.randint(0, len(models) - 1)],
]
)
gr.Interface(
fn=infer,
inputs=[
gr.Image(label="上传图片 Upload an image", type="filepath"),
gr.Dropdown(
label="选择权重 Select a model",
choices=models,
value=models[0],
),
],
outputs=gr.Textbox(label="识别结果 Recognition result", show_copy_button=True),
examples=samples,
allow_flagging="never",
cache_examples=False,
).launch()