File size: 1,146 Bytes
30ed572
 
 
 
 
 
326ae50
8b19113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e3f070
8b19113
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
os.system("pip install --upgrade httpx")
os.system("pip install --upgrade gradio")
os.system("pip install opencv-python")
os.system("pip install torch")
os.system("pip install --upgrade pillow")
os.system("pip install torchvision")
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image

def predict(image):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = torchvision.models.resnet50(pretrained=True).to(device)
    model.eval()
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    img = Image.fromarray(image.astype('uint8'), 'RGB')
    img = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(img)
    _, predicted = output.max(1)
    return predicted.item()

input_image = gr.component.Image(type="filepath", label="Input")
output_text = gr.outputs.Textbox()

gr.Interface(fn=predict, inputs=input_image, outputs=output_text).launch()