Spaces:
Runtime error
Runtime error
import argparse | |
import requests | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
from torchvision import transforms | |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from timm.data import create_transform | |
from timmvit import timmvit | |
import json | |
from timm.models.hub import download_cached_file | |
from PIL import Image | |
def pil_loader(filepath): | |
with Image.open(filepath) as img: | |
img = img.convert('RGB') | |
return img | |
def build_transforms(input_size): | |
transform = torchvision.transforms.Compose([ | |
torchvision.transforms.Resize(input_size * 8 // 7), | |
torchvision.transforms.CenterCrop(input_size), | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
return transforms | |
# Download human-readable labels for Bamboo. | |
with open('./trainid2name.json') as f: | |
id2name = json.load(f) | |
''' | |
build model | |
''' | |
model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert') | |
model.eval() | |
''' | |
build data transform | |
''' | |
eval_transforms = build_transforms(224) | |
''' | |
borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py | |
''' | |
def show_cam_on_image(img: np.ndarray, | |
mask: np.ndarray, | |
use_rgb: bool = False, | |
colormap: int = cv2.COLORMAP_JET) -> np.ndarray: | |
""" This function overlays the cam mask on the image as an heatmap. | |
By default the heatmap is in BGR format. | |
:param img: The base image in RGB or BGR format. | |
:param mask: The cam mask. | |
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. | |
:param colormap: The OpenCV colormap to be used. | |
:returns: The default image with the cam overlay. | |
""" | |
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) | |
if use_rgb: | |
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
heatmap = np.float32(heatmap) / 255 | |
if np.max(img) > 1: | |
raise Exception( | |
"The input image should np.float32 in the range [0, 1]") | |
cam = 0.7*heatmap + 0.3*img | |
# cam = cam / np.max(cam) | |
return np.uint8(255 * cam) | |
def recognize_image(image, texts): | |
img_t = eval_transforms(image) | |
# compute output | |
output = model(img_t.unsqueeze(0)) | |
prediction = output.softmax(-1).flatten() | |
_,top5_idx = torch.topk(prediction, 5) | |
return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()} | |
image = gr.inputs.Image() | |
label = gr.outputs.Label(num_top_classes=5) | |
gr.Interface( | |
description="Bamboo for Zero-shot Image Recognition Demo (https://github.com/Davidzhangyuanhan/Bamboo)", | |
fn=recognize_image, | |
inputs=["image"], | |
outputs=[ | |
label, | |
], | |
# examples=[ | |
# ["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"], | |
# ["./apple_with_ipod.jpg", "an ipod; an apple with a write note 'ipod'; an apple"], | |
# ["./crowd2.jpg", "a street; a street with a woman walking in the middle; a street with a man walking in the middle"], | |
# ["./zebras.png", "three zebras on the grass; two zebras on the grass; one zebra on the grass; no zebra on the grass; four zebras on the grass"], | |
# ], | |
).launch() |