ishworrsubedii's picture
add: body classification
875a318 verified
import gradio as gr
import torch
from PIL import Image
import os
import time
from transformers import ResNetForImageClassification, AutoImageProcessor
# Load model and processor
processor = AutoImageProcessor.from_pretrained("glazzova/body_type")
model = ResNetForImageClassification.from_pretrained("glazzova/body_type")
# Load example images from the "template" folder
example_images = [
os.path.join("template", x) for x in os.listdir("template") if x.lower().endswith((".png", ".jpg", ".jpeg"))
]
# Define the classification function
def body_classification(image):
start_time = time.time() # Record start time
inputs = processor(image, return_tensors="pt") # Process the image
# Get predictions
with torch.no_grad():
logits = model(**inputs).logits
predicted_label = logits.argmax(-1).item()
label = model.config.id2label[predicted_label]
elapsed_time = time.time() - start_time # Calculate elapsed time
return label, f"{elapsed_time:.2f} seconds"
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Body Type Classifier")
gr.Markdown(
"""
Upload an image or use the example images to predict the body type.
The app uses a pre-trained ResNet model fine-tuned for body type classification.
**by Ishwor Subedi**
GitHub: [@ishworrsubedii](https://github.com/ishworrsubedii)
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
with gr.Column():
label_output = gr.Textbox(label="Predicted Body Type")
time_output = gr.Textbox(label="Processing Time (s)")
classify_button = gr.Button("Classify")
classify_button.click(body_classification, inputs=image_input, outputs=[label_output, time_output])
gr.Markdown("### Example Images")
# Add example images as inputs
gr.Examples(examples=example_images, inputs=image_input, label="Template Images")
# Run the app
if __name__ == "__main__":
demo.launch(debug=True)