narugo's picture
dev(narugo): more models added
history blame
7.49 kB
import math
from pathlib import Path
import colorcet as cc
import cv2
import numpy as np
import timm
import torch
from PIL import Image
from matplotlib.colors import LinearSegmentedColormap
from import create_transform, resolve_data_config
from timm.models import VisionTransformer
from torch import Tensor, nn
from torch.nn import functional as F
from torchvision import transforms as T
from .common import Heatmap, ImageLabels, LabelData, pil_make_grid
# working dir, either file parent dir or cwd if interactive
work_dir = (Path(__file__).parent if "__file__" in locals() else Path.cwd()).resolve()
temp_dir = work_dir.joinpath("temp")
temp_dir.mkdir(exist_ok=True, parents=True)
# model cache
model_cache: dict[str, VisionTransformer] = {}
transform_cache: dict[str, T.Compose] = {}
# device to use
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class RGBtoBGR(nn.Module):
def forward(self, x: Tensor) -> Tensor:
if x.ndim == 4:
return x[:, [2, 1, 0], :, :]
return x[[2, 1, 0], :, :]
def model_device(model: nn.Module) -> torch.device:
return next(model.parameters()).device
def load_model(repo_id: str) -> VisionTransformer:
global model_cache
if model_cache.get(repo_id, None) is None:
# save model to cache
model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval().to(torch_device)
return model_cache[repo_id]
def load_model_and_transform(repo_id: str) -> tuple[VisionTransformer, T.Compose]:
global transform_cache
global model_cache
if model_cache.get(repo_id, None) is None:
# save model to cache
model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval()
model = model_cache[repo_id]
if transform_cache.get(repo_id, None) is None:
transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
# hack in the RGBtoBGR transform, save to cache
transform_cache[repo_id] = T.Compose(transforms.transforms + [RGBtoBGR()])
transform = transform_cache[repo_id]
return model, transform
def get_tags(
probs: Tensor,
labels: LabelData,
gen_threshold: float,
char_threshold: float,
# Convert indices+probs to labels
probs = list(zip(labels.names, probs.numpy()))
# First 4 labels are actually ratings
rating_labels = dict([probs[i] for i in labels.rating])
# General labels, pick any where prediction confidence > threshold
gen_labels = [probs[i] for i in labels.general]
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
# Character labels, pick any where prediction confidence > threshold
char_labels = [probs[i] for i in labels.character]
char_labels = dict([x for x in char_labels if x[1] > char_threshold])
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
# Combine general and character labels, sort by confidence
combined_names = [x for x in gen_labels]
combined_names.extend([x for x in char_labels])
# Convert to a string suitable for use as a training caption
caption = ", ".join(combined_names).replace("(", "\(").replace(")", "\)")
booru = caption.replace("_", " ")
return caption, booru, rating_labels, char_labels, gen_labels
def render_heatmap(
image: Tensor,
gradients: Tensor,
image_feats: Tensor,
image_probs: Tensor,
image_labels: list[str],
cmap: LinearSegmentedColormap = cc.m_linear_bmy_10_95_c71,
pos_embed_dim: int = 784,
image_size: tuple[int, int] = (448, 448),
font_args: dict = {
"fontScale": 1,
"color": (255, 255, 255),
"thickness": 2,
"lineType": cv2.LINE_AA,
partial_rows: bool = True,
) -> tuple[list[Heatmap], Image.Image]:
# hmap_dim = int(math.sqrt(pos_embed_dim))
image_hmaps = gradients.mean(2, keepdim=True).mul(image_feats.unsqueeze(0)).squeeze()
hmap_dim = int(math.sqrt(image_hmaps.mean(-1).numel() / len(image_labels)))
image_hmaps = image_hmaps.mean(-1).reshape(len(image_labels), -1)
image_hmaps = image_hmaps[..., -hmap_dim ** 2:]
image_hmaps = image_hmaps.reshape(len(image_labels), hmap_dim, hmap_dim)
image_hmaps = image_hmaps.max(torch.zeros_like(image_hmaps))
image_hmaps /= image_hmaps.reshape(image_hmaps.shape[0], -1).max(-1)[0].unsqueeze(-1).unsqueeze(-1)
# normalize to 0-1
image_hmaps = torch.stack([(x - x.min()) / (x.max() - x.min()) for x in image_hmaps]).unsqueeze(1)
# interpolate to input image size
image_hmaps = F.interpolate(image_hmaps, size=image_size, mode="bilinear").squeeze(1)
hmap_imgs: list[Heatmap] = []
for tag, hmap, score in zip(image_labels, image_hmaps, image_probs.cpu()):
image_pixels = image.add(1).mul(127.5).squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8)
hmap_pixels = cmap(hmap.cpu().numpy(), bytes=True)[:, :, :3]
hmap_cv2 = cv2.cvtColor(hmap_pixels, cv2.COLOR_RGB2BGR)
hmap_image = cv2.addWeighted(image_pixels, 0.5, hmap_cv2, 0.5, 0)
if tag is not None:
cv2.putText(hmap_image, tag, (10, 30), **font_args)
cv2.putText(hmap_image, f"{score:.3f}", org=(10, 60), **font_args)
hmap_pil = Image.fromarray(cv2.cvtColor(hmap_image, cv2.COLOR_BGR2RGB))
hmap_imgs.append(Heatmap(tag, score.item(), hmap_pil))
hmap_imgs = sorted(hmap_imgs, key=lambda x: x.score, reverse=True)
hmap_grid = pil_make_grid([x.image for x in hmap_imgs], partial_rows=partial_rows)
return hmap_imgs, hmap_grid
def process_heatmap(
model: VisionTransformer,
image: Tensor,
labels: LabelData,
threshold: float = 0.5,
partial_rows: bool = True,
) -> tuple[list[tuple[float, str, Image.Image]], Image.Image, ImageLabels]:
torch_device = model_device(model)
with torch.set_grad_enabled(True):
features = model.forward_features(
probs = model.forward_head(features)
probs = F.sigmoid(probs).squeeze(0)
probs_mask = probs > threshold
heatmap_probs = probs[probs_mask]
label_indices = torch.nonzero(probs_mask, as_tuple=False).squeeze(1)
image_labels = [labels.names[label_indices[i]] for i in range(len(label_indices))]
eye = torch.eye(heatmap_probs.shape[0], device=torch_device)
grads = torch.autograd.grad(
grads = grads[0].detach().requires_grad_(False)[:, 0, :, :].unsqueeze(1)
with torch.set_grad_enabled(False):
hmap_imgs, hmap_grid = render_heatmap(
caption, booru, ratings, character, general = get_tags(
labels = ImageLabels(caption, booru, ratings, general, character)
return hmap_imgs, hmap_grid, labels