File size: 3,068 Bytes
85205f1
ebb842f
85205f1
 
513f08f
7132357
 
 
 
85205f1
4e3ba00
85205f1
7132357
85205f1
513f08f
112454a
7132357
513f08f
7132357
 
513f08f
7132357
 
 
0d3f848
 
 
 
ebb842f
 
 
 
 
 
 
 
7132357
 
ebb842f
7132357
0d3f848
ebb842f
7132357
 
 
 
85205f1
7132357
85205f1
0d3f848
7132357
 
 
85205f1
 
7132357
 
 
85205f1
7132357
85205f1
7132357
85205f1
7132357
85205f1
 
 
 
 
7132357
 
 
85205f1
 
7132357
 
85205f1
 
 
7132357
85205f1
 
 
 
 
 
 
 
 
7132357
85205f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConfig
import gradio as gr
from PIL import Image
import os
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Define the class labels in the correct order as used during training
labels = ['Leggings', 'Jogger', 'Palazzo', 'Cargo', 'Dresspants', 'Chinos']
logging.info(f"Labels: {labels}")

# Define the path to the uploaded model file
model_path = "best_fine_tuned_vit_Leggings_Jogger_Palazzo_Cargo_Dresspants_Chinos_93.90243902439025_2024-08-26.pth"
logging.info(f"Looking for model file: {model_path}")

if os.path.exists(model_path):
    logging.info(f"Model file found: {model_path}")
else:
    logging.error(f"Model file not found: {model_path}")
    raise FileNotFoundError(f"Model file not found: {model_path}")

# Create label mappings consistent with training
id2label = {str(i): label for i, label in enumerate(labels)}
label2id = {label: str(i) for i, label in enumerate(labels)}

# Create a configuration for the model
config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k")
config.num_labels = len(labels)
config.id2label = id2label
config.label2id = label2id

# Initialize the model with the configuration
model = ViTForImageClassification(config)

try:
    # Load the state dict of the fine-tuned model
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(state_dict)
    logging.info("Fine-tuned model loaded successfully")
except Exception as e:
    logging.error(f"Error loading model: {str(e)}")
    raise

model.eval()
logging.info("Model set to evaluation mode")

# Load feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
logging.info("Feature extractor loaded")

# Define the prediction function
def predict(image):
    logging.info("Starting prediction")
    logging.info(f"Input image shape: {image.size}")
    
    # Preprocess the image
    logging.info("Preprocessing image")
    inputs = feature_extractor(images=image, return_tensors="pt")
    logging.info(f"Preprocessed input shape: {inputs['pixel_values'].shape}")
    
    logging.info("Running inference")
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.nn.functional.softmax(logits[0], dim=0)
    
    logging.info(f"Raw logits: {logits}")
    logging.info(f"Probabilities: {probabilities}")
    
    # Prepare the output dictionary
    result = {labels[i]: float(probabilities[i]) for i in range(len(labels))}
    logging.info(f"Prediction result: {result}")
    
    return result

# Set up the Gradio Interface
logging.info("Setting up Gradio interface")
gradio_app = gr.Interface(
    fn=predict, 
    inputs=gr.Image(type="pil"), 
    outputs=gr.Label(num_top_classes=6),
    title="Pants Shape Classifier"
)

# Launch the app
if __name__ == "__main__":
    logging.info("Launching the app")
    gradio_app.launch()