File size: 2,100 Bytes
875a318 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import gradio as gr
import torch
from PIL import Image
import os
import time
from transformers import ResNetForImageClassification, AutoImageProcessor
# Load model and processor
processor = AutoImageProcessor.from_pretrained("glazzova/body_type")
model = ResNetForImageClassification.from_pretrained("glazzova/body_type")
# Load example images from the "template" folder
example_images = [
os.path.join("template", x) for x in os.listdir("template") if x.lower().endswith((".png", ".jpg", ".jpeg"))
]
# Define the classification function
def body_classification(image):
start_time = time.time() # Record start time
inputs = processor(image, return_tensors="pt") # Process the image
# Get predictions
with torch.no_grad():
logits = model(**inputs).logits
predicted_label = logits.argmax(-1).item()
label = model.config.id2label[predicted_label]
elapsed_time = time.time() - start_time # Calculate elapsed time
return label, f"{elapsed_time:.2f} seconds"
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Body Type Classifier")
gr.Markdown(
"""
Upload an image or use the example images to predict the body type.
The app uses a pre-trained ResNet model fine-tuned for body type classification.
**by Ishwor Subedi**
GitHub: [@ishworrsubedii](https://github.com/ishworrsubedii)
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
with gr.Column():
label_output = gr.Textbox(label="Predicted Body Type")
time_output = gr.Textbox(label="Processing Time (s)")
classify_button = gr.Button("Classify")
classify_button.click(body_classification, inputs=image_input, outputs=[label_output, time_output])
gr.Markdown("### Example Images")
# Add example images as inputs
gr.Examples(examples=example_images, inputs=image_input, label="Template Images")
# Run the app
if __name__ == "__main__":
demo.launch(debug=True)
|