|
#!/usr/bin/env python |
|
|
|
import os |
|
import pillow_jxl |
|
from PIL import Image |
|
from transformers import AutoProcessor, ViTForImageClassification |
|
|
|
# Load the processor and model from the local directory |
|
model_dir = "/home/kade/models/classifiers/waifu-scorer-v4-beta" |
|
processor = AutoProcessor.from_pretrained(model_dir) |
|
model = ViTForImageClassification.from_pretrained(model_dir) |
|
|
|
def classify_image(image_path): |
|
# Open the image file |
|
raw_image = Image.open(image_path).convert("RGB") # Ensure the image is in RGB format |
|
inputs = processor(raw_image, return_tensors="pt") |
|
|
|
# Perform image classification |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
class_labels = model.config.id2label |
|
predicted_class = class_labels[predicted_class_idx] |
|
|
|
return predicted_class |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
# Set up argument parser |
|
parser = argparse.ArgumentParser(description="Classify an image using the waifu-scorer-v4-beta model.") |
|
parser.add_argument("image_path", type=str, help="Path to the image file") |
|
|
|
# Parse arguments |
|
args = parser.parse_args() |
|
|
|
# Classify and print the result |
|
predicted_class = classify_image(args.image_path) |
|
print(f"Predicted Class: {predicted_class}") |
|
|
|
|