Spaces:
Build error
Build error
Quyet
commited on
Commit
•
f30862a
1
Parent(s):
de337bd
add euc 100 200 to chat loop, fix dialog model
Browse files
README.md
CHANGED
@@ -11,3 +11,14 @@ license: gpl-3.0
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
14 |
+
|
15 |
+
For more information about this product, please visit this notion [page](https://www.notion.so/AI-Consulting-Design-Scheme-0a9c5288820d4fec98ecc7cc1e84be51)) (you need to have permission to access this page)
|
16 |
+
|
17 |
+
# Notes
|
18 |
+
|
19 |
+
### 2022/12/20
|
20 |
+
|
21 |
+
- Chat flow will trigger euc 200 when detect a negative emotion with prob > threshold. Thus, only euc 100 and free chat consist of chat loop, while euc 200 will pop up sometimes. I set the trigger to NOT be regularly (currently one trigger once during the conversation), because trigger to much will bother users
|
22 |
+
- Already fix the problem with dialog model. Now it's configured as the same as what it should be. Of course, that does not guarantee of good response
|
23 |
+
- TODO is written in the main file already
|
24 |
+
- Successfully convert plain euc 100 and 200 to chat flow
|
app.py
CHANGED
@@ -1,3 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
import re, time
|
3 |
import matplotlib.pyplot as plt
|
@@ -5,228 +27,275 @@ from threading import Timer
|
|
5 |
import gradio as gr
|
6 |
|
7 |
import torch
|
8 |
-
from transformers import
|
9 |
-
GPT2LMHeadModel, GPT2Tokenizer,
|
10 |
-
AutoModelForSequenceClassification, AutoTokenizer,
|
11 |
-
pipeline
|
12 |
-
)
|
13 |
-
# reference: https://huggingface.co/spaces/bentrevett/emotion-prediction
|
14 |
-
# and https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT
|
15 |
-
# gradio vs streamlit https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
|
16 |
-
# https://gradio.app/interface_state/
|
17 |
-
|
18 |
-
def euc_100():
|
19 |
-
# 1,2,3. asks about the user's emotions and store data
|
20 |
-
print('How was your day?')
|
21 |
-
print('On the scale 1 to 10, how would you judge your emotion through the following categories:') # ~ Baymax :)
|
22 |
-
emotion_types = ['overall'] #, 'happiness', 'surprise', 'sadness', 'depression', 'anger', 'fear', 'anxiety']
|
23 |
-
emotion_degree = []
|
24 |
-
input_time = []
|
25 |
-
|
26 |
-
for e in emotion_types:
|
27 |
-
while True:
|
28 |
-
x = input(f'{e}: ')
|
29 |
-
if x.isnumeric() and (0 < int(x) < 11):
|
30 |
-
emotion_degree.append(int(x))
|
31 |
-
input_time.append(time.gmtime())
|
32 |
-
break
|
33 |
-
else:
|
34 |
-
print('invalid input, my friend :) plz input again')
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
'And what do you think about those feelings or emotions at that time?',
|
59 |
'Could you think of any evidence for your above-mentioned thought?',
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
y = 'No' # bad mood
|
64 |
-
while True:
|
65 |
-
x = input('Your answer (example of answer here): ')
|
66 |
-
if x == '': # need to change this part to waiting 10 seconds
|
67 |
-
print('Whether your bad mood is over?')
|
68 |
-
y = input('Your answer (Yes or No): ')
|
69 |
-
if y == 'Yes':
|
70 |
-
break
|
71 |
-
else:
|
72 |
-
break
|
73 |
-
if y == 'Yes':
|
74 |
-
print('Nice to hear that.')
|
75 |
-
break
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
return_all_scores=True, truncation=True)
|
89 |
-
return pipe
|
90 |
-
|
91 |
-
def plot_emotion_distribution(predictions):
|
92 |
-
fig, ax = plt.subplots()
|
93 |
-
ax.bar(x=[i for i, _ in enumerate(prediction)],
|
94 |
-
height=[p['score'] for p in prediction],
|
95 |
-
tick_label=[p['label'] for p in prediction])
|
96 |
-
ax.tick_params(rotation=90)
|
97 |
-
ax.set_ylim(0, 1)
|
98 |
-
plt.show()
|
99 |
-
|
100 |
-
def rulebase(text):
|
101 |
-
keywords = {
|
102 |
-
'life_safety': ['death', 'suicide', 'murder', 'to perish together', 'jump off the building'],
|
103 |
-
'immediacy': ['now', 'immediately', 'tomorrow', 'today'],
|
104 |
-
'manifestation': ['never stop', 'every moment', 'strong', 'very']
|
105 |
}
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
threshold = 0.3
|
136 |
-
emotion = {'label': 'sadness', 'score': 0.4} if testing else prediction[0]
|
137 |
-
# then judge
|
138 |
-
if emotion['label'] in ['surprise', 'sadness', 'anger', 'fear'] and emotion['score'] > threshold:
|
139 |
-
print(f'It has come to our attention that you may suffer from {emotion["label"]}')
|
140 |
-
print('If you want to know more about yourself, '
|
141 |
-
'some professional scales are provided to quantify your current status. '
|
142 |
-
'After a period of time (maybe a week/two months/a month) trying to follow the solutions we suggested, '
|
143 |
-
'you can fill out these scales again to see if you have improved.')
|
144 |
-
x = input('Fill in the form now (Okay or Later): ')
|
145 |
-
if x == 'Okay':
|
146 |
-
print('Display the form')
|
147 |
-
else:
|
148 |
-
print('Here are some reference articles about bad emotions. You can take a look :)')
|
149 |
-
|
150 |
-
# 4. If both of the above are not satisfied. What do u mean by 'satisfied' here?
|
151 |
-
questions = [
|
152 |
-
'What specific thing is bothering you the most right now?',
|
153 |
-
'Oh, I see. So when it is happening, what feelings or emotions have you got?',
|
154 |
-
'And what do you think about those feelings or emotions at that time?',
|
155 |
-
'Could you think of any evidence for your above-mentioned thought? #',
|
156 |
-
]
|
157 |
-
for q in questions:
|
158 |
-
print(q)
|
159 |
-
y = 'No' # bad mood
|
160 |
-
while True:
|
161 |
-
x = input('Your answer (example of answer here): ')
|
162 |
-
if x == '': # need to change this part to waiting 10 seconds
|
163 |
-
print('Whether your bad mood is over?')
|
164 |
-
y = input('Your answer (Yes or No): ')
|
165 |
if y == 'Yes':
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
else:
|
168 |
-
|
169 |
-
if y == 'Yes':
|
170 |
-
print('Nice to hear that.')
|
171 |
-
break
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
185 |
|
186 |
-
|
187 |
-
history['input_ids'] = torch.cat([history['input_ids'], message_ids], dim=-1)
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
print((message, response), bot_output_ids[0][-10:])
|
197 |
|
198 |
-
|
199 |
-
|
200 |
-
|
|
|
201 |
|
|
|
|
|
|
|
|
|
202 |
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
216 |
|
217 |
title = 'PsyPlus Empathetic Chatbot'
|
218 |
description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
|
219 |
-
chatbot = gr.Chatbot(value=
|
220 |
iface = gr.Interface(
|
221 |
-
chat,
|
222 |
-
|
223 |
-
[chatbot, 'state'],
|
224 |
-
# css=".gradio-container {background-color: white}",
|
225 |
-
allow_flagging='never',
|
226 |
-
title=title,
|
227 |
-
description=description,
|
228 |
)
|
|
|
|
|
229 |
if args.run_on_own_server == 0:
|
230 |
iface.launch(debug=True)
|
231 |
else:
|
232 |
-
iface.launch(debug=True, server_name='0.0.0.0', server_port=2022
|
|
|
1 |
+
'''
|
2 |
+
Dialog System of PsyPlus (dvq)
|
3 |
+
|
4 |
+
reference:
|
5 |
+
https://huggingface.co/spaces/bentrevett/emotion-prediction
|
6 |
+
https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT
|
7 |
+
https://huggingface.co/benjaminbeilharz/t5-empatheticdialogues
|
8 |
+
|
9 |
+
gradio vs streamlit
|
10 |
+
https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
|
11 |
+
https://gradio.app/interface_state/
|
12 |
+
|
13 |
+
TODO
|
14 |
+
Add diagram in Gradio Interface showing sentimate analysis
|
15 |
+
Gradio input timeout: cannot find a tutorial in Google -> don't know how to implement
|
16 |
+
Personalize: create database, load and save data
|
17 |
+
|
18 |
+
Run command
|
19 |
+
python app.py --run_on_own_server 1 --initial_chat_state free_chat
|
20 |
+
'''
|
21 |
+
|
22 |
+
|
23 |
import argparse
|
24 |
import re, time
|
25 |
import matplotlib.pyplot as plt
|
|
|
27 |
import gradio as gr
|
28 |
|
29 |
import torch
|
30 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
|
33 |
+
def option():
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
parser.add_argument('--run_on_own_server', type=int, default=0, help='if test on own server, need to use share mode')
|
36 |
+
parser.add_argument('--dialog_model', type=str, default='tareknaous/dialogpt-empathetic-dialogues')
|
37 |
+
parser.add_argument('--emotion_model', type=str, default='joeddav/distilbert-base-uncased-go-emotions-student')
|
38 |
+
parser.add_argument('--account', type=str, default=None)
|
39 |
+
parser.add_argument('--initial_chat_state', type=str, default='euc_100', choices=['euc_100', 'euc_200', 'free_chat'])
|
40 |
+
args = parser.parse_args()
|
41 |
+
return args
|
42 |
+
|
43 |
+
|
44 |
+
class ChatHelper: # store the list of messages that are showed in therapies
|
45 |
+
invalid_input = 'Invalid input, my friend :) Plz input again'
|
46 |
+
good_mood_over = 'Whether your good mood is over? Any other details that you would like to recall?'
|
47 |
+
good_case = 'Nice to hear that!'
|
48 |
+
bad_mood_over = 'Whether your bad mood is over? (Yes or No)'
|
49 |
+
not_answer = "It's okay, maybe you don't want to answer this question."
|
50 |
+
fill_form = ('It has come to our attention that you may suffer from {}.\n'
|
51 |
+
'If you want to know more about yourself, some professional scales are provided to quantify your current status.\n'
|
52 |
+
'After a period of time (maybe a week/two months/a month) trying to follow the solutions we suggested, '
|
53 |
+
'you can fill out these scales again to see if you have improved.\n'
|
54 |
+
'Do you want to fill in the form now? (Okay or Later)')
|
55 |
+
display_form = '<Display the form>.\n'
|
56 |
+
reference = 'Here are some reference articles about bad emotions. You can take a look :) <Display references>\n'
|
57 |
+
|
58 |
+
emotion_types = ['Overall', 'Happiness', 'Anxiety'] # 'Surprise', 'Sadness', 'Depression', 'Anger', 'Fear',
|
59 |
+
euc_100 = {
|
60 |
+
'q': emotion_types,
|
61 |
+
'good_mood': [
|
62 |
+
'You seem to be in a good mood today. Is there anything you could notice that makes you happy?',
|
63 |
+
'I am glad that you are willing to share the experience with me. Thanks for letting me know.',
|
64 |
+
],
|
65 |
+
'bad_mood': [
|
66 |
+
'You seem not to be in a good mood. What specific thing is bothering you the most right now?',
|
67 |
+
'I see. So when it is happening, what feelings or emotions have you got?',
|
68 |
'And what do you think about those feelings or emotions at that time?',
|
69 |
'Could you think of any evidence for your above-mentioned thought?',
|
70 |
+
'Here are some reference articles about bad emotions. You can take a look :)',
|
71 |
+
],
|
72 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
+
negative_emotions = ['remorse', 'nervousness', 'annoyance', 'anger', 'grief', 'fear', 'disapproval',
|
75 |
+
'confusion', 'embarrassment', 'disgust', 'sadness', 'disappointment']
|
76 |
+
euc_200 = 'Now go back to the last chat. You said that "{}".\n'
|
77 |
+
|
78 |
+
greeting_template = {
|
79 |
+
'euc_100': 'How was your day? On the scale 1 to 10, '
|
80 |
+
'how would you judge your emotion through the following categories:\nOverall',
|
81 |
+
# euc_200 is only trigger when you say smt more negative than a certain threshol
|
82 |
+
# thus the greeting here is only for debuging euc_200
|
83 |
+
'euc_200': fill_form.format('anxiety'),
|
84 |
+
'free_chat': 'Hi you! How is it going?',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
}
|
86 |
|
87 |
+
def plot_emotion_distribution(predictions):
|
88 |
+
fig, ax = plt.subplots()
|
89 |
+
ax.bar(x=[i for i, _ in enumerate(prediction)],
|
90 |
+
height=[p['score'] for p in prediction],
|
91 |
+
tick_label=[p['label'] for p in prediction])
|
92 |
+
ax.tick_params(rotation=90)
|
93 |
+
ax.set_ylim(0, 1)
|
94 |
+
plt.show()
|
95 |
+
|
96 |
+
def ed_rulebase(text):
|
97 |
+
keywords = {
|
98 |
+
'life_safety': ['death', 'suicide', 'murder', 'to perish together', 'jump off the building'],
|
99 |
+
'immediacy': ['now', 'immediately', 'tomorrow', 'today'],
|
100 |
+
'manifestation': ['never stop', 'every moment', 'strong', 'very']
|
101 |
+
}
|
102 |
+
|
103 |
+
# if found dangerous kw/topics
|
104 |
+
if re.search(rf"{'|'.join(keywords['life_safety'])}", text) != None and \
|
105 |
+
sum([re.search(rf"{'|'.join(keywords[k])}", text) != None for k in ['immediacy','manifestation']]) >= 1:
|
106 |
+
print('We noticed that you may need immediate professional assistance, would you like to make a phone call? '
|
107 |
+
'The Hong Kong Lifeline number is (852) 2382 0000')
|
108 |
+
x = input('Choose 1. "Dial to the number" or 2. "No dangerous emotion la": ')
|
109 |
+
if x == '1':
|
110 |
+
print('Let you connect to the office')
|
111 |
+
else:
|
112 |
+
print('Sorry for our misdetection. We just want to make sure that you could get immediate help when needed. '
|
113 |
+
'Would you mind if we send this conversation to the cloud to finetune the model.')
|
114 |
+
y = input('Yes or No: ')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
if y == 'Yes':
|
116 |
+
pass # do smt here
|
117 |
+
|
118 |
+
|
119 |
+
class TherapyChatBot:
|
120 |
+
def __init__(self, args):
|
121 |
+
# check state to control the dialog
|
122 |
+
self.chat_state = args.initial_chat_state # name of the chat function/therapy segment the model is in
|
123 |
+
self.message_prev = None
|
124 |
+
self.chat_state_prev = None
|
125 |
+
self.run_on_own_server = args.run_on_own_server
|
126 |
+
self.account = args.account
|
127 |
+
|
128 |
+
# additional attribute for euc_100
|
129 |
+
self.euc_100_input_time = []
|
130 |
+
self.euc_100_emotion_degree = []
|
131 |
+
self.already_trigger_euc_200 = False
|
132 |
+
|
133 |
+
# chat and emotion-detection models
|
134 |
+
self.ed_pipe = pipeline('text-classification', model=args.emotion_model, top_k=5, truncation=True)
|
135 |
+
self.ed_threshold = 0.3
|
136 |
+
self.dialog_model = GPT2LMHeadModel.from_pretrained(args.dialog_model)
|
137 |
+
self.dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.dialog_model)
|
138 |
+
self.eos = self.dialog_tokenizer.eos_token
|
139 |
+
# tokenizer.__call__ -> input_ids, attention_mask
|
140 |
+
# tokenizer.encode -> only inputs_ids, which is required by model.generate function
|
141 |
+
|
142 |
+
# chat history.
|
143 |
+
# TODO: if we want to personalize and save the conversation,
|
144 |
+
# we can load data from database
|
145 |
+
self.greeting = ChatHelper.greeting_template[self.chat_state]
|
146 |
+
self.history = {'input_ids': torch.tensor([[self.dialog_tokenizer.bos_token_id]]),
|
147 |
+
'text': [('', self.greeting)]} if not self.account else open(f'database/{hash(self.account)}', 'rb')
|
148 |
+
if 'euc_100' in self.chat_state:
|
149 |
+
self.chat_state = 'euc_100.q.0'
|
150 |
+
|
151 |
+
def __call__(self, message, prefix=''):
|
152 |
+
# if prefix != None, which means this function is called from euc_200, thus already detected the negative emotion
|
153 |
+
if (not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200:
|
154 |
+
prediction = self.ed_pipe(message)[0]
|
155 |
+
prediction = sorted(prediction, key=lambda x: x['score'], reverse=True)
|
156 |
+
if self.run_on_own_server:
|
157 |
+
print(prediction)
|
158 |
+
# plot_emotion_distribution(prediction)
|
159 |
+
emotion = prediction[0]
|
160 |
+
|
161 |
+
# if message is negative, change state immediately
|
162 |
+
if ((not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200) and \
|
163 |
+
(emotion['label'] in ChatHelper.negative_emotions and emotion['score'] > self.ed_threshold):
|
164 |
+
self.chat_state_prev = self.chat_state
|
165 |
+
self.chat_state = 'euc_200'
|
166 |
+
self.message_prev = message
|
167 |
+
self.already_trigger_euc_200 = True
|
168 |
+
response = ChatHelper.fill_form.format(emotion['label'])
|
169 |
+
|
170 |
+
# set up rule to update state inside each dialog function
|
171 |
+
elif self.chat_state.startswith('euc_100'):
|
172 |
+
response = self.euc_100(message)
|
173 |
+
if self.chat_state == 'free_chat':
|
174 |
+
last_two_turns_ids = self.dialog_tokenizer.encode(message + self.eos, return_tensors='pt')
|
175 |
+
self.history['input_ids'] = torch.cat([self.history['input_ids'], last_two_turns_ids], dim=-1)
|
176 |
+
|
177 |
+
elif self.chat_state.startswith('euc_200'):
|
178 |
+
return self.euc_200(message)
|
179 |
+
|
180 |
+
else: # free_chat
|
181 |
+
response = self.free_chat(message)
|
182 |
+
|
183 |
+
if prefix:
|
184 |
+
response = prefix + response
|
185 |
+
self.history['text'].append((self.message_prev, response))
|
186 |
+
else:
|
187 |
+
self.history['text'].append((message, response))
|
188 |
+
return self.history['text']
|
189 |
+
|
190 |
+
def euc_100(self, x):
|
191 |
+
_, subsection, entry = self.chat_state.split('.')
|
192 |
+
entry = int(entry)
|
193 |
+
|
194 |
+
if subsection == 'q':
|
195 |
+
if x.isnumeric() and (0 < int(x) < 11):
|
196 |
+
self.euc_100_emotion_degree.append(int(x))
|
197 |
+
self.euc_100_input_time.append(time.gmtime())
|
198 |
+
if entry == len(ChatHelper.euc_100['q']) - 1:
|
199 |
+
if self.run_on_own_server:
|
200 |
+
print(self.euc_100_emotion_degree)
|
201 |
+
mood = 'good_mood' if self.euc_100_emotion_degree[0] > 5 else 'bad_mood'
|
202 |
+
self.chat_state = f'euc_100.{mood}.0'
|
203 |
+
response = ChatHelper.euc_100[mood][0]
|
204 |
+
else:
|
205 |
+
self.chat_state = f'euc_100.q.{entry+1}'
|
206 |
+
response = ChatHelper.euc_100['q'][entry+1]
|
207 |
else:
|
208 |
+
response = ChatHelper.invalid_input
|
|
|
|
|
|
|
209 |
|
210 |
+
elif subsection == 'good_mood':
|
211 |
+
if x == '':
|
212 |
+
response = ChatHelper.good_mood_over
|
213 |
+
else:
|
214 |
+
response = ChatHelper.good_case
|
215 |
+
response += '\n' + ChatHelper.euc_100['good_mood'][1]
|
216 |
+
self.chat_state = 'free_chat'
|
217 |
|
218 |
+
elif subsection == 'bad_mood':
|
219 |
+
if entry == -1:
|
220 |
+
if 'yes' in x.lower() or 'better' in x.lower():
|
221 |
+
response = ChatHelper.good_case
|
222 |
+
else:
|
223 |
+
entry = int(self.chat_state_prev.rsplit('.', 1))
|
224 |
+
response = ChatHelper.not_answer + '\n' + ChatHelper.euc_100['bad_mood'][entry+1]
|
225 |
+
if entry == len(ChatHelper.euc_100['bad_mood']) - 2:
|
226 |
+
self.chat_state = 'free_chat'
|
227 |
+
else:
|
228 |
+
self.chat_state = f'euc_100.bad_mood.{entry+1}'
|
229 |
|
230 |
+
if x == '':
|
231 |
+
response = ChatHelper.bad_mood_over
|
232 |
+
self.chat_state_prev = self.chat_state
|
233 |
+
self.chat_state = 'euc_100.bad_mood.-1'
|
234 |
+
else:
|
235 |
+
response = ChatHelper.euc_100['bad_mood'][entry+1]
|
236 |
+
if entry == len(ChatHelper.euc_100['bad_mood']) - 2:
|
237 |
+
self.chat_state = 'free_chat'
|
238 |
+
else:
|
239 |
+
self.chat_state = f'euc_100.bad_mood.{entry+1}'
|
240 |
|
241 |
+
return response
|
|
|
242 |
|
243 |
+
def euc_200(self, x):
|
244 |
+
# don't ask question in euc_200, because they're similar to question in euc_100
|
245 |
+
if x.lower() == 'okay':
|
246 |
+
response = ChatHelper.display_form
|
247 |
+
else:
|
248 |
+
response = ChatHelper.reference
|
249 |
+
response += ChatHelper.euc_200.format(self.message_prev)
|
|
|
250 |
|
251 |
+
message = self.message_prev
|
252 |
+
self.message_prev = x
|
253 |
+
self.chat_state = self.chat_state_prev
|
254 |
+
return self.__call__(message, response)
|
255 |
|
256 |
+
def free_chat(self, message):
|
257 |
+
message_ids = self.dialog_tokenizer.encode(message + self.eos, return_tensors='pt')
|
258 |
+
self.history['input_ids'] = torch.cat([self.history['input_ids'], message_ids], dim=-1)
|
259 |
+
input_ids = self.history['input_ids'].clone()
|
260 |
|
261 |
+
while True:
|
262 |
+
bot_output_ids = self.dialog_model.generate(input_ids, max_length=1000,
|
263 |
+
do_sample=True, top_p=0.9, temperature=0.8, num_beams=2,
|
264 |
+
pad_token_id=self.dialog_tokenizer.eos_token_id)
|
265 |
+
response = self.dialog_tokenizer.decode(bot_output_ids[0][input_ids.shape[-1]:],
|
266 |
+
skip_special_tokens=True)
|
267 |
+
if response.strip() != '':
|
268 |
+
break
|
269 |
+
elif input_ids[0].tolist().count(self.dialog_tokenizer.eos_token_id) > 0:
|
270 |
+
idx = input_ids[0].tolist().index(self.dialog_tokenizer.eos_token_id)
|
271 |
+
input_ids = input_ids[:, (idx+1):]
|
272 |
+
else:
|
273 |
+
input_ids = message_ids
|
274 |
+
|
275 |
+
if self.run_on_own_server:
|
276 |
+
print(input_ids)
|
277 |
+
|
278 |
+
self.history['input_ids'] = torch.cat([self.history['input_ids'], bot_output_ids[0:1, input_ids.shape[-1]:]], dim=-1)
|
279 |
+
if self.run_on_own_server == 1:
|
280 |
+
print((message, response), '\n', self.history['input_ids'])
|
281 |
|
282 |
+
return response
|
283 |
+
|
284 |
+
|
285 |
+
if __name__ == '__main__':
|
286 |
+
args = option()
|
287 |
+
chat = TherapyChatBot(args)
|
288 |
|
289 |
title = 'PsyPlus Empathetic Chatbot'
|
290 |
description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
|
291 |
+
chatbot = gr.Chatbot(value=chat.history['text'])
|
292 |
iface = gr.Interface(
|
293 |
+
chat, 'text', chatbot,
|
294 |
+
allow_flagging='never', title=title, description=description,
|
|
|
|
|
|
|
|
|
|
|
295 |
)
|
296 |
+
|
297 |
+
# iface.queue(concurrency_count=5)
|
298 |
if args.run_on_own_server == 0:
|
299 |
iface.launch(debug=True)
|
300 |
else:
|
301 |
+
iface.launch(debug=True, share=True) # server_name='0.0.0.0', server_port=2022
|