MuGeminorum
use cpu
718237e
raw
history blame
No virus
2.93 kB
import os
import torch
import requests
import gradio as gr
from tqdm import tqdm
from PIL import Image
from model import Model
from torchvision import transforms
import warnings
warnings.filterwarnings("ignore")
def download_model(url="https://huggingface.co/MuGeminorum/SVHN-Recognition/resolve/main/model-122000.pth", local_path="model-122000.pth"):
# Check if the file exists
if not os.path.exists(local_path):
print(f"Downloading file from {url}...")
# Make a request to the URL
response = requests.get(url, stream=True)
# Get the total file size in bytes
total_size = int(response.headers.get('content-length', 0))
# Initialize the tqdm progress bar
progress_bar = tqdm(total=total_size, unit='B', unit_scale=True)
# Open a local file with write-binary mode
with open(local_path, 'wb') as file:
for data in response.iter_content(chunk_size=1024):
# Update the progress bar
progress_bar.update(len(data))
# Write the data to the local file
file.write(data)
# Close the progress bar
progress_bar.close()
print("Download completed.")
def _infer(path_to_checkpoint_file, path_to_input_image):
model = Model()
model.restore(path_to_checkpoint_file)
# model.cuda()
outstr = ''
with torch.no_grad():
transform = transforms.Compose([
transforms.Resize([64, 64]),
transforms.CenterCrop([54, 54]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
image = Image.open(path_to_input_image)
image = image.convert('RGB')
image = transform(image)
images = image.unsqueeze(dim=0) # .cuda()
length_logits, digit1_logits, digit2_logits, digit3_logits, digit4_logits, digit5_logits = model.eval()(images)
length_prediction = length_logits.max(1)[1]
digit1_prediction = digit1_logits.max(1)[1]
digit2_prediction = digit2_logits.max(1)[1]
digit3_prediction = digit3_logits.max(1)[1]
digit4_prediction = digit4_logits.max(1)[1]
digit5_prediction = digit5_logits.max(1)[1]
output = [
digit1_prediction.item(),
digit2_prediction.item(),
digit3_prediction.item(),
digit4_prediction.item(),
digit5_prediction.item()
]
for i in range(length_prediction.item()):
outstr += str(output[i])
return outstr
def inference(image_path, weight_path="model-122000.pth"):
download_model()
if not image_path:
image_path = '457.png'
return _infer(weight_path, image_path)
if __name__ == '__main__':
iface = gr.Interface(
fn=inference,
inputs=gr.Image(type='filepath'),
outputs=gr.Textbox()
)
iface.launch()