Spaces:
Sleeping
Sleeping
import torch | |
from transformers import ViTForImageClassification, ViTFeatureExtractor | |
import gradio as gr | |
from PIL import Image | |
# Define the class labels as used during training | |
labels = ['Leggings', 'Jogger', 'Palazzo', 'Cargo', 'Dresspants', 'Chinos'] | |
# Load the ViT model and feature extractor | |
model = ViTForImageClassification.from_pretrained("DumbledoreWiz/PantsShape") | |
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") | |
# Set the model to evaluation mode | |
model.eval() | |
# Define the prediction function | |
def predict(image): | |
# Preprocess the image | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probabilities = torch.nn.functional.softmax(logits[0], dim=0) | |
# Prepare the output dictionary | |
result = {labels[i]: float(probabilities[i]) for i in range(len(labels))} | |
return result | |
# Set up the Gradio Interface | |
gradio_app = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=6), | |
title="Pants Shape Classifier" | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
gradio_app.launch() |