k4d3 commited on
Commit
d521fbd
1 Parent(s): 2460268

Signed-off-by: Balazs Horvath <acsipont@gmail.com>

Files changed (1) hide show
  1. jtp2 +11 -142
jtp2 CHANGED
@@ -7,15 +7,18 @@ This script implements a multi-label classifier for furry images using the
7
  PILOT2 model. It processes images, generates tags, and saves the results. The
8
  model is based on a Vision Transformer architecture and uses a custom GatedHead
9
  for classification.
 
10
  Key features:
11
  - Image preprocessing and transformation
12
  - Model inference using PILOT2
13
  - Tag generation with customizable threshold
14
  - Batch processing of image directories
15
  - Saving results as text files alongside images
 
16
  Usage:
17
  python jtp2.py <directory> [--threshold <float>]
18
  """
 
19
  import os
20
  import json
21
  import argparse
@@ -27,13 +30,11 @@ import torch
27
  from torchvision.transforms import transforms
28
  from torchvision.transforms import InterpolationMode
29
  import torchvision.transforms.functional as TF
30
- import pillow_jxl # type: ignore
31
- from itertools import islice
32
- import gettext
33
- import locale
34
 
35
  torch.set_grad_enabled(False)
36
 
 
37
  class Fit(torch.nn.Module):
38
  """
39
  A custom transform module for resizing and padding images.
@@ -193,14 +194,14 @@ safetensors.torch.load_model(
193
 
194
  # Create argument parser first
195
  parser = argparse.ArgumentParser(
196
- description=_("Run inference on a directory of images.")
197
  )
198
- parser.add_argument("directory", type=str, help=_("Target directory containing images."))
199
  parser.add_argument(
200
- "--threshold", type=float, default=0.2, help=_("Threshold for tag filtering.")
201
  )
202
  parser.add_argument(
203
- "--cpu", action="store_true", help=_("Force CPU inference instead of CUDA")
204
  )
205
  args = parser.parse_args()
206
 
@@ -218,140 +219,6 @@ for idx, tag in enumerate(allowed_tags):
218
  sorted_tag_score = {}
219
 
220
 
221
- <<<<<<< HEAD
222
- def batch_iterator(iterable, batch_size):
223
- """
224
- Creates batches from an iterable.
225
- Args:
226
- iterable: The source iterable to batch
227
- batch_size (int): Size of each batch
228
- """
229
- iterator = iter(iterable)
230
- while batch := list(islice(iterator, batch_size)):
231
- yield batch
232
-
233
- def setup_model():
234
- """Initialize model and move to appropriate device"""
235
- model = timm.create_model(
236
- "vit_so400m_patch14_siglip_384",
237
- pretrained=False,
238
- num_classes=9083
239
- )
240
- model.head = GatedHead(min(model.head.weight.shape), 9083)
241
- safetensors.torch.load_model(
242
- model, "/home/kade/source/repos/JTP2/JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors"
243
- )
244
-
245
- if torch.cuda.is_available() and not args.cpu:
246
- model.cuda()
247
- if torch.cuda.get_device_capability()[0] >= 7:
248
- model.to(dtype=torch.float16, memory_format=torch.channels_last)
249
- model.eval()
250
- return model
251
-
252
- def process_batch(args):
253
- """
254
- Processes a batch of images with the model.
255
- Args:
256
- args (tuple): Tuple containing (image_paths, threshold)
257
- """
258
- batch_paths, threshold = args
259
-
260
- # Initialize model and CUDA settings for this process
261
- if torch.cuda.is_available() and not args.cpu:
262
- model.cuda()
263
- if torch.cuda.get_device_capability()[0] >= 7:
264
- model.to(dtype=torch.float16, memory_format=torch.channels_last)
265
- model.eval()
266
-
267
- for image_path in batch_paths:
268
- try:
269
- text_file_path = os.path.splitext(image_path)[0] + ".tags"
270
-
271
- # Skip if a corresponding .txt file already exists
272
- if os.path.exists(text_file_path):
273
- print(_("Skipping {}: {} already exists").format(image_path, text_file_path))
274
- continue
275
-
276
- image = Image.open(image_path)
277
- tags, _ = run_classifier(image, threshold)
278
-
279
- # Save tags to a text file
280
- with open(text_file_path, "w", encoding="utf-8") as text_file:
281
- text_file.write(tags)
282
-
283
- print(f"{image_path}: {tags}")
284
-
285
- except Exception as e:
286
- print(f"Error processing {image_path}: {e}")
287
-
288
- def run_classifier(image, model, threshold):
289
- ||||||| ef62f54 (multiproc)
290
- def batch_iterator(iterable, batch_size):
291
- """
292
- Creates batches from an iterable.
293
- Args:
294
- iterable: The source iterable to batch
295
- batch_size (int): Size of each batch
296
- """
297
- iterator = iter(iterable)
298
- while batch := list(islice(iterator, batch_size)):
299
- yield batch
300
-
301
- def setup_model():
302
- """Initialize model and move to appropriate device"""
303
- model = timm.create_model(
304
- "vit_so400m_patch14_siglip_384",
305
- pretrained=False,
306
- num_classes=9083
307
- )
308
- model.head = GatedHead(min(model.head.weight.shape), 9083)
309
- safetensors.torch.load_model(
310
- model, "/home/kade/source/repos/JTP2/JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors"
311
- )
312
-
313
- if torch.cuda.is_available() and not args.cpu:
314
- model.cuda()
315
- if torch.cuda.get_device_capability()[0] >= 7:
316
- model.to(dtype=torch.float16, memory_format=torch.channels_last)
317
- model.eval()
318
- return model
319
-
320
- def process_batch(args):
321
- """
322
- Processes a batch of images with the model.
323
- Args:
324
- args (tuple): Tuple containing (image_paths, threshold)
325
- """
326
- batch_paths, threshold = args
327
-
328
- # Initialize model and CUDA settings for this process
329
- if torch.cuda.is_available() and not args.cpu:
330
- model.cuda()
331
- if torch.cuda.get_device_capability()[0] >= 7:
332
- model.to(dtype=torch.float16, memory_format=torch.channels_last)
333
- model.eval()
334
-
335
- for image_path in batch_paths:
336
- try:
337
- text_file_path = os.path.splitext(image_path)[0] + ".tags"
338
-
339
- # Skip if a corresponding .txt file already exists
340
- if os.path.exists(text_file_path):
341
- continue
342
-
343
- image = Image.open(image_path)
344
- tags, _ = run_classifier(image, threshold)
345
-
346
- # Save tags to a text file
347
- with open(text_file_path, "w", encoding="utf-8") as text_file:
348
- text_file.write(tags)
349
-
350
- print(f"{image_path}: {tags}")
351
-
352
- except Exception as e:
353
- print(f"Error processing {image_path}: {e}")
354
-
355
  def run_classifier(image, threshold):
356
  """
357
  Runs the classifier on a single image and returns tags based on the threshold.
@@ -432,3 +299,5 @@ if __name__ == "__main__":
432
  results = process_directory(args.directory, args.threshold)
433
  for image_path, tags in results.items():
434
  print(f"{image_path}: {tags}")
 
 
 
7
  PILOT2 model. It processes images, generates tags, and saves the results. The
8
  model is based on a Vision Transformer architecture and uses a custom GatedHead
9
  for classification.
10
+
11
  Key features:
12
  - Image preprocessing and transformation
13
  - Model inference using PILOT2
14
  - Tag generation with customizable threshold
15
  - Batch processing of image directories
16
  - Saving results as text files alongside images
17
+
18
  Usage:
19
  python jtp2.py <directory> [--threshold <float>]
20
  """
21
+
22
  import os
23
  import json
24
  import argparse
 
30
  from torchvision.transforms import transforms
31
  from torchvision.transforms import InterpolationMode
32
  import torchvision.transforms.functional as TF
33
+ import pillow_jxl
 
 
 
34
 
35
  torch.set_grad_enabled(False)
36
 
37
+
38
  class Fit(torch.nn.Module):
39
  """
40
  A custom transform module for resizing and padding images.
 
194
 
195
  # Create argument parser first
196
  parser = argparse.ArgumentParser(
197
+ description="Run inference on a directory of images."
198
  )
199
+ parser.add_argument("directory", type=str, help="Target directory containing images.")
200
  parser.add_argument(
201
+ "--threshold", type=float, default=0.2, help="Threshold for tag filtering."
202
  )
203
  parser.add_argument(
204
+ "--cpu", action="store_true", help="Force CPU inference instead of CUDA"
205
  )
206
  args = parser.parse_args()
207
 
 
219
  sorted_tag_score = {}
220
 
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  def run_classifier(image, threshold):
223
  """
224
  Runs the classifier on a single image and returns tags based on the threshold.
 
299
  results = process_directory(args.directory, args.threshold)
300
  for image_path, tags in results.items():
301
  print(f"{image_path}: {tags}")
302
+
303
+