Spaces:
Running
Running
File size: 3,088 Bytes
e1e7fa2 eca1055 e1e7fa2 718237e e1e7fa2 718237e e1e7fa2 ed3f8d6 e1e7fa2 b7472a4 ed3f8d6 b7472a4 e1e7fa2 b7472a4 e1e7fa2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import os
import torch
import requests
import gradio as gr
from tqdm import tqdm
from PIL import Image
from model import Model
from torchvision import transforms
import warnings
warnings.filterwarnings("ignore")
def download_model(url="https://huggingface.co/MuGeminorum/SVHN-Recognition/resolve/main/model-122000.pth", local_path="model-122000.pth"):
# Check if the file exists
if not os.path.exists(local_path):
print(f"Downloading file from {url}...")
# Make a request to the URL
response = requests.get(url, stream=True)
# Get the total file size in bytes
total_size = int(response.headers.get('content-length', 0))
# Initialize the tqdm progress bar
progress_bar = tqdm(total=total_size, unit='B', unit_scale=True)
# Open a local file with write-binary mode
with open(local_path, 'wb') as file:
for data in response.iter_content(chunk_size=1024):
# Update the progress bar
progress_bar.update(len(data))
# Write the data to the local file
file.write(data)
# Close the progress bar
progress_bar.close()
print("Download completed.")
def _infer(path_to_checkpoint_file, path_to_input_image):
model = Model()
model.restore(path_to_checkpoint_file)
# model.cuda()
outstr = ''
with torch.no_grad():
transform = transforms.Compose([
transforms.Resize([64, 64]),
transforms.CenterCrop([54, 54]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
image = Image.open(path_to_input_image)
image = image.convert('RGB')
image = transform(image)
images = image.unsqueeze(dim=0) # .cuda()
length_logits, digit1_logits, digit2_logits, digit3_logits, digit4_logits, digit5_logits = model.eval()(images)
length_prediction = length_logits.max(1)[1]
digit1_prediction = digit1_logits.max(1)[1]
digit2_prediction = digit2_logits.max(1)[1]
digit3_prediction = digit3_logits.max(1)[1]
digit4_prediction = digit4_logits.max(1)[1]
digit5_prediction = digit5_logits.max(1)[1]
output = [
digit1_prediction.item(),
digit2_prediction.item(),
digit3_prediction.item(),
digit4_prediction.item(),
digit5_prediction.item()
]
for i in range(length_prediction.item()):
outstr += str(output[i])
return outstr
def inference(image_path, weight_path="model-122000.pth"):
download_model()
if not image_path:
image_path = './images/03.png'
return _infer(weight_path, image_path)
if __name__ == '__main__':
example_images = [
'./images/03.png',
'./images/457.png',
'./images/2003.png'
]
iface = gr.Interface(
fn=inference,
inputs=gr.Image(type='filepath'),
outputs=gr.Textbox(),
examples=example_images
)
iface.launch()
|