TopLength / app.py
DumbledoreWiz's picture
Update app.py
ebb842f verified
raw
history blame
3.07 kB
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConfig
import gradio as gr
from PIL import Image
import os
import logging
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Define the class labels in the correct order as used during training
labels = ['Leggings', 'Jogger', 'Palazzo', 'Cargo', 'Dresspants', 'Chinos']
logging.info(f"Labels: {labels}")
# Define the path to the uploaded model file
model_path = "best_fine_tuned_vit_Leggings_Jogger_Palazzo_Cargo_Dresspants_Chinos_93.90243902439025_2024-08-26.pth"
logging.info(f"Looking for model file: {model_path}")
if os.path.exists(model_path):
logging.info(f"Model file found: {model_path}")
else:
logging.error(f"Model file not found: {model_path}")
raise FileNotFoundError(f"Model file not found: {model_path}")
# Create label mappings consistent with training
id2label = {str(i): label for i, label in enumerate(labels)}
label2id = {label: str(i) for i, label in enumerate(labels)}
# Create a configuration for the model
config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k")
config.num_labels = len(labels)
config.id2label = id2label
config.label2id = label2id
# Initialize the model with the configuration
model = ViTForImageClassification(config)
try:
# Load the state dict of the fine-tuned model
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
logging.info("Fine-tuned model loaded successfully")
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
raise
model.eval()
logging.info("Model set to evaluation mode")
# Load feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
logging.info("Feature extractor loaded")
# Define the prediction function
def predict(image):
logging.info("Starting prediction")
logging.info(f"Input image shape: {image.size}")
# Preprocess the image
logging.info("Preprocessing image")
inputs = feature_extractor(images=image, return_tensors="pt")
logging.info(f"Preprocessed input shape: {inputs['pixel_values'].shape}")
logging.info("Running inference")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits[0], dim=0)
logging.info(f"Raw logits: {logits}")
logging.info(f"Probabilities: {probabilities}")
# Prepare the output dictionary
result = {labels[i]: float(probabilities[i]) for i in range(len(labels))}
logging.info(f"Prediction result: {result}")
return result
# Set up the Gradio Interface
logging.info("Setting up Gradio interface")
gradio_app = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=6),
title="Pants Shape Classifier"
)
# Launch the app
if __name__ == "__main__":
logging.info("Launching the app")
gradio_app.launch()