jw2yang's picture
Update app.py
753a0ae
raw
history blame
No virus
4.99 kB
import requests
import gradio as gr
import numpy as np
# import cv2
if not INTERNAL:
import cv2 # pylint: disable=g-import-not-at-top
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform
from timm.data.transforms import _pil_interp
from focalnet import FocalNet, build_transforms, build_transforms4display
# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
'''
build model
'''
model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], use_layerscale=True, use_postln=True)
url = 'https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_iso_16.pth'
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
model = model.cuda(); model.eval()
'''
build data transform
'''
eval_transforms = build_transforms(224, center_crop=False)
display_transforms = build_transforms4display(224, center_crop=False)
'''
build upsampler
'''
# upsampler = nn.Upsample(scale_factor=16, mode='bilinear')
'''
borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
'''
def show_cam_on_image(img: np.ndarray,
mask: np.ndarray,
use_rgb: bool = False,
colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
""" This function overlays the cam mask on the image as an heatmap.
By default the heatmap is in BGR format.
:param img: The base image in RGB or BGR format.
:param mask: The cam mask.
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
:param colormap: The OpenCV colormap to be used.
:returns: The default image with the cam overlay.
"""
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
if use_rgb:
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = np.float32(heatmap) / 255
if np.max(img) > 1:
raise Exception(
"The input image should np.float32 in the range [0, 1]")
cam = 0.7*heatmap + 0.3*img
# cam = cam / np.max(cam)
return np.uint8(255 * cam)
def classify_image(inp):
img_t = eval_transforms(inp)
img_d = display_transforms(inp).permute(1, 2, 0).cpu().numpy()
print(img_d.min(), img_d.max())
prediction = model(img_t.unsqueeze(0).cuda()).softmax(-1).flatten()
modulator = model.layers[0].blocks[2].modulation.modulator.norm(2, 1, keepdim=True)
modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
modulator = modulator.squeeze(1).detach().permute(1, 2, 0).cpu().numpy()
modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
cam0 = show_cam_on_image(img_d, modulator, use_rgb=True)
modulator = model.layers[0].blocks[5].modulation.modulator.norm(2, 1, keepdim=True)
modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
modulator = modulator.squeeze(1).detach().permute(1, 2, 0).cpu().numpy()
modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
cam1 = show_cam_on_image(img_d, modulator, use_rgb=True)
modulator = model.layers[0].blocks[8].modulation.modulator.norm(2, 1, keepdim=True)
modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
modulator = modulator.squeeze(1).detach().permute(1, 2, 0).cpu().numpy()
modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
cam2 = show_cam_on_image(img_d, modulator, use_rgb=True)
modulator = model.layers[0].blocks[11].modulation.modulator.norm(2, 1, keepdim=True)
modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
modulator = modulator.squeeze(1).detach().permute(1, 2, 0).cpu().numpy()
modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
cam3 = show_cam_on_image(img_d, modulator, use_rgb=True)
return Image.fromarray(cam0), Image.fromarray(cam1), Image.fromarray(cam2), Image.fromarray(cam3), {labels[i]: float(prediction[i]) for i in range(1000)}
image = gr.inputs.Image()
label = gr.outputs.Label(num_top_classes=3)
gr.Interface(
fn=classify_image,
inputs=image,
outputs=[
gr.outputs.Image(
type="pil",
label="Modulator at layer 3"),
gr.outputs.Image(
type="pil",
label="Modulator at layer 6"),
gr.outputs.Image(
type="pil",
label="Modulator at layer 9"),
gr.outputs.Image(
type="pil",
label="Modulator at layer 12"),
label,
],
# examples=[["images/aiko.jpg"], ["images/pencils.jpg"], ["images/donut.png"]],
).launch()