Dhahlan2000 commited on
Commit
5be0841
·
verified ·
1 Parent(s): 7e3e2aa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import ViTForImageClassification, ViTFeatureExtractor
3
+ import gradio as gr
4
+ from PIL import Image
5
+
6
+ # Check if GPU is available
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ # Load pre-trained ViT model from Hugging Face
10
+ model = ViTForImageClassification.from_pretrained('Dhahlan2000/ripeness_detection', num_labels=20)
11
+ model.to(device)
12
+ model.eval()
13
+
14
+ # Load ViT feature extractor
15
+ feature_extractor = ViTFeatureExtractor.from_pretrained('Dhahlan2000/ripeness_detection')
16
+
17
+ # Class labels
18
+ predicted_classes = [
19
+ 'FreshApple', 'FreshBanana', 'FreshBellpepper', 'FreshCarrot', 'FreshCucumber', 'FreshMango', 'FreshOrange',
20
+ 'FreshPotato', 'FreshStrawberry', 'FreshTomato', 'RottenApple', 'RottenBanana', 'RottenBellpepper', 'RottenCarrot',
21
+ 'RottenCucumber', 'RottenMango', 'RottenOrange', 'RottenPotato', 'RottenStrawberry', 'RottenTomato']
22
+
23
+ # Function for inference
24
+ def classify_fruit(image):
25
+ inputs = feature_extractor(images=image, return_tensors="pt").to(device)
26
+ with torch.no_grad():
27
+ outputs = model(**inputs)
28
+ logits = outputs.logits
29
+ predicted_class_idx = logits.argmax(-1).item()
30
+ return predicted_classes[predicted_class_idx]
31
+
32
+ # Gradio UI
33
+ demo = gr.Interface(
34
+ fn=classify_fruit,
35
+ inputs=gr.Image(type="pil"),
36
+ outputs=gr.Label(),
37
+ title="Fruit Ripeness Detection",
38
+ description="Upload an image of a fruit to determine whether it's fresh or rotten."
39
+ )
40
+
41
+ demo.launch()