Spaces:
Sleeping
Sleeping
File size: 3,068 Bytes
85205f1 ebb842f 85205f1 513f08f 7132357 85205f1 4e3ba00 85205f1 7132357 85205f1 513f08f 112454a 7132357 513f08f 7132357 513f08f 7132357 0d3f848 ebb842f 7132357 ebb842f 7132357 0d3f848 ebb842f 7132357 85205f1 7132357 85205f1 0d3f848 7132357 85205f1 7132357 85205f1 7132357 85205f1 7132357 85205f1 7132357 85205f1 7132357 85205f1 7132357 85205f1 7132357 85205f1 7132357 85205f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConfig
import gradio as gr
from PIL import Image
import os
import logging
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Define the class labels in the correct order as used during training
labels = ['Leggings', 'Jogger', 'Palazzo', 'Cargo', 'Dresspants', 'Chinos']
logging.info(f"Labels: {labels}")
# Define the path to the uploaded model file
model_path = "best_fine_tuned_vit_Leggings_Jogger_Palazzo_Cargo_Dresspants_Chinos_93.90243902439025_2024-08-26.pth"
logging.info(f"Looking for model file: {model_path}")
if os.path.exists(model_path):
logging.info(f"Model file found: {model_path}")
else:
logging.error(f"Model file not found: {model_path}")
raise FileNotFoundError(f"Model file not found: {model_path}")
# Create label mappings consistent with training
id2label = {str(i): label for i, label in enumerate(labels)}
label2id = {label: str(i) for i, label in enumerate(labels)}
# Create a configuration for the model
config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k")
config.num_labels = len(labels)
config.id2label = id2label
config.label2id = label2id
# Initialize the model with the configuration
model = ViTForImageClassification(config)
try:
# Load the state dict of the fine-tuned model
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
logging.info("Fine-tuned model loaded successfully")
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
raise
model.eval()
logging.info("Model set to evaluation mode")
# Load feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
logging.info("Feature extractor loaded")
# Define the prediction function
def predict(image):
logging.info("Starting prediction")
logging.info(f"Input image shape: {image.size}")
# Preprocess the image
logging.info("Preprocessing image")
inputs = feature_extractor(images=image, return_tensors="pt")
logging.info(f"Preprocessed input shape: {inputs['pixel_values'].shape}")
logging.info("Running inference")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits[0], dim=0)
logging.info(f"Raw logits: {logits}")
logging.info(f"Probabilities: {probabilities}")
# Prepare the output dictionary
result = {labels[i]: float(probabilities[i]) for i in range(len(labels))}
logging.info(f"Prediction result: {result}")
return result
# Set up the Gradio Interface
logging.info("Setting up Gradio interface")
gradio_app = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=6),
title="Pants Shape Classifier"
)
# Launch the app
if __name__ == "__main__":
logging.info("Launching the app")
gradio_app.launch() |