toolkit / aestethic-shadow-v2
k4d3's picture
organize a little
2cc9f55
#!/usr/bin/env python
import pillow_jxl # type: ignore
from PIL import Image
from transformers import AutoProcessor, ViTForImageClassification
# Load the processor and model
processor = AutoProcessor.from_pretrained("shadowlilac/aesthetic-shadow-v2")
model = ViTForImageClassification.from_pretrained("shadowlilac/aesthetic-shadow-v2")
def classify_image(image_path):
# Open the image file
raw_image = Image.open(image_path).convert("RGB") # Ensure the image is in RGB format
inputs = processor(raw_image, return_tensors="pt")
# Perform image classification
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
class_labels = model.config.id2label
predicted_class = class_labels[predicted_class_idx]
return predicted_class
if __name__ == "__main__":
import argparse
# Set up argument parser
parser = argparse.ArgumentParser(description="Classify an image using the shadowlilac/aesthetic-shadow-v2 model.")
parser.add_argument("image_path", type=str, help="Path to the image file")
# Parse arguments
args = parser.parse_args()
# Classify and print the result
predicted_class = classify_image(args.image_path)
print(f"Predicted Class: {predicted_class}")