SleevesLength / app.py
DumbledoreWiz's picture
Update app.py
a07e008 verified
raw
history blame
2.85 kB
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor, AutoConfig
import gradio as gr
from PIL import Image
import os
import logging
from safetensors.torch import load_file # Import safetensors loading function
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Define the directory containing the model files
model_dir = "." # Use current directory
# Define paths to the specific model files
model_path = os.path.join(model_dir, "model.safetensors")
config_path = os.path.join(model_dir, "config.json")
preprocessor_path = os.path.join(model_dir, "preprocessor_config.json")
# Check if all required files exist
for path in [model_path, config_path, preprocessor_path]:
if not os.path.exists(path):
logging.error(f"File not found: {path}")
raise FileNotFoundError(f"Required file not found: {path}")
else:
logging.info(f"Found file: {path}")
# Load the configuration
config = AutoConfig.from_pretrained(config_path)
# Ensure the labels are consistent with the model's config
labels = list(config.id2label.values())
logging.info(f"Labels: {labels}")
# Load the feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained(preprocessor_path)
# Load the model using the safetensors file
state_dict = load_file(model_path) # Use safetensors to load the model weights
model = ViTForImageClassification.from_pretrained(
pretrained_model_name_or_path=None,
config=config,
state_dict=state_dict
)
# Ensure the model is in evaluation mode
model.eval()
logging.info("Model set to evaluation mode")
# Define the prediction function
def predict(image):
logging.info("Starting prediction")
logging.info(f"Input image shape: {image.size}")
# Preprocess the image
logging.info("Preprocessing image")
inputs = feature_extractor(images=image, return_tensors="pt")
logging.info(f"Preprocessed input shape: {inputs['pixel_values'].shape}")
logging.info("Running inference")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits[0], dim=0)
logging.info(f"Raw logits: {logits}")
logging.info(f"Probabilities: {probabilities}")
# Prepare the output dictionary
result = {labels[i]: float(probabilities[i]) for i in range(len(labels))}
logging.info(f"Prediction result: {result}")
return result
# Set up the Gradio Interface
logging.info("Setting up Gradio interface")
gradio_app = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=6),
title="Sleeves Length Classifier"
)
# Launch the app
if __name__ == "__main__":
logging.info("Launching the app")
gradio_app.launch()