lukiod commited on
Commit
f973312
·
verified ·
1 Parent(s): 2bf0817

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -185
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import logging
5
  from typing import List, Dict
6
  import gc
@@ -13,232 +13,113 @@ logging.basicConfig(
13
  )
14
  logger = logging.getLogger(__name__)
15
 
16
- # Set torch threads
17
- torch.set_num_threads(4)
18
 
19
  class HealthAssistant:
20
- def __init__(self, use_smaller_model=True):
21
- if use_smaller_model:
22
- self.model_name = "Qwen/Qwen2-VL-7B-Instruct"
23
- else:
24
- self.model_name = "Qwen/Qwen2-VL-7B-Instruct"
25
-
26
  self.model = None
27
  self.tokenizer = None
 
28
  self.metrics = []
29
  self.medications = []
30
  self.initialize_model()
31
 
32
  def initialize_model(self):
33
  try:
34
- logger.info(f"Starting model initialization: {self.model_name}")
35
 
 
36
  self.tokenizer = AutoTokenizer.from_pretrained(
37
- self.model_name,
38
  trust_remote_code=True
39
  )
40
  logger.info("Tokenizer loaded")
41
 
 
42
  self.model = AutoModelForCausalLM.from_pretrained(
43
- self.model_name,
44
- torch_dtype=torch.float32,
45
- low_cpu_mem_usage=True,
46
  trust_remote_code=True
47
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- if self.tokenizer.pad_token is None:
50
- self.tokenizer.pad_token = self.tokenizer.eos_token
51
-
52
- self.model = self.model.to("cpu")
53
- logger.info("Model loaded successfully")
54
  return True
55
 
56
  except Exception as e:
57
  logger.error(f"Error in model initialization: {str(e)}")
58
  raise
59
 
60
- def _detect_query_type(self, message: str) -> str:
61
- """Detect type of medical query"""
62
- message_lower = message.lower()
63
-
64
- emergency_keywords = ["emergency", "severe pain", "chest pain", "can't breathe",
65
- "unconscious", "stroke", "heart attack"]
66
- if any(keyword in message_lower for keyword in emergency_keywords):
67
- return "emergency_guidance"
 
68
 
69
- symptom_keywords = ["symptom", "feel", "pain", "ache", "suffering", "experiencing"]
70
- if any(keyword in message_lower for keyword in symptom_keywords):
71
- return "symptom_check"
 
 
 
72
 
73
- medication_keywords = ["medicine", "drug", "pill", "prescription", "medication", "dose"]
74
- if any(keyword in message_lower for keyword in medication_keywords):
75
- return "medication_info"
 
76
 
77
- lifestyle_keywords = ["exercise", "diet", "sleep", "stress", "healthy", "lifestyle"]
78
- if any(keyword in message_lower for keyword in lifestyle_keywords):
79
- return "lifestyle_advice"
80
-
81
- return "general"
82
-
83
- def _prepare_medical_prompt(self, message: str, query_type: str) -> str:
84
- """Prepare medical prompt based on query type"""
85
- base_context = self._get_health_context()
86
-
87
- prompts = {
88
- "symptom_check": f"""You are a medical AI assistant. Based on the following health context and symptoms, provide a careful analysis.
89
-
90
- Current Health Context:
91
- {base_context}
92
-
93
- Patient's Symptoms: {message}
94
-
95
- Provide a structured response covering:
96
- 1. Key symptoms identified
97
- 2. Possible common causes
98
- 3. General recommendations
99
- 4. Warning signs to watch for
100
- 5. When to seek medical care
101
-
102
- Remember to maintain a professional and careful tone.""",
103
-
104
- "medication_info": f"""You are a medical AI assistant. Provide information about the medication inquiry while noting you cannot give prescription advice.
105
-
106
- Current Health Context:
107
- {base_context}
108
-
109
- Medication Query: {message}
110
-
111
- Provide general information about:
112
- 1. Basic medication category/purpose
113
- 2. General usage patterns
114
- 3. Common considerations
115
- 4. Important precautions
116
- 5. When to consult a healthcare provider
117
-
118
- Remember to emphasize this is general information only.""",
119
-
120
- "emergency_guidance": f"""You are a medical AI assistant. This appears to be an urgent situation.
121
-
122
- Current Health Context:
123
- {base_context}
124
-
125
- Urgent Situation: {message}
126
-
127
- Provide immediate guidance:
128
- 1. Severity assessment
129
- 2. Immediate actions needed
130
- 3. Emergency warning signs
131
- 4. Whether to call emergency services
132
- 5. Precautions while waiting
133
-
134
- Always emphasize seeking immediate medical care for emergencies.""",
135
-
136
- "general": f"""You are a medical AI assistant. Provide helpful health information based on the query.
137
-
138
- Current Health Context:
139
- {base_context}
140
-
141
- Health Query: {message}
142
-
143
- Provide a structured response covering:
144
- 1. Understanding of the question
145
- 2. Relevant health information
146
- 3. General guidance
147
- 4. Important considerations
148
- 5. Additional recommendations"""
149
- }
150
-
151
- return prompts.get(query_type, prompts["general"])
152
 
153
  def generate_response(self, message: str, history: List = None) -> str:
154
  try:
155
- if not hasattr(self, 'model') or self.model is None:
156
- return "System is initializing. Please try again in a moment."
157
-
158
- # Detect query type
159
- query_type = self._detect_query_type(message)
160
-
161
  # Prepare prompt
162
- prompt = self._prepare_medical_prompt(message, query_type)
163
-
164
- # Add conversation history if available
165
- if history:
166
- prompt += "\n\nRecent conversation context:"
167
- for prev_msg, prev_response in history[-2:]:
168
- prompt += f"\nQ: {prev_msg}\nA: {prev_response}\n"
169
-
170
- # Tokenize
171
- inputs = self.tokenizer(
172
- prompt,
173
- return_tensors="pt",
174
- padding=True,
175
- truncation=True,
176
- max_length=512
177
- )
178
-
179
- # Generate
180
- with torch.no_grad():
181
- outputs = self.model.generate(
182
- inputs["input_ids"],
183
- max_new_tokens=150,
184
- num_beams=1,
185
- temperature=0.7,
186
- top_p=0.9,
187
- pad_token_id=self.tokenizer.pad_token_id,
188
- eos_token_id=self.tokenizer.eos_token_id
189
- )
190
-
191
- # Decode
192
- response = self.tokenizer.decode(
193
- outputs[0][inputs["input_ids"].shape[1]:],
194
- skip_special_tokens=True
195
- )
196
-
197
- # Format response
198
- response = self._format_response(response, query_type)
199
 
200
  # Cleanup
201
- del outputs, inputs
202
  gc.collect()
 
 
203
 
204
  return response.strip()
205
 
206
  except Exception as e:
207
  logger.error(f"Error generating response: {str(e)}")
208
- return "I apologize, but I encountered an error. Please try rephrasing your question."
209
-
210
- def _format_response(self, response: str, query_type: str) -> str:
211
- """Format and clean the response"""
212
- # Remove repeated headers
213
- lines = [line.strip() for line in response.split('\n') if line.strip()]
214
- clean_lines = []
215
- seen = set()
216
-
217
- for line in lines:
218
- if line not in seen:
219
- seen.add(line)
220
- clean_lines.append(line)
221
-
222
- # Add appropriate prefix based on query type
223
- prefixes = {
224
- "emergency_guidance": "🚨 URGENT: ",
225
- "symptom_check": "🔍 Analysis: ",
226
- "medication_info": "💊 Medication Info: ",
227
- "lifestyle_advice": "💡 Health Advice: ",
228
- "general": "ℹ️ "
229
- }
230
-
231
- prefix = prefixes.get(query_type, "ℹ️ ")
232
- formatted_response = prefix + "\n".join(clean_lines)
233
-
234
- # Add disclaimer for certain types
235
- if query_type in ["emergency_guidance", "medication_info"]:
236
- formatted_response += "\n\n⚠️ Note: This is general information only. Always consult healthcare professionals."
237
-
238
- return formatted_response
239
 
240
  def _get_health_context(self) -> str:
241
- """Get user's health context"""
242
  context_parts = []
243
 
244
  if self.metrics:
@@ -289,7 +170,7 @@ class GradioInterface:
289
  def __init__(self):
290
  try:
291
  logger.info("Initializing Health Assistant...")
292
- self.assistant = HealthAssistant(use_smaller_model=True)
293
  logger.info("Health Assistant initialized successfully")
294
  except Exception as e:
295
  logger.error(f"Failed to initialize Health Assistant: {e}")
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  import logging
5
  from typing import List, Dict
6
  import gc
 
13
  )
14
  logger = logging.getLogger(__name__)
15
 
16
+ # Set random seed for reproducibility
17
+ torch.random.manual_seed(0)
18
 
19
  class HealthAssistant:
20
+ def __init__(self):
21
+ self.model_id = "microsoft/Phi-3-small-128k-instruct"
 
 
 
 
22
  self.model = None
23
  self.tokenizer = None
24
+ self.pipe = None
25
  self.metrics = []
26
  self.medications = []
27
  self.initialize_model()
28
 
29
  def initialize_model(self):
30
  try:
31
+ logger.info(f"Loading model: {self.model_id}")
32
 
33
+ # Initialize tokenizer
34
  self.tokenizer = AutoTokenizer.from_pretrained(
35
+ self.model_id,
36
  trust_remote_code=True
37
  )
38
  logger.info("Tokenizer loaded")
39
 
40
+ # Initialize model
41
  self.model = AutoModelForCausalLM.from_pretrained(
42
+ self.model_id,
43
+ torch_dtype="auto",
 
44
  trust_remote_code=True
45
  )
46
+
47
+ # Set device
48
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
49
+ self.model = self.model.to(self.device)
50
+ logger.info(f"Model loaded on {self.device}")
51
+
52
+ # Setup pipeline
53
+ self.pipe = pipeline(
54
+ "text-generation",
55
+ model=self.model,
56
+ tokenizer=self.tokenizer,
57
+ device=self.device
58
+ )
59
+ logger.info("Pipeline created successfully")
60
 
 
 
 
 
 
61
  return True
62
 
63
  except Exception as e:
64
  logger.error(f"Error in model initialization: {str(e)}")
65
  raise
66
 
67
+ def _prepare_prompt(self, message: str, history: List = None) -> str:
68
+ """Prepare prompt with context and history"""
69
+ prompt_parts = [
70
+ "You are a medical AI assistant providing healthcare information and guidance.",
71
+ "Always be professional and include appropriate medical disclaimers.",
72
+ "\nCurrent Health Information:",
73
+ self._get_health_context(),
74
+ "\nConversation:"
75
+ ]
76
 
77
+ if history:
78
+ for prev_msg, prev_response in history[-3:]:
79
+ prompt_parts.extend([
80
+ f"Human: {prev_msg}",
81
+ f"Assistant: {prev_response}"
82
+ ])
83
 
84
+ prompt_parts.extend([
85
+ f"Human: {message}",
86
+ "Assistant:"
87
+ ])
88
 
89
+ return "\n".join(prompt_parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  def generate_response(self, message: str, history: List = None) -> str:
92
  try:
 
 
 
 
 
 
93
  # Prepare prompt
94
+ prompt = self._prepare_prompt(message, history)
95
+
96
+ # Generation configuration
97
+ generation_args = {
98
+ "max_new_tokens": 500,
99
+ "return_full_text": False,
100
+ "temperature": 0.7,
101
+ "do_sample": True,
102
+ "top_k": 50,
103
+ "top_p": 0.9,
104
+ "repetition_penalty": 1.1
105
+ }
106
+
107
+ # Generate response
108
+ output = self.pipe(prompt, **generation_args)
109
+ response = output[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  # Cleanup
 
112
  gc.collect()
113
+ if torch.cuda.is_available():
114
+ torch.cuda.empty_cache()
115
 
116
  return response.strip()
117
 
118
  except Exception as e:
119
  logger.error(f"Error generating response: {str(e)}")
120
+ return "I apologize, but I encountered an error. Please try again."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  def _get_health_context(self) -> str:
 
123
  context_parts = []
124
 
125
  if self.metrics:
 
170
  def __init__(self):
171
  try:
172
  logger.info("Initializing Health Assistant...")
173
+ self.assistant = HealthAssistant()
174
  logger.info("Health Assistant initialized successfully")
175
  except Exception as e:
176
  logger.error(f"Failed to initialize Health Assistant: {e}")