toolkit / waifu-scorer-v4-beta
k4d3's picture
A lot random of stuff
2fb8598
#!/usr/bin/env python
import os
import pillow_jxl
from PIL import Image
from transformers import AutoProcessor, ViTForImageClassification
# Load the processor and model from the local directory
model_dir = "/home/kade/models/classifiers/waifu-scorer-v4-beta"
processor = AutoProcessor.from_pretrained(model_dir)
model = ViTForImageClassification.from_pretrained(model_dir)
def classify_image(image_path):
# Open the image file
raw_image = Image.open(image_path).convert("RGB") # Ensure the image is in RGB format
inputs = processor(raw_image, return_tensors="pt")
# Perform image classification
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
class_labels = model.config.id2label
predicted_class = class_labels[predicted_class_idx]
return predicted_class
if __name__ == "__main__":
import argparse
# Set up argument parser
parser = argparse.ArgumentParser(description="Classify an image using the waifu-scorer-v4-beta model.")
parser.add_argument("image_path", type=str, help="Path to the image file")
# Parse arguments
args = parser.parse_args()
# Classify and print the result
predicted_class = classify_image(args.image_path)
print(f"Predicted Class: {predicted_class}")