|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
import os |
|
import time |
|
from transformers import ResNetForImageClassification, AutoImageProcessor |
|
|
|
|
|
processor = AutoImageProcessor.from_pretrained("glazzova/body_type") |
|
model = ResNetForImageClassification.from_pretrained("glazzova/body_type") |
|
|
|
|
|
example_images = [ |
|
os.path.join("template", x) for x in os.listdir("template") if x.lower().endswith((".png", ".jpg", ".jpeg")) |
|
] |
|
|
|
|
|
def body_classification(image): |
|
start_time = time.time() |
|
inputs = processor(image, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
|
|
predicted_label = logits.argmax(-1).item() |
|
label = model.config.id2label[predicted_label] |
|
elapsed_time = time.time() - start_time |
|
|
|
return label, f"{elapsed_time:.2f} seconds" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Body Type Classifier") |
|
gr.Markdown( |
|
""" |
|
Upload an image or use the example images to predict the body type. |
|
The app uses a pre-trained ResNet model fine-tuned for body type classification. |
|
|
|
**by Ishwor Subedi** |
|
GitHub: [@ishworrsubedii](https://github.com/ishworrsubedii) |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(type="pil", label="Upload Image") |
|
with gr.Column(): |
|
label_output = gr.Textbox(label="Predicted Body Type") |
|
time_output = gr.Textbox(label="Processing Time (s)") |
|
|
|
classify_button = gr.Button("Classify") |
|
classify_button.click(body_classification, inputs=image_input, outputs=[label_output, time_output]) |
|
|
|
gr.Markdown("### Example Images") |
|
|
|
gr.Examples(examples=example_images, inputs=image_input, label="Template Images") |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |
|
|