Image Classification
timm
drhead commited on
Commit
fda5480
·
verified ·
1 Parent(s): 3b61449

Delete inference_gradio.py

Browse files
Files changed (1) hide show
  1. inference_gradio.py +0 -179
inference_gradio.py DELETED
@@ -1,179 +0,0 @@
1
- import json
2
-
3
- import gradio as gr
4
- from PIL import Image
5
- import safetensors.torch
6
- import timm
7
- from timm.models import VisionTransformer
8
- import torch
9
- from torchvision.transforms import transforms
10
- from torchvision.transforms import InterpolationMode
11
- import torchvision.transforms.functional as TF
12
-
13
- torch.set_grad_enabled(False)
14
-
15
- class Fit(torch.nn.Module):
16
- def __init__(
17
- self,
18
- bounds: tuple[int, int] | int,
19
- interpolation = InterpolationMode.LANCZOS,
20
- grow: bool = True,
21
- pad: float | None = None
22
- ):
23
- super().__init__()
24
-
25
- self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds
26
- self.interpolation = interpolation
27
- self.grow = grow
28
- self.pad = pad
29
-
30
- def forward(self, img: Image) -> Image:
31
- wimg, himg = img.size
32
- hbound, wbound = self.bounds
33
-
34
- hscale = hbound / himg
35
- wscale = wbound / wimg
36
-
37
- if not self.grow:
38
- hscale = min(hscale, 1.0)
39
- wscale = min(wscale, 1.0)
40
-
41
- scale = min(hscale, wscale)
42
- if scale == 1.0:
43
- return img
44
-
45
- hnew = min(round(himg * scale), hbound)
46
- wnew = min(round(wimg * scale), wbound)
47
-
48
- img = TF.resize(img, (hnew, wnew), self.interpolation)
49
-
50
- if self.pad is None:
51
- return img
52
-
53
- hpad = hbound - hnew
54
- wpad = wbound - wnew
55
-
56
- tpad = hpad // 2
57
- bpad = hpad - tpad
58
-
59
- lpad = wpad // 2
60
- rpad = wpad - lpad
61
-
62
- return TF.pad(img, (lpad, tpad, rpad, bpad), self.pad)
63
-
64
- def __repr__(self) -> str:
65
- return (
66
- f"{self.__class__.__name__}(" +
67
- f"bounds={self.bounds}, " +
68
- f"interpolation={self.interpolation.value}, " +
69
- f"grow={self.grow}, " +
70
- f"pad={self.pad})"
71
- )
72
-
73
- class CompositeAlpha(torch.nn.Module):
74
- def __init__(
75
- self,
76
- background: tuple[float, float, float] | float,
77
- ):
78
- super().__init__()
79
-
80
- self.background = (background, background, background) if isinstance(background, float) else background
81
- self.background = torch.tensor(self.background).unsqueeze(1).unsqueeze(2)
82
-
83
- def forward(self, img: torch.Tensor) -> torch.Tensor:
84
- if img.shape[-3] == 3:
85
- return img
86
-
87
- alpha = img[..., 3, None, :, :]
88
-
89
- img[..., :3, :, :] *= alpha
90
-
91
- background = self.background.expand(-1, img.shape[-2], img.shape[-1])
92
- if background.ndim == 1:
93
- background = background[:, None, None]
94
- elif background.ndim == 2:
95
- background = background[None, :, :]
96
-
97
- img[..., :3, :, :] += (1.0 - alpha) * background
98
- return img[..., :3, :, :]
99
-
100
- def __repr__(self) -> str:
101
- return (
102
- f"{self.__class__.__name__}(" +
103
- f"background={self.background})"
104
- )
105
-
106
- transform = transforms.Compose([
107
- Fit((384, 384)),
108
- transforms.ToTensor(),
109
- CompositeAlpha(0.5),
110
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
111
- transforms.CenterCrop((384, 384)),
112
- ])
113
-
114
- model = timm.create_model(
115
- "vit_so400m_patch14_siglip_384.webli",
116
- pretrained=False,
117
- num_classes=9083,
118
- ) # type: VisionTransformer
119
-
120
- safetensors.torch.load_model(model, "JTP_PILOT/JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors")
121
- model.eval()
122
-
123
- if torch.cuda.is_available():
124
- model.cuda()
125
- if torch.cuda.get_device_capability()[0] >= 7: # tensor cores
126
- model.to(dtype=torch.float16, memory_format=torch.channels_last)
127
-
128
- with open("JTP_PILOT/tags.json", "r") as file:
129
- tags = json.load(file) # type: dict
130
- allowed_tags = list(tags.keys())
131
-
132
- for idx, tag in enumerate(allowed_tags):
133
- allowed_tags[idx] = tag.replace("_", " ")
134
-
135
- def create_tags(image, threshold):
136
- img = image.convert('RGB')
137
- tensor = transform(img).unsqueeze(0) # type: torch.Tensor
138
-
139
- if torch.cuda.is_available():
140
- tensor = tensor.cuda()
141
- if torch.cuda.get_device_capability()[0] >= 7:
142
- tensor = tensor.to(dtype=torch.float16, memory_format=torch.channels_last)
143
-
144
- with torch.no_grad():
145
- logits = model(tensor)
146
- probabilities = torch.nn.functional.sigmoid(logits[0])
147
- indices = torch.where(probabilities > threshold)[0]
148
- values = probabilities[indices]
149
-
150
- temp = []
151
- tag_score = dict()
152
- for i in range(indices.size(0)):
153
- temp.append([allowed_tags[indices[i]], values[i].item()])
154
- tag_score[allowed_tags[indices[i]]] = values[i].item()
155
- temp = [t[0] for t in temp]
156
- text_no_impl = ", ".join(temp)
157
- return text_no_impl, tag_score
158
-
159
- with gr.Blocks() as demo:
160
- gr.Markdown("""
161
- ## Joint Tagger Project: PILOT
162
- This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
163
-
164
- This tagger is the result of joint efforts between members of the RedRocket team.
165
-
166
- Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
167
- """)
168
- gr.Interface(
169
- create_tags,
170
- inputs=[gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")],
171
- outputs=[
172
- gr.Textbox(label="Tag String"),
173
- gr.Label(label="Tag Predictions", num_top_classes=200),
174
- ],
175
- allow_flagging="never",
176
- )
177
-
178
- if __name__ == "__main__":
179
- demo.launch()