File size: 2,579 Bytes
98d4fbe dba7cbf 5066aaa 98d4fbe 5066aaa 98d4fbe 5066aaa fa0722e 5066aaa 98d4fbe 5066aaa e0f0523 98d4fbe fa0722e 98d4fbe fa0722e 98d4fbe 88ed06b 94a816d fa0722e 98d4fbe fa0722e dc28d34 0e9f6ff dc28d34 56dff7d fa0722e 7a28a77 fa0722e 0ca4068 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import sys
import os
import matplotlib.pyplot as plt
import PIL
from PIL import Image
import json
import torch
import torchvision
import torchvision.transforms as T
from timm import create_model
import gradio as gr
model_name = "convnext_xlarge_in22k"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# create a ConvNeXt model : https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convnext.py
model = create_model(model_name, pretrained=True).to(device)
# Define transforms for test
from timm.data.constants import \
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
NORMALIZE_MEAN = IMAGENET_DEFAULT_MEAN
NORMALIZE_STD = IMAGENET_DEFAULT_STD
SIZE = 256
# Here we resize smaller edge to 256, no center cropping
transforms = [
T.Resize(SIZE, interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
]
transforms = T.Compose(transforms)
os.system("wget https://dl.fbaipublicfiles.com/convnext/label_to_words.json")
imagenet_labels = json.load(open('label_to_words.json'))
def inference(img):
img_tensor = transforms(img).unsqueeze(0).to(device)
# inference
output = torch.softmax(model(img_tensor), dim=1)
top5 = torch.topk(output, k=5)
top5_prob = top5.values[0]
top5_indices = top5.indices[0]
result = {}
for i in range(5):
labels = imagenet_labels[str(int(top5_indices[i]))]
prob = float(top5_prob[i])
result[labels] = prob
return result
inputs = gr.inputs.Image(type='pil')
outputs = gr.outputs.Label(type="confidences",num_top_classes=5)
title = "ConvNeXt"
description = "Gradio demo for ConvNeXt for image classification. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.03545' target='_blank'>A ConvNet for the 2020s</a> | <a href='https://github.com/facebookresearch/ConvNeXt' target='_blank'>Github Repo</a> | <a href='https://github.com/leondgarse/keras_cv_attention_models' target='_blank'>pretrained ConvNeXt model from keras_cv_attention_models</a> | <a href='https://github.com/stanislavfort/adversaries_to_convnext' target='_blank'>examples usage from adversaries_to_convnext</a></p>"
examples = ['Tortoise-on-ground-surrounded-by-plants.jpeg']
gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, analytics_enabled=False, examples=examples).launch(enable_queue=True) |