import argparse import requests import gradio as gr import numpy as np import torch import torch.nn as nn from PIL import Image from torchvision import transforms from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import create_transform from timmvit import timmvit import json from timm.models.hub import download_cached_file from PIL import Image def pil_loader(filepath): with Image.open(filepath) as img: img = img.convert('RGB') return img def build_transforms(input_size): transform = torchvision.transforms.Compose([ torchvision.transforms.Resize(input_size * 8 // 7), torchvision.transforms.CenterCrop(input_size), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transforms # Download human-readable labels for Bamboo. with open('./Bamboo_ViT-B16_demo/trainid2name.json') as f: id2name = json.load(f) ''' build model ''' model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert') model.eval() ''' build data transform ''' eval_transforms = build_transforms(224) ''' borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py ''' def recognize_image(image, texts): img_t = eval_transforms(image) # compute output output = model(img_t.unsqueeze(0)) prediction = output.softmax(-1).flatten() _,top5_idx = torch.topk(prediction, 5) return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()} image = gr.inputs.Image() label = gr.outputs.Label(num_top_classes=5) gr.Interface( description="Bamboo for Zero-shot Image Recognition Demo (https://github.com/Davidzhangyuanhan/Bamboo)", fn=recognize_image, inputs=["image"], outputs=[ label, ], # examples=[ # ["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"], # ["./apple_with_ipod.jpg", "an ipod; an apple with a write note 'ipod'; an apple"], # ["./crowd2.jpg", "a street; a street with a woman walking in the middle; a street with a man walking in the middle"], # ["./zebras.png", "three zebras on the grass; two zebras on the grass; one zebra on the grass; no zebra on the grass; four zebras on the grass"], # ], ).launch()