#!/usr/bin/env python import os import pillow_jxl from PIL import Image from transformers import AutoProcessor, ViTForImageClassification # Load the processor and model from the local directory model_dir = "/home/kade/models/classifiers/waifu-scorer-v4-beta" processor = AutoProcessor.from_pretrained(model_dir) model = ViTForImageClassification.from_pretrained(model_dir) def classify_image(image_path): # Open the image file raw_image = Image.open(image_path).convert("RGB") # Ensure the image is in RGB format inputs = processor(raw_image, return_tensors="pt") # Perform image classification outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() class_labels = model.config.id2label predicted_class = class_labels[predicted_class_idx] return predicted_class if __name__ == "__main__": import argparse # Set up argument parser parser = argparse.ArgumentParser(description="Classify an image using the waifu-scorer-v4-beta model.") parser.add_argument("image_path", type=str, help="Path to the image file") # Parse arguments args = parser.parse_args() # Classify and print the result predicted_class = classify_image(args.image_path) print(f"Predicted Class: {predicted_class}")