izhan001 commited on
Commit
e06a9ce
·
verified ·
1 Parent(s): 4c5b858

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoTokenizer, AutoModelForSeq2SeqLM
3
+ from datasets import load_dataset
4
+ from sklearn.model_selection import train_test_split
5
+ import torch
6
+ from PIL import Image
7
+ from torch.utils.data import Dataset
8
+
9
+ # Step 1: Load the World Cuisines dataset
10
+ ds = load_dataset("worldcuisines/food-kb")
11
+
12
+ # Access the 'main' dataset
13
+ dataset = ds['main']
14
+
15
+ # Check the structure of the dataset
16
+ print(dataset)
17
+
18
+ # Converting dataset to a list of dictionaries for easier manipulation
19
+ data_list = dataset.to_dict()['image1'] # Accessing the first image column (you can access others like image2, etc.)
20
+
21
+ # Now split the dataset into train and test
22
+ train_data, test_data = train_test_split(data_list, test_size=0.2)
23
+
24
+ # Check the shapes of train_data and test_data
25
+ print(f"Training data size: {len(train_data)}")
26
+ print(f"Testing data size: {len(test_data)}")
27
+
28
+ # Define a custom dataset class for the image classification task
29
+ class FoodDataset(Dataset):
30
+ def __init__(self, dataset, processor, max_length=256):
31
+ self.dataset = dataset
32
+ self.processor = processor
33
+ self.max_length = max_length
34
+
35
+ def __len__(self):
36
+ return len(self.dataset)
37
+
38
+ def __getitem__(self, idx):
39
+ item = self.dataset[idx]
40
+ # For simplicity, let's use image1 for training and test
41
+ image = Image.open(item['image1']) # Assuming 'image1' has the food images
42
+ label = item['fine_categories'] # You can modify this based on the label
43
+
44
+ # Process the image
45
+ encoding = self.processor(images=image, return_tensors="pt", padding=True, truncation=True)
46
+
47
+ # Return the input and target labels
48
+ return {
49
+ 'input_ids': encoding['input_ids'].squeeze(),
50
+ 'attention_mask': encoding['attention_mask'].squeeze(),
51
+ 'labels': label # Assuming that 'fine_categories' is used as labels
52
+ }
53
+
54
+ # Step 2: Load the ViT model for image classification
55
+ processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
56
+ vit_model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
57
+
58
+ # Step 3: Load the text generation model (Gemini) for nutrition breakdown and diet plan
59
+ tokenizer = AutoTokenizer.from_pretrained("describeai/gemini")
60
+ gemini_model = AutoModelForSeq2SeqLM.from_pretrained("describeai/gemini")
61
+
62
+ # Helper function to get nutritional breakdown and allergen information
63
+ def get_nutrition_and_allergens(food_name):
64
+ # Look for the food item in the dataset
65
+ result = None
66
+ try:
67
+ dataset = ds['main'] # Access the correct dataset split
68
+ for item in dataset:
69
+ if food_name.lower() in item['name'].lower():
70
+ result = item
71
+ break
72
+
73
+ if result:
74
+ nutrition_info = result.get('nutrition', 'Nutrition information not available')
75
+ allergens = result.get('allergens', 'Allergen information not available')
76
+ diet_plan = f"This item is suitable for a diet including {result.get('suitable_for', 'N/A')}."
77
+ else:
78
+ nutrition_info = "Food item not found in the database."
79
+ allergens = "Allergen information not available."
80
+ diet_plan = "Diet plan not available for this food item."
81
+
82
+ except KeyError as e:
83
+ nutrition_info = f"Key error: {e}"
84
+ allergens = "Allergen information not available."
85
+ diet_plan = "Diet plan not available."
86
+
87
+ except Exception as e:
88
+ nutrition_info = f"An error occurred: {str(e)}"
89
+ allergens = "Allergen information not available."
90
+ diet_plan = "Diet plan not available."
91
+
92
+ return nutrition_info, allergens, diet_plan
93
+
94
+ # Main prediction function for the image classification and text generation
95
+ def predict(image):
96
+ try:
97
+ # Step 1: Classify the food item in the image using ViT model
98
+ inputs = processor(images=image, return_tensors="pt")
99
+ outputs = vit_model(**inputs)
100
+
101
+ # Get the predicted label (food item)
102
+ predicted_label = outputs.logits.argmax(-1).item()
103
+
104
+ # Get the food name from the class labels (assuming the model has the food labels)
105
+ class_labels = vit_model.config.id2label # Get the class label mapping
106
+ food_item = class_labels[predicted_label]
107
+
108
+ # Step 2: Generate nutritional breakdown, allergens, and diet plan
109
+ nutrition_info, allergens, diet_plan = get_nutrition_and_allergens(food_item)
110
+
111
+ # Step 3: Generate a detailed description using the Gemini model
112
+ description_input = f"Nutritional breakdown and diet plan for {food_item}"
113
+ diet_plan_text = tokenizer(description_input, return_tensors="pt", padding=True, truncation=True)
114
+ diet_plan_output = gemini_model.generate(**diet_plan_text)
115
+ diet_plan_text = tokenizer.decode(diet_plan_output[0], skip_special_tokens=True)
116
+
117
+ # Combine results into a single output
118
+ response = f"**Detected Food:** {food_item}\n\n"
119
+ response += f"**Nutrition Info:** {nutrition_info}\n\n"
120
+ response += f"**Allergens:** {allergens}\n\n"
121
+ response += f"**Diet Plan:** {diet_plan}\n\n"
122
+ response += f"**Detailed Diet Plan and Breakdown:** {diet_plan_text}"
123
+
124
+ except Exception as e:
125
+ response = f"Error: {str(e)}"
126
+
127
+ return response
128
+
129
+ # Step 4: Gradio Interface
130
+ interface = gr.Interface(
131
+ fn=predict,
132
+ inputs=gr.Image(type="pil"),
133
+ outputs="text",
134
+ title="NutriScan: AI-Powered Food Analyzer",
135
+ description="Upload an image of food, and get a nutritional breakdown, allergen information, and diet plan recommendations.",
136
+ examples=[["path_to_example_image.jpg"]] # replace with paths to example images if needed
137
+ )
138
+
139
+ # Launch the Gradio interface
140
+ if __name__ == "__main__":
141
+ interface.launch()