tahirsher commited on
Commit
7fe9a0d
·
verified ·
1 Parent(s): b102079

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -66
app.py CHANGED
@@ -8,67 +8,41 @@ import pickle
8
  import gradio as gr
9
  from nltk.stem.lancaster import LancasterStemmer
10
 
11
- # Ensure nltk downloads
12
  nltk.download('punkt')
13
  stemmer = LancasterStemmer()
14
 
15
- # Load intents and check for errors
16
- try:
17
- with open("intents.json") as file:
18
- data = json.load(file)
19
- except FileNotFoundError:
20
- raise FileNotFoundError("The file 'intents.json' was not found.")
21
 
22
- # Load or regenerate the data
23
  try:
24
  with open("data.pickle", "rb") as f:
25
  words, labels, training, output = pickle.load(f)
26
  except FileNotFoundError:
27
- # Regenerate the data.pickle if not found
28
- words = []
29
- labels = []
30
- docs_x = []
31
- docs_y = []
32
-
33
  for intent in data["intents"]:
34
  for pattern in intent["patterns"]:
35
  wrds = nltk.word_tokenize(pattern)
36
  words.extend(wrds)
37
  docs_x.append(wrds)
38
  docs_y.append(intent["tag"])
39
-
40
  if intent["tag"] not in labels:
41
  labels.append(intent["tag"])
42
 
43
- words = [stemmer.stem(w.lower()) for w in words if w not in ["?", "!", ".", ","]]
44
- words = sorted(list(set(words)))
45
  labels = sorted(labels)
46
 
47
- training = []
48
- output = []
49
-
50
- out_empty = [0 for _ in range(len(labels))]
51
-
52
  for x, doc in enumerate(docs_x):
53
- bag = []
54
-
55
- wrds = [stemmer.stem(w.lower()) for w in doc]
56
-
57
- for w in words:
58
- if w in wrds:
59
- bag.append(1)
60
- else:
61
- bag.append(0)
62
-
63
  output_row = out_empty[:]
64
  output_row[labels.index(docs_y[x])] = 1
65
-
66
  training.append(bag)
67
  output.append(output_row)
68
 
69
- training = np.array(training)
70
- output = np.array(output)
71
-
72
  with open("data.pickle", "wb") as f:
73
  pickle.dump((words, labels, training, output), f)
74
 
@@ -82,60 +56,44 @@ net = tflearn.regression(net)
82
  model = tflearn.DNN(net)
83
  try:
84
  model.load("MentalHealthChatBotmodel.tflearn")
85
- except Exception as e:
86
- raise FileNotFoundError("Model file 'MentalHealthChatBotmodel.tflearn' could not be loaded.") from e
 
87
 
88
- # Function to convert user input into a bag-of-words representation
89
  def bag_of_words(s, words):
90
- bag = [0 for _ in range(len(words))]
91
- s_words = nltk.word_tokenize(s)
92
- s_words = [stemmer.stem(word.lower()) for word in s_words]
93
-
94
  for se in s_words:
95
  for i, w in enumerate(words):
96
  if w == se:
97
  bag[i] = 1
98
  return np.array(bag)
99
 
100
- # Chat function
101
  def chat(message, history=None):
102
  history = history or []
103
- message = message.lower()
104
-
105
  try:
106
  bag = bag_of_words(message, words)
107
  results = model.predict([bag])
108
  results_index = np.argmax(results)
109
  tag = labels[results_index]
 
 
 
 
 
 
110
  except Exception as e:
111
- print(f"Error during processing: {e}") # Debugging
112
  response = "I'm sorry, I couldn't understand your message."
113
- history.append((message, response))
114
- return history, history
115
-
116
- for tg in data["intents"]:
117
- if tg['tag'] == tag:
118
- responses = tg['responses']
119
- response = random.choice(responses)
120
- break
121
- else:
122
- response = "I'm sorry, I don't have a response for that."
123
-
124
  history.append((message, response))
125
  return history, history
126
 
127
  # Gradio Interface
128
- css = """
129
- footer {display:none !important}
130
- div[data-testid="user"] {background-color: #253885 !important;}
131
- """
132
  demo = gr.Interface(
133
  fn=chat,
134
- inputs=[gr.Textbox(lines=1, label="Message"), gr.State([])],
135
  outputs=[gr.Chatbot(label="Chat"), gr.State()],
136
- allow_flagging="never",
137
- title="Wellbeing Chatbot",
138
- css=css
139
  )
140
 
141
  if __name__ == "__main__":
 
8
  import gradio as gr
9
  from nltk.stem.lancaster import LancasterStemmer
10
 
 
11
  nltk.download('punkt')
12
  stemmer = LancasterStemmer()
13
 
14
+ # Load intents
15
+ with open("intents.json") as file:
16
+ data = json.load(file)
 
 
 
17
 
18
+ # Load or regenerate data.pickle
19
  try:
20
  with open("data.pickle", "rb") as f:
21
  words, labels, training, output = pickle.load(f)
22
  except FileNotFoundError:
23
+ words, labels, docs_x, docs_y = [], [], [], []
 
 
 
 
 
24
  for intent in data["intents"]:
25
  for pattern in intent["patterns"]:
26
  wrds = nltk.word_tokenize(pattern)
27
  words.extend(wrds)
28
  docs_x.append(wrds)
29
  docs_y.append(intent["tag"])
 
30
  if intent["tag"] not in labels:
31
  labels.append(intent["tag"])
32
 
33
+ words = sorted(set(stemmer.stem(w.lower()) for w in words if w not in ["?", ".", ",", "!"]))
 
34
  labels = sorted(labels)
35
 
36
+ training, output = [], []
37
+ out_empty = [0] * len(labels)
 
 
 
38
  for x, doc in enumerate(docs_x):
39
+ bag = [1 if stemmer.stem(w.lower()) in [stemmer.stem(word) for word in doc] else 0 for w in words]
 
 
 
 
 
 
 
 
 
40
  output_row = out_empty[:]
41
  output_row[labels.index(docs_y[x])] = 1
 
42
  training.append(bag)
43
  output.append(output_row)
44
 
45
+ training, output = np.array(training), np.array(output)
 
 
46
  with open("data.pickle", "wb") as f:
47
  pickle.dump((words, labels, training, output), f)
48
 
 
56
  model = tflearn.DNN(net)
57
  try:
58
  model.load("MentalHealthChatBotmodel.tflearn")
59
+ except FileNotFoundError:
60
+ model.fit(training, output, n_epoch=1000, batch_size=8, show_metric=True)
61
+ model.save("MentalHealthChatBotmodel.tflearn")
62
 
63
+ # Define chat function
64
  def bag_of_words(s, words):
65
+ bag = [0] * len(words)
66
+ s_words = [stemmer.stem(w.lower()) for w in nltk.word_tokenize(s)]
 
 
67
  for se in s_words:
68
  for i, w in enumerate(words):
69
  if w == se:
70
  bag[i] = 1
71
  return np.array(bag)
72
 
 
73
  def chat(message, history=None):
74
  history = history or []
 
 
75
  try:
76
  bag = bag_of_words(message, words)
77
  results = model.predict([bag])
78
  results_index = np.argmax(results)
79
  tag = labels[results_index]
80
+ for tg in data["intents"]:
81
+ if tg['tag'] == tag:
82
+ response = random.choice(tg['responses'])
83
+ break
84
+ else:
85
+ response = "I'm sorry, I don't have a response for that."
86
  except Exception as e:
 
87
  response = "I'm sorry, I couldn't understand your message."
 
 
 
 
 
 
 
 
 
 
 
88
  history.append((message, response))
89
  return history, history
90
 
91
  # Gradio Interface
 
 
 
 
92
  demo = gr.Interface(
93
  fn=chat,
94
+ inputs=[gr.Textbox(lines=1, label="Message"), gr.State()],
95
  outputs=[gr.Chatbot(label="Chat"), gr.State()],
96
+ allow_flagging="never"
 
 
97
  )
98
 
99
  if __name__ == "__main__":