File size: 2,267 Bytes
98d4fbe dba7cbf 5066aaa 98d4fbe 5066aaa 98d4fbe 5066aaa fa0722e 5066aaa 98d4fbe 5066aaa e0f0523 98d4fbe fa0722e 98d4fbe fa0722e 98d4fbe 88ed06b 94a816d fa0722e 98d4fbe fa0722e dc28d34 e92a4ce 0e9f6ff dc28d34 e92a4ce fa0722e e92a4ce fa0722e 0ca4068 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import sys
import os
import matplotlib.pyplot as plt
import PIL
from PIL import Image
import json
import torch
import torchvision
import torchvision.transforms as T
from timm import create_model
import gradio as gr
model_name = "convnext_xlarge_in22k"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# create a ConvNeXt model : https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convnext.py
model = create_model(model_name, pretrained=True).to(device)
# Define transforms for test
from timm.data.constants import \
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
NORMALIZE_MEAN = IMAGENET_DEFAULT_MEAN
NORMALIZE_STD = IMAGENET_DEFAULT_STD
SIZE = 256
# Here we resize smaller edge to 256, no center cropping
transforms = [
T.Resize(SIZE, interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
]
transforms = T.Compose(transforms)
os.system("wget https://dl.fbaipublicfiles.com/convnext/label_to_words.json")
imagenet_labels = json.load(open('label_to_words.json'))
def inference(img):
img_tensor = transforms(img).unsqueeze(0).to(device)
# inference
output = torch.softmax(model(img_tensor), dim=1)
top5 = torch.topk(output, k=5)
top5_prob = top5.values[0]
top5_indices = top5.indices[0]
result = {}
for i in range(5):
labels = imagenet_labels[str(int(top5_indices[i]))]
prob = float(top5_prob[i])
result[labels] = prob
return result
inputs = gr.inputs.Image(type='pil')
outputs = gr.outputs.Label(type="confidences",num_top_classes=5)
title = "ConvNeXt"
description = "Gradio demo for ConvNeXt for image classification. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.03545' target='_blank'>A ConvNet for the 2020s</a> | <a href='https://github.com/facebookresearch/ConvNeXt' target='_blank'>Github Repo</a></p>"
examples = ['test.jpeg']
gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, analytics_enabled=False, examples=examples).launch(enable_queue=True) |