ishworrsubedii commited on
Commit
875a318
1 Parent(s): f04a4e9

add: body classification

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import os
5
+ import time
6
+ from transformers import ResNetForImageClassification, AutoImageProcessor
7
+
8
+ # Load model and processor
9
+ processor = AutoImageProcessor.from_pretrained("glazzova/body_type")
10
+ model = ResNetForImageClassification.from_pretrained("glazzova/body_type")
11
+
12
+ # Load example images from the "template" folder
13
+ example_images = [
14
+ os.path.join("template", x) for x in os.listdir("template") if x.lower().endswith((".png", ".jpg", ".jpeg"))
15
+ ]
16
+
17
+ # Define the classification function
18
+ def body_classification(image):
19
+ start_time = time.time() # Record start time
20
+ inputs = processor(image, return_tensors="pt") # Process the image
21
+
22
+ # Get predictions
23
+ with torch.no_grad():
24
+ logits = model(**inputs).logits
25
+
26
+ predicted_label = logits.argmax(-1).item()
27
+ label = model.config.id2label[predicted_label]
28
+ elapsed_time = time.time() - start_time # Calculate elapsed time
29
+
30
+ return label, f"{elapsed_time:.2f} seconds"
31
+
32
+ # Create the Gradio interface
33
+ with gr.Blocks() as demo:
34
+ gr.Markdown("# Body Type Classifier")
35
+ gr.Markdown(
36
+ """
37
+ Upload an image or use the example images to predict the body type.
38
+ The app uses a pre-trained ResNet model fine-tuned for body type classification.
39
+
40
+ **by Ishwor Subedi**
41
+ GitHub: [@ishworrsubedii](https://github.com/ishworrsubedii)
42
+ """
43
+ )
44
+
45
+ with gr.Row():
46
+ with gr.Column():
47
+ image_input = gr.Image(type="pil", label="Upload Image")
48
+ with gr.Column():
49
+ label_output = gr.Textbox(label="Predicted Body Type")
50
+ time_output = gr.Textbox(label="Processing Time (s)")
51
+
52
+ classify_button = gr.Button("Classify")
53
+ classify_button.click(body_classification, inputs=image_input, outputs=[label_output, time_output])
54
+
55
+ gr.Markdown("### Example Images")
56
+ # Add example images as inputs
57
+ gr.Examples(examples=example_images, inputs=image_input, label="Template Images")
58
+
59
+ # Run the app
60
+ if __name__ == "__main__":
61
+ demo.launch(debug=True)