lukiod commited on
Commit
e0bdeef
·
verified ·
1 Parent(s): 54c24d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -74
app.py CHANGED
@@ -6,65 +6,110 @@ from transformers import T5Tokenizer, T5ForConditionalGeneration
6
  import gc
7
  from typing import List, Dict
8
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class ModelHandler:
11
  def __init__(self):
12
- self.model_name = "google/flan-t5-base"
13
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
- print(f"Using device: {self.device}")
 
 
15
  self.initialize_model()
16
 
17
  def initialize_model(self):
18
- try:
19
- print(f"Loading model: {self.model_name}")
20
- self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
21
- self.model = T5ForConditionalGeneration.from_pretrained(
22
- self.model_name,
23
- torch_dtype=torch.float32,
24
- low_cpu_mem_usage=True
25
- )
26
- self.model.to(self.device)
27
- print("Model loaded successfully")
28
- return True
29
- except Exception as e:
30
- print(f"Error initializing model: {str(e)}")
31
- return False
 
 
 
 
 
 
 
32
 
33
- def generate_response(self, prompt: str, max_length: int = 512) -> str:
34
- try:
35
- # Format prompt for T5
36
- formatted_prompt = f"Answer the health question: {prompt}"
37
 
38
- # Generate response
39
- input_ids = self.tokenizer(
40
- formatted_prompt,
41
- return_tensors="pt",
 
 
 
 
42
  truncation=True,
43
- max_length=512
44
- ).input_ids.to(self.device)
45
-
46
- outputs = self.model.generate(
47
- input_ids,
48
- max_length=max_length,
49
- min_length=20,
50
- num_beams=2,
51
- temperature=0.7,
52
- do_sample=True
53
  )
54
 
55
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Memory cleanup
58
- del outputs, input_ids
 
59
  gc.collect()
60
- if torch.cuda.is_available():
61
- torch.cuda.empty_cache()
62
-
63
- return response
64
-
65
  except Exception as e:
66
- print(f"Error in generate_response: {str(e)}")
67
- return "I apologize, but I encountered an error processing your request."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  class HealthData:
70
  def __init__(self):
@@ -78,14 +123,16 @@ class HealthData:
78
  **metrics
79
  })
80
  return True
81
- except:
 
82
  return False
83
 
84
  def add_medication(self, medication: Dict) -> bool:
85
  try:
86
  self.medications.append(medication)
87
  return True
88
- except:
 
89
  return False
90
 
91
  def get_health_context(self) -> str:
@@ -93,10 +140,12 @@ class HealthData:
93
 
94
  if self.metrics:
95
  latest = self.metrics[-1]
96
- context_parts.append(f"Recent Health Metrics (Date: {latest['Date']}):")
97
- context_parts.append(f"- Weight: {latest['Weight']} kg")
98
- context_parts.append(f"- Steps: {latest['Steps']}")
99
- context_parts.append(f"- Sleep: {latest['Sleep']} hours")
 
 
100
 
101
  if self.medications:
102
  context_parts.append("\nCurrent Medications:")
@@ -116,34 +165,30 @@ class HealthAssistant:
116
 
117
  def get_response(self, message: str, history: List = None) -> str:
118
  try:
 
 
119
  # Prepare context
120
  context = self.data.get_health_context()
121
 
122
- # Format prompt with context and history
123
- prompt = "Given the following context:\n"
124
- prompt += f"{context}\n\n"
125
-
126
- if history:
127
- prompt += "Previous conversation:\n"
128
- for user_msg, bot_msg in history[-3:]: # Last 3 exchanges
129
- prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
130
-
131
- prompt += f"Current question: {message}"
132
-
133
  # Get response
134
  response = self.model.generate_response(prompt)
135
 
136
- # Memory management
137
  if self.request_count % 5 == 0:
138
  gc.collect()
139
- if torch.cuda.is_available():
140
- torch.cuda.empty_cache()
141
-
142
  return response
143
 
144
  except Exception as e:
145
- print(f"Error in get_response: {str(e)}")
146
- return "I apologize, but I encountered an error. Please try again."
147
 
148
  class HealthAssistantUI:
149
  def __init__(self):
@@ -153,7 +198,7 @@ class HealthAssistantUI:
153
  if message.strip() == "":
154
  return "", history
155
 
156
- bot_message = self.assistant.get_response(message, history)
157
  history.append([message, bot_message])
158
  return "", history
159
 
@@ -183,8 +228,13 @@ class HealthAssistantUI:
183
  return "❌ Error adding medication", None
184
 
185
  def create_interface(self):
186
- with gr.Blocks(title="Virtual Health Assistant", theme=gr.themes.Soft()) as demo:
187
- gr.Markdown("# 🏥 Virtual Health Assistant")
 
 
 
 
 
188
 
189
  with gr.Tabs():
190
  # Chat Interface
@@ -251,12 +301,38 @@ class HealthAssistantUI:
251
  outputs=[med_status, meds_display]
252
  )
253
 
 
 
 
 
 
 
 
 
254
  return demo
255
 
 
 
 
 
 
256
  def main():
257
- ui = HealthAssistantUI()
258
- demo = ui.create_interface()
259
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
  if __name__ == "__main__":
262
  main()
 
6
  import gc
7
  from typing import List, Dict
8
  import os
9
+ import time
10
+ import logging
11
+
12
+ # Setup logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Disable gradient computation and set memory efficient settings
17
+ torch.set_grad_enabled(False)
18
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
19
+
20
+ # Create cache directory
21
+ os.makedirs("model_cache", exist_ok=True)
22
 
23
  class ModelHandler:
24
  def __init__(self):
25
+ self.model_name = "google/flan-t5-small" # Small model for Spaces
26
+ self.device = "cpu"
27
+ self.initialized = False
28
+ self.load_attempts = 0
29
+ self.max_attempts = 3
30
  self.initialize_model()
31
 
32
  def initialize_model(self):
33
+ while not self.initialized and self.load_attempts < self.max_attempts:
34
+ try:
35
+ logger.info(f"Loading model attempt {self.load_attempts + 1}")
36
+ self.tokenizer = T5Tokenizer.from_pretrained(
37
+ self.model_name,
38
+ model_max_length=512,
39
+ cache_dir="model_cache"
40
+ )
41
+ self.model = T5ForConditionalGeneration.from_pretrained(
42
+ self.model_name,
43
+ low_cpu_mem_usage=True,
44
+ cache_dir="model_cache"
45
+ )
46
+ self.initialized = True
47
+ logger.info("Model loaded successfully")
48
+ return True
49
+ except Exception as e:
50
+ logger.error(f"Loading attempt failed: {str(e)}")
51
+ self.load_attempts += 1
52
+ time.sleep(1)
53
+ return False
54
 
55
+ def generate_response(self, prompt: str, max_length: int = 256) -> str:
56
+ if not self.initialized:
57
+ return "Model initialization failed. Using basic responses."
 
58
 
59
+ try:
60
+ clean_prompt = prompt.strip()
61
+ if len(clean_prompt) == 0:
62
+ return "Please provide a valid question."
63
+
64
+ inputs = self.tokenizer(
65
+ clean_prompt,
66
+ max_length=512,
67
  truncation=True,
68
+ padding=True,
69
+ return_tensors="pt"
 
 
 
 
 
 
 
 
70
  )
71
 
72
+ with torch.no_grad():
73
+ outputs = self.model.generate(
74
+ input_ids=inputs["input_ids"],
75
+ max_length=max_length,
76
+ min_length=10,
77
+ num_beams=1,
78
+ do_sample=True,
79
+ temperature=0.7,
80
+ top_k=50,
81
+ top_p=0.95,
82
+ )
83
 
84
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
85
+
86
+ del outputs, inputs
87
  gc.collect()
88
+
89
+ return response if response else "Could not generate a response."
90
+
 
 
91
  except Exception as e:
92
+ logger.error(f"Generation error: {str(e)}")
93
+ return self.get_fallback_response(prompt)
94
+
95
+ def get_fallback_response(self, query: str) -> str:
96
+ responses = {
97
+ "hello": "Hello! I'm your health assistant.",
98
+ "help": "I can help with health information and tracking.",
99
+ "health": "I provide general health information.",
100
+ "sleep": "Aim for 7-9 hours of sleep daily.",
101
+ "exercise": "Regular exercise is important for health.",
102
+ "diet": "Eat a balanced diet with plenty of vegetables.",
103
+ "medication": "Always follow prescribed medication schedules.",
104
+ "water": "Stay hydrated by drinking plenty of water daily.",
105
+ "stress": "Managing stress is important for overall health."
106
+ }
107
+
108
+ query = query.lower()
109
+ for key, response in responses.items():
110
+ if key in query:
111
+ return response
112
+ return "I understand you have a health question. Please try rephrasing it simply."
113
 
114
  class HealthData:
115
  def __init__(self):
 
123
  **metrics
124
  })
125
  return True
126
+ except Exception as e:
127
+ logger.error(f"Error adding metrics: {str(e)}")
128
  return False
129
 
130
  def add_medication(self, medication: Dict) -> bool:
131
  try:
132
  self.medications.append(medication)
133
  return True
134
+ except Exception as e:
135
+ logger.error(f"Error adding medication: {str(e)}")
136
  return False
137
 
138
  def get_health_context(self) -> str:
 
140
 
141
  if self.metrics:
142
  latest = self.metrics[-1]
143
+ context_parts.extend([
144
+ f"Recent Health Metrics (Date: {latest['Date']}):",
145
+ f"- Weight: {latest['Weight']} kg",
146
+ f"- Steps: {latest['Steps']}",
147
+ f"- Sleep: {latest['Sleep']} hours"
148
+ ])
149
 
150
  if self.medications:
151
  context_parts.append("\nCurrent Medications:")
 
165
 
166
  def get_response(self, message: str, history: List = None) -> str:
167
  try:
168
+ self.request_count += 1
169
+
170
  # Prepare context
171
  context = self.data.get_health_context()
172
 
173
+ # Format prompt
174
+ prompt = (
175
+ f"Context: {context}\n\n"
176
+ f"Question: {message}\n\n"
177
+ "Provide a helpful and accurate health-related response."
178
+ )
179
+
 
 
 
 
180
  # Get response
181
  response = self.model.generate_response(prompt)
182
 
183
+ # Periodic cleanup
184
  if self.request_count % 5 == 0:
185
  gc.collect()
186
+
 
 
187
  return response
188
 
189
  except Exception as e:
190
+ logger.error(f"Error in get_response: {str(e)}")
191
+ return self.model.get_fallback_response(message)
192
 
193
  class HealthAssistantUI:
194
  def __init__(self):
 
198
  if message.strip() == "":
199
  return "", history
200
 
201
+ bot_message = self.assistant.get_response(message)
202
  history.append([message, bot_message])
203
  return "", history
204
 
 
228
  return "❌ Error adding medication", None
229
 
230
  def create_interface(self):
231
+ with gr.Blocks(title="Health Assistant", theme=gr.themes.Soft()) as demo:
232
+ gr.Markdown(
233
+ """
234
+ # 🏥 Health Assistant
235
+ Your AI-powered health companion. Track metrics, manage medications, and get health information.
236
+ """
237
+ )
238
 
239
  with gr.Tabs():
240
  # Chat Interface
 
301
  outputs=[med_status, meds_display]
302
  )
303
 
304
+ gr.Markdown(
305
+ """
306
+ ### ⚠️ Important Note
307
+ This is an AI assistant for general health information only.
308
+ Always consult healthcare professionals for medical advice.
309
+ """
310
+ )
311
+
312
  return demo
313
 
314
+ def cleanup():
315
+ """Cleanup function for memory management"""
316
+ gc.collect()
317
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
318
+
319
  def main():
320
+ try:
321
+ logger.info("Starting Health Assistant")
322
+ ui = HealthAssistantUI()
323
+ demo = ui.create_interface()
324
+
325
+ # Register cleanup
326
+ demo.load(cleanup)
327
+
328
+ # Launch app
329
+ demo.launch(
330
+ share=False,
331
+ enable_queue=True,
332
+ max_threads=4
333
+ )
334
+ except Exception as e:
335
+ logger.error(f"Error starting app: {str(e)}")
336
 
337
  if __name__ == "__main__":
338
  main()