File size: 2,847 Bytes
85205f1
3713f52
85205f1
 
513f08f
7132357
1b79fe8
7132357
 
 
85205f1
3713f52
cf14c94
3713f52
 
cf14c94
 
 
513f08f
4b0d68d
 
 
 
 
 
 
7132357
3713f52
 
 
2b271e4
 
 
 
3713f52
 
85dcc91
1b79fe8
 
3713f52
 
 
1b79fe8
3713f52
7132357
4b0d68d
85205f1
7132357
85205f1
 
 
7132357
 
 
85205f1
7132357
85205f1
7132357
85205f1
7132357
85205f1
 
 
 
 
7132357
 
 
85205f1
 
7132357
 
85205f1
 
 
7132357
85205f1
 
 
 
a07e008
85205f1
 
 
 
7132357
4b0d68d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor, AutoConfig
import gradio as gr
from PIL import Image
import os
import logging
from safetensors.torch import load_file  # Import safetensors loading function

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Define the directory containing the model files
model_dir = "."  # Use current directory

# Define paths to the specific model files
model_path = os.path.join(model_dir, "model.safetensors")
config_path = os.path.join(model_dir, "config.json")
preprocessor_path = os.path.join(model_dir, "preprocessor_config.json")

# Check if all required files exist
for path in [model_path, config_path, preprocessor_path]:
    if not os.path.exists(path):
        logging.error(f"File not found: {path}")
        raise FileNotFoundError(f"Required file not found: {path}")
    else:
        logging.info(f"Found file: {path}")

# Load the configuration
config = AutoConfig.from_pretrained(config_path)

# Ensure the labels are consistent with the model's config
labels = list(config.id2label.values())
logging.info(f"Labels: {labels}")

# Load the feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained(preprocessor_path)

# Load the model using the safetensors file
state_dict = load_file(model_path)  # Use safetensors to load the model weights
model = ViTForImageClassification.from_pretrained(
    pretrained_model_name_or_path=None,
    config=config,
    state_dict=state_dict
)

# Ensure the model is in evaluation mode
model.eval()
logging.info("Model set to evaluation mode")

# Define the prediction function
def predict(image):
    logging.info("Starting prediction")
    logging.info(f"Input image shape: {image.size}")
    
    # Preprocess the image
    logging.info("Preprocessing image")
    inputs = feature_extractor(images=image, return_tensors="pt")
    logging.info(f"Preprocessed input shape: {inputs['pixel_values'].shape}")
    
    logging.info("Running inference")
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.nn.functional.softmax(logits[0], dim=0)
    
    logging.info(f"Raw logits: {logits}")
    logging.info(f"Probabilities: {probabilities}")
    
    # Prepare the output dictionary
    result = {labels[i]: float(probabilities[i]) for i in range(len(labels))}
    logging.info(f"Prediction result: {result}")
    
    return result

# Set up the Gradio Interface
logging.info("Setting up Gradio interface")
gradio_app = gr.Interface(
    fn=predict, 
    inputs=gr.Image(type="pil"), 
    outputs=gr.Label(num_top_classes=6),
    title="Sleeves Length Classifier"
)

# Launch the app
if __name__ == "__main__":
    logging.info("Launching the app")
    gradio_app.launch()