XXXX / app.py
Nvd's picture
Update app.py
496591d
raw
history blame
1.07 kB
import torch
import torchvision.transforms as transforms
import gradio as gr
from PIL import Image
from model import SimpleCNN
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = Image.fromarray(image)
image = transform(image)
image = image.unsqueeze(0)
return image
def predict_image(model, image):
if torch.sum(image) == 0:
return 404
model.eval()
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output.data, 1)
return predicted.item()
def main():
model = SimpleCNN()
model.load_state_dict(torch.load('cifar10_model.pth'))
model.eval()
iface = gr.Interface(
fn=lambda img: predict_image(model, preprocess_image(img)),
inputs=gr.Image(),
outputs="label",
live=True,
)
iface.launch()
if __name__ == "__main__":
main()