reverty
Browse filesSigned-off-by: Balazs Horvath <acsipont@gmail.com>
jtp2
CHANGED
@@ -7,15 +7,18 @@ This script implements a multi-label classifier for furry images using the
|
|
7 |
PILOT2 model. It processes images, generates tags, and saves the results. The
|
8 |
model is based on a Vision Transformer architecture and uses a custom GatedHead
|
9 |
for classification.
|
|
|
10 |
Key features:
|
11 |
- Image preprocessing and transformation
|
12 |
- Model inference using PILOT2
|
13 |
- Tag generation with customizable threshold
|
14 |
- Batch processing of image directories
|
15 |
- Saving results as text files alongside images
|
|
|
16 |
Usage:
|
17 |
python jtp2.py <directory> [--threshold <float>]
|
18 |
"""
|
|
|
19 |
import os
|
20 |
import json
|
21 |
import argparse
|
@@ -27,13 +30,11 @@ import torch
|
|
27 |
from torchvision.transforms import transforms
|
28 |
from torchvision.transforms import InterpolationMode
|
29 |
import torchvision.transforms.functional as TF
|
30 |
-
import pillow_jxl
|
31 |
-
from itertools import islice
|
32 |
-
import gettext
|
33 |
-
import locale
|
34 |
|
35 |
torch.set_grad_enabled(False)
|
36 |
|
|
|
37 |
class Fit(torch.nn.Module):
|
38 |
"""
|
39 |
A custom transform module for resizing and padding images.
|
@@ -193,14 +194,14 @@ safetensors.torch.load_model(
|
|
193 |
|
194 |
# Create argument parser first
|
195 |
parser = argparse.ArgumentParser(
|
196 |
-
description=
|
197 |
)
|
198 |
-
parser.add_argument("directory", type=str, help=
|
199 |
parser.add_argument(
|
200 |
-
"--threshold", type=float, default=0.2, help=
|
201 |
)
|
202 |
parser.add_argument(
|
203 |
-
"--cpu", action="store_true", help=
|
204 |
)
|
205 |
args = parser.parse_args()
|
206 |
|
@@ -218,140 +219,6 @@ for idx, tag in enumerate(allowed_tags):
|
|
218 |
sorted_tag_score = {}
|
219 |
|
220 |
|
221 |
-
<<<<<<< HEAD
|
222 |
-
def batch_iterator(iterable, batch_size):
|
223 |
-
"""
|
224 |
-
Creates batches from an iterable.
|
225 |
-
Args:
|
226 |
-
iterable: The source iterable to batch
|
227 |
-
batch_size (int): Size of each batch
|
228 |
-
"""
|
229 |
-
iterator = iter(iterable)
|
230 |
-
while batch := list(islice(iterator, batch_size)):
|
231 |
-
yield batch
|
232 |
-
|
233 |
-
def setup_model():
|
234 |
-
"""Initialize model and move to appropriate device"""
|
235 |
-
model = timm.create_model(
|
236 |
-
"vit_so400m_patch14_siglip_384",
|
237 |
-
pretrained=False,
|
238 |
-
num_classes=9083
|
239 |
-
)
|
240 |
-
model.head = GatedHead(min(model.head.weight.shape), 9083)
|
241 |
-
safetensors.torch.load_model(
|
242 |
-
model, "/home/kade/source/repos/JTP2/JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors"
|
243 |
-
)
|
244 |
-
|
245 |
-
if torch.cuda.is_available() and not args.cpu:
|
246 |
-
model.cuda()
|
247 |
-
if torch.cuda.get_device_capability()[0] >= 7:
|
248 |
-
model.to(dtype=torch.float16, memory_format=torch.channels_last)
|
249 |
-
model.eval()
|
250 |
-
return model
|
251 |
-
|
252 |
-
def process_batch(args):
|
253 |
-
"""
|
254 |
-
Processes a batch of images with the model.
|
255 |
-
Args:
|
256 |
-
args (tuple): Tuple containing (image_paths, threshold)
|
257 |
-
"""
|
258 |
-
batch_paths, threshold = args
|
259 |
-
|
260 |
-
# Initialize model and CUDA settings for this process
|
261 |
-
if torch.cuda.is_available() and not args.cpu:
|
262 |
-
model.cuda()
|
263 |
-
if torch.cuda.get_device_capability()[0] >= 7:
|
264 |
-
model.to(dtype=torch.float16, memory_format=torch.channels_last)
|
265 |
-
model.eval()
|
266 |
-
|
267 |
-
for image_path in batch_paths:
|
268 |
-
try:
|
269 |
-
text_file_path = os.path.splitext(image_path)[0] + ".tags"
|
270 |
-
|
271 |
-
# Skip if a corresponding .txt file already exists
|
272 |
-
if os.path.exists(text_file_path):
|
273 |
-
print(_("Skipping {}: {} already exists").format(image_path, text_file_path))
|
274 |
-
continue
|
275 |
-
|
276 |
-
image = Image.open(image_path)
|
277 |
-
tags, _ = run_classifier(image, threshold)
|
278 |
-
|
279 |
-
# Save tags to a text file
|
280 |
-
with open(text_file_path, "w", encoding="utf-8") as text_file:
|
281 |
-
text_file.write(tags)
|
282 |
-
|
283 |
-
print(f"{image_path}: {tags}")
|
284 |
-
|
285 |
-
except Exception as e:
|
286 |
-
print(f"Error processing {image_path}: {e}")
|
287 |
-
|
288 |
-
def run_classifier(image, model, threshold):
|
289 |
-
||||||| ef62f54 (multiproc)
|
290 |
-
def batch_iterator(iterable, batch_size):
|
291 |
-
"""
|
292 |
-
Creates batches from an iterable.
|
293 |
-
Args:
|
294 |
-
iterable: The source iterable to batch
|
295 |
-
batch_size (int): Size of each batch
|
296 |
-
"""
|
297 |
-
iterator = iter(iterable)
|
298 |
-
while batch := list(islice(iterator, batch_size)):
|
299 |
-
yield batch
|
300 |
-
|
301 |
-
def setup_model():
|
302 |
-
"""Initialize model and move to appropriate device"""
|
303 |
-
model = timm.create_model(
|
304 |
-
"vit_so400m_patch14_siglip_384",
|
305 |
-
pretrained=False,
|
306 |
-
num_classes=9083
|
307 |
-
)
|
308 |
-
model.head = GatedHead(min(model.head.weight.shape), 9083)
|
309 |
-
safetensors.torch.load_model(
|
310 |
-
model, "/home/kade/source/repos/JTP2/JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors"
|
311 |
-
)
|
312 |
-
|
313 |
-
if torch.cuda.is_available() and not args.cpu:
|
314 |
-
model.cuda()
|
315 |
-
if torch.cuda.get_device_capability()[0] >= 7:
|
316 |
-
model.to(dtype=torch.float16, memory_format=torch.channels_last)
|
317 |
-
model.eval()
|
318 |
-
return model
|
319 |
-
|
320 |
-
def process_batch(args):
|
321 |
-
"""
|
322 |
-
Processes a batch of images with the model.
|
323 |
-
Args:
|
324 |
-
args (tuple): Tuple containing (image_paths, threshold)
|
325 |
-
"""
|
326 |
-
batch_paths, threshold = args
|
327 |
-
|
328 |
-
# Initialize model and CUDA settings for this process
|
329 |
-
if torch.cuda.is_available() and not args.cpu:
|
330 |
-
model.cuda()
|
331 |
-
if torch.cuda.get_device_capability()[0] >= 7:
|
332 |
-
model.to(dtype=torch.float16, memory_format=torch.channels_last)
|
333 |
-
model.eval()
|
334 |
-
|
335 |
-
for image_path in batch_paths:
|
336 |
-
try:
|
337 |
-
text_file_path = os.path.splitext(image_path)[0] + ".tags"
|
338 |
-
|
339 |
-
# Skip if a corresponding .txt file already exists
|
340 |
-
if os.path.exists(text_file_path):
|
341 |
-
continue
|
342 |
-
|
343 |
-
image = Image.open(image_path)
|
344 |
-
tags, _ = run_classifier(image, threshold)
|
345 |
-
|
346 |
-
# Save tags to a text file
|
347 |
-
with open(text_file_path, "w", encoding="utf-8") as text_file:
|
348 |
-
text_file.write(tags)
|
349 |
-
|
350 |
-
print(f"{image_path}: {tags}")
|
351 |
-
|
352 |
-
except Exception as e:
|
353 |
-
print(f"Error processing {image_path}: {e}")
|
354 |
-
|
355 |
def run_classifier(image, threshold):
|
356 |
"""
|
357 |
Runs the classifier on a single image and returns tags based on the threshold.
|
@@ -432,3 +299,5 @@ if __name__ == "__main__":
|
|
432 |
results = process_directory(args.directory, args.threshold)
|
433 |
for image_path, tags in results.items():
|
434 |
print(f"{image_path}: {tags}")
|
|
|
|
|
|
7 |
PILOT2 model. It processes images, generates tags, and saves the results. The
|
8 |
model is based on a Vision Transformer architecture and uses a custom GatedHead
|
9 |
for classification.
|
10 |
+
|
11 |
Key features:
|
12 |
- Image preprocessing and transformation
|
13 |
- Model inference using PILOT2
|
14 |
- Tag generation with customizable threshold
|
15 |
- Batch processing of image directories
|
16 |
- Saving results as text files alongside images
|
17 |
+
|
18 |
Usage:
|
19 |
python jtp2.py <directory> [--threshold <float>]
|
20 |
"""
|
21 |
+
|
22 |
import os
|
23 |
import json
|
24 |
import argparse
|
|
|
30 |
from torchvision.transforms import transforms
|
31 |
from torchvision.transforms import InterpolationMode
|
32 |
import torchvision.transforms.functional as TF
|
33 |
+
import pillow_jxl
|
|
|
|
|
|
|
34 |
|
35 |
torch.set_grad_enabled(False)
|
36 |
|
37 |
+
|
38 |
class Fit(torch.nn.Module):
|
39 |
"""
|
40 |
A custom transform module for resizing and padding images.
|
|
|
194 |
|
195 |
# Create argument parser first
|
196 |
parser = argparse.ArgumentParser(
|
197 |
+
description="Run inference on a directory of images."
|
198 |
)
|
199 |
+
parser.add_argument("directory", type=str, help="Target directory containing images.")
|
200 |
parser.add_argument(
|
201 |
+
"--threshold", type=float, default=0.2, help="Threshold for tag filtering."
|
202 |
)
|
203 |
parser.add_argument(
|
204 |
+
"--cpu", action="store_true", help="Force CPU inference instead of CUDA"
|
205 |
)
|
206 |
args = parser.parse_args()
|
207 |
|
|
|
219 |
sorted_tag_score = {}
|
220 |
|
221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
def run_classifier(image, threshold):
|
223 |
"""
|
224 |
Runs the classifier on a single image and returns tags based on the threshold.
|
|
|
299 |
results = process_directory(args.directory, args.threshold)
|
300 |
for image_path, tags in results.items():
|
301 |
print(f"{image_path}: {tags}")
|
302 |
+
|
303 |
+
|