JustKiddo commited on
Commit
4911187
·
verified ·
1 Parent(s): 47b843d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -29
app.py CHANGED
@@ -10,22 +10,65 @@ from datetime import datetime
10
  import json
11
  from collections import deque
12
  from datasets import load_dataset
13
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
14
- import torch # Import torch
15
 
16
  class BERTopicChatbot:
17
 
 
 
 
 
 
 
18
  def __init__(self, dataset_name, text_column, split="train", max_samples=10000):
19
  # Initialize BERT sentence transformer
20
  self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
21
 
22
- #Initialize BARTpho model and tokenizer
23
- self.bartpho_model_name = "vinai/bartpho-syllable"
24
-
25
- # Load tokenizer only once
26
- self.tokenizer = AutoTokenizer.from_pretrained(self.bartpho_model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Load Dataset and set other variables
 
29
  try:
30
  dataset = load_dataset(dataset_name, split=split)
31
  # Convert to pandas DataFrame and sample if necessary
@@ -62,13 +105,10 @@ class BERTopicChatbot:
62
  'total_documents': len(self.documents),
63
  'topics_found': len(set(self.topics))
64
  }
 
65
  except Exception as e:
66
  st.error(f"Error loading dataset: {str(e)}")
67
  raise
68
-
69
- #Load fine-tuned BARTpho model
70
- self.bartpho_model = AutoModelForSeq2SeqLM.from_pretrained("./bartpho_chatbot").to("cuda" if torch.cuda.is_available() else "cpu")
71
- self.bartpho_model.eval()
72
 
73
  def get_metrics_visualizations(self):
74
  """Generate visualizations for chatbot metrics"""
@@ -142,34 +182,48 @@ class BERTopicChatbot:
142
  def get_response(self, user_query):
143
  try:
144
  start_time = datetime.now()
145
-
146
- # Generate response with BARTpho
147
- input_ids = self.tokenizer(user_query, return_tensors="pt").input_ids.to(self.bartpho_model.device) #Send the tensor to the same device as the model.
148
-
149
- with torch.no_grad():
150
- outputs = self.bartpho_model.generate(input_ids, max_length=100, num_beams=5, early_stopping=True) # Tune max_length, num_beams
151
-
152
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
153
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  end_time = datetime.now()
155
  metrics = {
156
- 'similarity': 0.0, # Remove original implementation
157
  'response_time': (end_time - start_time).total_seconds(),
158
  'tokens': len(response.split()),
159
- 'topic': "N/A", # Remove original implementation
160
- 'detected_condition': "N/A" # Remove original implementation
161
  }
162
-
163
  # Update metrics history
164
  self.metrics_history['similarities'].append(metrics['similarity'])
165
  self.metrics_history['response_times'].append(metrics['response_time'])
166
  self.metrics_history['token_counts'].append(metrics['tokens'])
167
- topic_id = "N/A" # Remove original implementation
168
  self.metrics_history['topics_accessed'][topic_id] = \
169
  self.metrics_history['topics_accessed'].get(topic_id, 0) + 1
170
-
171
  return response, metrics
172
-
173
  except Exception as e:
174
  return f"Error processing query: {str(e)}", {'error': str(e)}
175
 
@@ -191,7 +245,7 @@ class BERTopicChatbot:
191
  'dataset_info': None,
192
  'metrics': None
193
  }
194
-
195
  @st.cache_resource
196
  def initialize_chatbot(dataset_name, text_column, split="train", max_samples=10000):
197
  return BERTopicChatbot(dataset_name, text_column, split, max_samples)
 
10
  import json
11
  from collections import deque
12
  from datasets import load_dataset
 
 
13
 
14
  class BERTopicChatbot:
15
 
16
+ #Initialize chatbot with a Hugging Face dataset
17
+ #dataset_name: name of the dataset on Hugging Face (e.g., 'vietnam/legal')
18
+ #text_column: name of the column containing the text data
19
+ #split: which split of the dataset to use ('train', 'test', 'validation')
20
+ #max_samples: maximum number of samples to use (to manage memory)
21
+
22
  def __init__(self, dataset_name, text_column, split="train", max_samples=10000):
23
  # Initialize BERT sentence transformer
24
  self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
25
 
26
+ # Add label mapping
27
+ self.label_mapping = {
28
+ 0: 'BPD',
29
+ 1: 'bipolar',
30
+ 2: 'depression',
31
+ 3: 'Anxiety',
32
+ 4: 'schizophrenia',
33
+ 5: 'mentalillness'
34
+ }
35
+
36
+ # Add comfort responses
37
+ self.comfort_responses = {
38
+ 'BPD': [
39
+ "I understand BPD can be overwhelming. You're not alone in this journey.",
40
+ "Your feelings are valid. BPD is challenging, but there are people who understand.",
41
+ "Taking things one day at a time with BPD is okay. You're showing great strength."
42
+ ],
43
+ 'bipolar': [
44
+ "Bipolar disorder can feel like a roller coaster. Remember, stability is possible.",
45
+ "You're so strong for managing bipolar disorder. Take it one day at a time.",
46
+ "Both the highs and lows are temporary. You've gotten through them before."
47
+ ],
48
+ 'depression': [
49
+ "Depression is heavy, but you don't have to carry it alone.",
50
+ "Even small steps forward are progress. You're doing better than you think.",
51
+ "This feeling won't last forever. You've made it through difficult times before."
52
+ ],
53
+ 'Anxiety': [
54
+ "Your anxiety doesn't define you. You're stronger than your fears.",
55
+ "Remember to breathe. You're safe, and this feeling will pass.",
56
+ "It's okay to take things at your own pace. You're handling this well."
57
+ ],
58
+ 'schizophrenia': [
59
+ "You're not your diagnosis. You're a person first, and you matter.",
60
+ "Managing schizophrenia takes incredible strength. You're doing well.",
61
+ "There's support available, and you deserve all the help you need."
62
+ ],
63
+ 'mentalillness': [
64
+ "Mental health challenges don't define your worth. You are valuable.",
65
+ "Recovery isn't linear, and that's okay. Every step counts.",
66
+ "You're not alone in this journey. There's a community that understands."
67
+ ]
68
+ }
69
 
70
+
71
+ # Load dataset from Hugging Face
72
  try:
73
  dataset = load_dataset(dataset_name, split=split)
74
  # Convert to pandas DataFrame and sample if necessary
 
105
  'total_documents': len(self.documents),
106
  'topics_found': len(set(self.topics))
107
  }
108
+
109
  except Exception as e:
110
  st.error(f"Error loading dataset: {str(e)}")
111
  raise
 
 
 
 
112
 
113
  def get_metrics_visualizations(self):
114
  """Generate visualizations for chatbot metrics"""
 
182
  def get_response(self, user_query):
183
  try:
184
  start_time = datetime.now()
185
+
186
+ # Get most similar documents
187
+ similar_docs, similarities = self.get_most_similar_document(user_query)
188
+
189
+ # Get the label from the most similar document
190
+ most_similar_index = similarities.argmax()
191
+ label_index = int(self.df['label'].iloc[most_similar_index]) # Convert to int
192
+ condition = self.label_mapping[label_index] # Map the integer label to condition name
193
+
194
+ # Get comfort response
195
+ comfort_messages = self.comfort_responses[condition]
196
+ comfort_response = np.random.choice(comfort_messages)
197
+
198
+ # Calculate query topic for metrics
199
+ query_topic, _ = self.topic_model.transform([user_query])
200
+
201
+ # Combine information and comfort response
202
+ if max(similarities) < 0.5:
203
+ response = f"I sense you might be dealing with {condition}. {comfort_response}"
204
+ else:
205
+ response = f"{similar_docs[0]}\n\n{comfort_response}"
206
+
207
+ # Track metrics
208
  end_time = datetime.now()
209
  metrics = {
210
+ 'similarity': float(max(similarities)),
211
  'response_time': (end_time - start_time).total_seconds(),
212
  'tokens': len(response.split()),
213
+ 'topic': str(query_topic[0]),
214
+ 'detected_condition': condition
215
  }
216
+
217
  # Update metrics history
218
  self.metrics_history['similarities'].append(metrics['similarity'])
219
  self.metrics_history['response_times'].append(metrics['response_time'])
220
  self.metrics_history['token_counts'].append(metrics['tokens'])
221
+ topic_id = str(query_topic[0])
222
  self.metrics_history['topics_accessed'][topic_id] = \
223
  self.metrics_history['topics_accessed'].get(topic_id, 0) + 1
224
+
225
  return response, metrics
226
+
227
  except Exception as e:
228
  return f"Error processing query: {str(e)}", {'error': str(e)}
229
 
 
245
  'dataset_info': None,
246
  'metrics': None
247
  }
248
+
249
  @st.cache_resource
250
  def initialize_chatbot(dataset_name, text_column, split="train", max_samples=10000):
251
  return BERTopicChatbot(dataset_name, text_column, split, max_samples)