lukiod commited on
Commit
54c24d5
·
verified ·
1 Parent(s): 6ed2a87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -59
app.py CHANGED
@@ -2,28 +2,29 @@ import gradio as gr
2
  import pandas as pd
3
  from datetime import datetime
4
  import torch
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  import gc
7
- from typing import List, Dict, Optional
8
  import os
9
 
10
  class ModelHandler:
11
  def __init__(self):
12
- self.model_name = "google/flan-t5-large"
13
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
- self.tokenizer = None
15
- self.model = None
16
  self.initialize_model()
17
-
18
  def initialize_model(self):
19
  try:
20
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
21
- self.model = AutoModelForSeq2SeqLM.from_pretrained(
 
22
  self.model_name,
23
  torch_dtype=torch.float32,
24
  low_cpu_mem_usage=True
25
  )
26
  self.model.to(self.device)
 
27
  return True
28
  except Exception as e:
29
  print(f"Error initializing model: {str(e)}")
@@ -31,42 +32,39 @@ class ModelHandler:
31
 
32
  def generate_response(self, prompt: str, max_length: int = 512) -> str:
33
  try:
34
- gc.collect()
35
- if torch.cuda.is_available():
36
- torch.cuda.empty_cache()
37
-
38
- inputs = self.tokenizer(
39
- prompt,
40
  return_tensors="pt",
41
  truncation=True,
42
  max_length=512
43
- ).to(self.device)
44
 
45
- with torch.no_grad():
46
- outputs = self.model.generate(
47
- inputs.input_ids,
48
- max_length=max_length,
49
- num_beams=2,
50
- temperature=0.7,
51
- no_repeat_ngram_size=3,
52
- length_penalty=1.0
53
- )
54
 
55
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
56
 
57
- del outputs, inputs
 
58
  gc.collect()
59
  if torch.cuda.is_available():
60
  torch.cuda.empty_cache()
61
 
62
  return response
63
- except Exception as e:
64
- return f"Error generating response: {str(e)}"
65
 
66
- def clear_memory(self):
67
- gc.collect()
68
- if torch.cuda.is_available():
69
- torch.cuda.empty_cache()
70
 
71
  class HealthData:
72
  def __init__(self):
@@ -116,36 +114,36 @@ class HealthAssistant:
116
  self.data = HealthData()
117
  self.request_count = 0
118
 
119
- def _create_prompt(self, message: str, history: List = None) -> str:
120
- prompt_parts = ["You are a helpful healthcare assistant."]
121
-
122
- # Add health context
123
- health_context = self.data.get_health_context()
124
- if health_context != "No health data available.":
125
- prompt_parts.append(f"Current health information:\n{health_context}")
126
-
127
- # Add conversation history
128
- if history:
129
- prompt_parts.append("Previous conversation:")
130
- for user_msg, bot_msg in history[-3:]:
131
- prompt_parts.append(f"User: {user_msg}")
132
- prompt_parts.append(f"Assistant: {bot_msg}")
133
-
134
- # Add current question
135
- prompt_parts.append(f"User: {message}")
136
- prompt_parts.append("Assistant:")
137
-
138
- return "\n\n".join(prompt_parts)
139
-
140
  def get_response(self, message: str, history: List = None) -> str:
141
- self.request_count += 1
142
- prompt = self._create_prompt(message, history)
143
- response = self.model.generate_response(prompt)
144
-
145
- if self.request_count % 5 == 0:
146
- self.model.clear_memory()
 
147
 
148
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  class HealthAssistantUI:
151
  def __init__(self):
 
2
  import pandas as pd
3
  from datetime import datetime
4
  import torch
5
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
6
  import gc
7
+ from typing import List, Dict
8
  import os
9
 
10
  class ModelHandler:
11
  def __init__(self):
12
+ self.model_name = "google/flan-t5-base"
13
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"Using device: {self.device}")
 
15
  self.initialize_model()
16
+
17
  def initialize_model(self):
18
  try:
19
+ print(f"Loading model: {self.model_name}")
20
+ self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
21
+ self.model = T5ForConditionalGeneration.from_pretrained(
22
  self.model_name,
23
  torch_dtype=torch.float32,
24
  low_cpu_mem_usage=True
25
  )
26
  self.model.to(self.device)
27
+ print("Model loaded successfully")
28
  return True
29
  except Exception as e:
30
  print(f"Error initializing model: {str(e)}")
 
32
 
33
  def generate_response(self, prompt: str, max_length: int = 512) -> str:
34
  try:
35
+ # Format prompt for T5
36
+ formatted_prompt = f"Answer the health question: {prompt}"
37
+
38
+ # Generate response
39
+ input_ids = self.tokenizer(
40
+ formatted_prompt,
41
  return_tensors="pt",
42
  truncation=True,
43
  max_length=512
44
+ ).input_ids.to(self.device)
45
 
46
+ outputs = self.model.generate(
47
+ input_ids,
48
+ max_length=max_length,
49
+ min_length=20,
50
+ num_beams=2,
51
+ temperature=0.7,
52
+ do_sample=True
53
+ )
 
54
 
55
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
56
 
57
+ # Memory cleanup
58
+ del outputs, input_ids
59
  gc.collect()
60
  if torch.cuda.is_available():
61
  torch.cuda.empty_cache()
62
 
63
  return response
 
 
64
 
65
+ except Exception as e:
66
+ print(f"Error in generate_response: {str(e)}")
67
+ return "I apologize, but I encountered an error processing your request."
 
68
 
69
  class HealthData:
70
  def __init__(self):
 
114
  self.data = HealthData()
115
  self.request_count = 0
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def get_response(self, message: str, history: List = None) -> str:
118
+ try:
119
+ # Prepare context
120
+ context = self.data.get_health_context()
121
+
122
+ # Format prompt with context and history
123
+ prompt = "Given the following context:\n"
124
+ prompt += f"{context}\n\n"
125
 
126
+ if history:
127
+ prompt += "Previous conversation:\n"
128
+ for user_msg, bot_msg in history[-3:]: # Last 3 exchanges
129
+ prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
130
+
131
+ prompt += f"Current question: {message}"
132
+
133
+ # Get response
134
+ response = self.model.generate_response(prompt)
135
+
136
+ # Memory management
137
+ if self.request_count % 5 == 0:
138
+ gc.collect()
139
+ if torch.cuda.is_available():
140
+ torch.cuda.empty_cache()
141
+
142
+ return response
143
+
144
+ except Exception as e:
145
+ print(f"Error in get_response: {str(e)}")
146
+ return "I apologize, but I encountered an error. Please try again."
147
 
148
  class HealthAssistantUI:
149
  def __init__(self):