KhadijaAsehnoune12 commited on
Commit
bc18618
·
verified ·
1 Parent(s): 7391848

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -9
app.py CHANGED
@@ -1,10 +1,46 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
-
4
- pipe = pipeline(task="image-classification",
5
- model="KhadijaAsehnoune12/OrangeLeafDiseaseDetector")
6
- gr.Interface.from_pipeline(pipe,
7
- title="Orange Disease Image Classification",
8
- description="Detect diseases in orange leaves and fruits.",
9
- examples = ['MoucheB.jpg', 'verdissement.jpg',],
10
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
4
+ from PIL import Image
5
+
6
+ # Define the model and feature extractor
7
+ model_name = "KhadijaAsehnoune12/OrangeLeafDiseaseDetector"
8
+ model = ViTForImageClassification.from_pretrained(model_name)
9
+ feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
10
+
11
+ # Define the label mapping
12
+ id2label = {
13
+ "0": "Aleurocanthus spiniferus",
14
+ "1": "Chancre citrique",
15
+ "2": "Cochenille blanche",
16
+ "3": "Dépérissement des agrumes",
17
+ "4": "Feuille saine",
18
+ "5": "Jaunissement des feuilles",
19
+ "6": "Maladie de l'oïdium",
20
+ "7": "Maladie du dragon jaune",
21
+ "8": "Mineuse des agrumes",
22
+ "9": "Trou de balle"
23
+ }
24
+
25
+ def predict(image):
26
+ # Preprocess the image
27
+ inputs = feature_extractor(images=image, return_tensors="pt")
28
+
29
+ # Forward pass through the model
30
+ outputs = model(**inputs)
31
+
32
+ # Get the predicted label
33
+ logits = outputs.logits
34
+ predicted_class_idx = logits.argmax(-1).item()
35
+
36
+ # Get the label name
37
+ predicted_label = id2label[str(predicted_class_idx)]
38
+
39
+ return predicted_label
40
+
41
+ # Create the Gradio interface
42
+ image = gr.Image(type="pil")
43
+ label = gr.Label(num_top_classes=3)
44
+
45
+ gr.Interface(fn=predict, inputs=image, outputs=label, title="Citrus Disease Classification",
46
+ description="Upload an image of a citrus leaf or fruit to classify its disease.").launch()