narugo commited on
Commit
da7e7c0
1 Parent(s): 3be3994

dev(narugo): more models added

Browse files
Files changed (1) hide show
  1. tagger/model.py +5 -8
tagger/model.py CHANGED
@@ -1,25 +1,20 @@
1
- import json
2
  import math
3
- from dataclasses import dataclass, field
4
- from os import PathLike, cpu_count
5
  from pathlib import Path
6
- from typing import Any, Optional, TypeAlias
7
 
8
  import colorcet as cc
9
  import cv2
10
  import numpy as np
11
- import pandas as pd
12
  import timm
13
  import torch
14
- from matplotlib.colors import LinearSegmentedColormap
15
  from PIL import Image
 
16
  from timm.data import create_transform, resolve_data_config
17
  from timm.models import VisionTransformer
18
  from torch import Tensor, nn
19
  from torch.nn import functional as F
20
  from torchvision import transforms as T
21
 
22
- from .common import Heatmap, ImageLabels, LabelData, load_labels_hf, pil_ensure_rgb, pil_make_grid
23
 
24
  # working dir, either file parent dir or cwd if interactive
25
  work_dir = (Path(__file__).parent if "__file__" in locals() else Path.cwd()).resolve()
@@ -129,7 +124,9 @@ def render_heatmap(
129
 
130
  image_hmaps = gradients.mean(2, keepdim=True).mul(image_feats.unsqueeze(0)).squeeze()
131
  hmap_dim = int(math.sqrt(image_hmaps.mean(-1).numel() / len(image_labels)))
132
- image_hmaps = image_hmaps.mean(-1).reshape(len(image_labels), hmap_dim, hmap_dim)
 
 
133
  image_hmaps = image_hmaps.max(torch.zeros_like(image_hmaps))
134
 
135
  image_hmaps /= image_hmaps.reshape(image_hmaps.shape[0], -1).max(-1)[0].unsqueeze(-1).unsqueeze(-1)
 
 
1
  import math
 
 
2
  from pathlib import Path
 
3
 
4
  import colorcet as cc
5
  import cv2
6
  import numpy as np
 
7
  import timm
8
  import torch
 
9
  from PIL import Image
10
+ from matplotlib.colors import LinearSegmentedColormap
11
  from timm.data import create_transform, resolve_data_config
12
  from timm.models import VisionTransformer
13
  from torch import Tensor, nn
14
  from torch.nn import functional as F
15
  from torchvision import transforms as T
16
 
17
+ from .common import Heatmap, ImageLabels, LabelData, pil_make_grid
18
 
19
  # working dir, either file parent dir or cwd if interactive
20
  work_dir = (Path(__file__).parent if "__file__" in locals() else Path.cwd()).resolve()
 
124
 
125
  image_hmaps = gradients.mean(2, keepdim=True).mul(image_feats.unsqueeze(0)).squeeze()
126
  hmap_dim = int(math.sqrt(image_hmaps.mean(-1).numel() / len(image_labels)))
127
+ image_hmaps = image_hmaps.mean(-1).reshape(len(image_labels), -1)
128
+ image_hmaps = image_hmaps[..., -hmap_dim ** 2:]
129
+ image_hmaps = image_hmaps.reshape(len(image_labels), hmap_dim, hmap_dim)
130
  image_hmaps = image_hmaps.max(torch.zeros_like(image_hmaps))
131
 
132
  image_hmaps /= image_hmaps.reshape(image_hmaps.shape[0], -1).max(-1)[0].unsqueeze(-1).unsqueeze(-1)