''' Dialog System of PsyPlus (dvq) reference: https://huggingface.co/spaces/bentrevett/emotion-prediction https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT https://huggingface.co/benjaminbeilharz/t5-empatheticdialogues gradio vs streamlit https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/ https://gradio.app/interface_state/ -> global and local varible affect the separation of sessions TODO Add command to reset/jump to a function, e.g >reset, >euc_100 Add diagram in Gradio Interface showing sentimate analysis Gradio input timeout: cannot find a tutorial in Google -> don't know how to implement Personalize: create database, load and save data Run command python app.py --run_on_own_server 1 --initial_chat_state free_chat ''' import argparse import re, time import matplotlib.pyplot as plt from threading import Timer import gradio as gr import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline def option(): parser = argparse.ArgumentParser() parser.add_argument('--run_on_own_server', type=int, default=0, help='if test on own server, need to use share mode') parser.add_argument('--dialog_model', type=str, default='tareknaous/dialogpt-empathetic-dialogues') parser.add_argument('--emotion_model', type=str, default='joeddav/distilbert-base-uncased-go-emotions-student') parser.add_argument('--account', type=str, default=None) parser.add_argument('--initial_chat_state', type=str, default='euc_100', choices=['euc_100', 'euc_200', 'free_chat']) args = parser.parse_args() return args args = option() # store the list of messages that are showed in therapies and models as global variables # let all chat-session-wise variables placed in TherapyChatBot class ChatHelper: # chat and emotion-detection models ed_pipe = pipeline('text-classification', model=args.emotion_model, top_k=5, truncation=True) ed_threshold = 0.3 dialog_model = GPT2LMHeadModel.from_pretrained(args.dialog_model) dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.dialog_model) eos = dialog_tokenizer.eos_token # tokenizer.__call__ -> input_ids, attention_mask # tokenizer.encode -> only inputs_ids, which is required by model.generate function invalid_input = 'Invalid input, my friend :) Plz input again' good_mood_over = 'Whether your good mood is over? Any other details that you would like to recall?' good_case = 'Nice to hear that!' bad_mood_over = 'Whether your bad mood is over? (Yes or No)' not_answer = "It's okay, maybe you don't want to answer this question." fill_form = ('It has come to our attention that you may suffer from {}.\n' 'If you want to know more about yourself, some professional scales are provided to quantify your current status.\n' 'After a period of time (maybe a week/two months/a month) trying to follow the solutions we suggested, ' 'you can fill out these scales again to see if you have improved.\n' 'Do you want to fill in the form now? (Okay or Later)') display_form = '.\n' reference = 'Here are some reference articles about bad emotions. You can take a look :) \n' emotion_types = ['Overall', 'Happiness', 'Anxiety'] # 'Surprise', 'Sadness', 'Depression', 'Anger', 'Fear', euc_100 = { 'q': emotion_types, 'good_mood': [ 'You seem to be in a good mood today. Is there anything you could notice that makes you happy?', 'I am glad that you are willing to share the experience with me. Thanks for letting me know.', ], 'bad_mood': [ 'You seem not to be in a good mood. What specific thing is bothering you the most right now?', 'I see. So when it is happening, what feelings or emotions have you got?', 'And what do you think about those feelings or emotions at that time?', 'Could you think of any evidence for your above-mentioned thought?', 'Here are some reference articles about bad emotions. You can take a look :)', ], } negative_emotions = ['remorse', 'nervousness', 'annoyance', 'anger', 'grief', 'fear', 'disapproval', 'confusion', 'embarrassment', 'disgust', 'sadness', 'disappointment'] euc_200 = 'Now go back to the last chat. You said that "{}".\n' greeting_template = { 'euc_100': 'How was your day? On the scale 1 to 10, ' 'how would you judge your emotion through the following categories:\nOverall', # euc_200 is only trigger when you say smt more negative than a certain threshol # thus the greeting here is only for debuging euc_200 'euc_200': fill_form.format('anxiety'), 'free_chat': 'Hi you! How is it going?', } def plot_emotion_distribution(predictions): fig, ax = plt.subplots() ax.bar(x=[i for i, _ in enumerate(prediction)], height=[p['score'] for p in prediction], tick_label=[p['label'] for p in prediction]) ax.tick_params(rotation=90) ax.set_ylim(0, 1) plt.show() def ed_rulebase(text): keywords = { 'life_safety': ['death', 'suicide', 'murder', 'to perish together', 'jump off the building'], 'immediacy': ['now', 'immediately', 'tomorrow', 'today'], 'manifestation': ['never stop', 'every moment', 'strong', 'very'] } # if found dangerous kw/topics if re.search(rf"{'|'.join(keywords['life_safety'])}", text) != None and \ sum([re.search(rf"{'|'.join(keywords[k])}", text) != None for k in ['immediacy','manifestation']]) >= 1: print('We noticed that you may need immediate professional assistance, would you like to make a phone call? ' 'The Hong Kong Lifeline number is (852) 2382 0000') x = input('Choose 1. "Dial to the number" or 2. "No dangerous emotion la": ') if x == '1': print('Let you connect to the office') else: print('Sorry for our misdetection. We just want to make sure that you could get immediate help when needed. ' 'Would you mind if we send this conversation to the cloud to finetune the model.') y = input('Yes or No: ') if y == 'Yes': pass # do smt here class TherapyChatBot: def __init__(self, args): # check state to control the dialog self.chat_state = args.initial_chat_state # name of the chat function/therapy segment the model is in self.message_prev = None self.chat_state_prev = None self.run_on_own_server = args.run_on_own_server self.account = args.account # additional attribute for euc_100 self.euc_100_input_time = [] self.euc_100_emotion_degree = [] self.already_trigger_euc_200 = False # chat history. # TODO: if we want to personalize and save the conversation, # we can load data from database self.greeting = [('', ChatHelper.greeting_template[self.chat_state])] self.history = {'input_ids': torch.tensor([[ChatHelper.dialog_tokenizer.bos_token_id]]), 'text': self.greeting} if not self.account else open(f'database/{hash(self.account)}', 'rb') if 'euc_100' in self.chat_state: self.chat_state = 'euc_100.q.0' def __call__(self, message, prefix=''): # if prefix != None, which means this function is called from euc_200, thus already detected the negative emotion if (not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200: prediction = ChatHelper.ed_pipe(message)[0] prediction = sorted(prediction, key=lambda x: x['score'], reverse=True) if self.run_on_own_server: print(prediction) # plot_emotion_distribution(prediction) emotion = prediction[0] # if message is negative, change state immediately if ((not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200) and \ (emotion['label'] in ChatHelper.negative_emotions and emotion['score'] > ChatHelper.ed_threshold): self.chat_state_prev = self.chat_state self.chat_state = 'euc_200' self.message_prev = message self.already_trigger_euc_200 = True response = ChatHelper.fill_form.format(emotion['label']) # set up rule to update state inside each dialog function elif self.chat_state.startswith('euc_100'): response = self.euc_100(message) if self.chat_state == 'free_chat': last_two_turns_ids = ChatHelper.dialog_tokenizer.encode(message + ChatHelper.eos, return_tensors='pt') self.history['input_ids'] = torch.cat([self.history['input_ids'], last_two_turns_ids], dim=-1) elif self.chat_state.startswith('euc_200'): return self.euc_200(message) else: # free_chat response = self.free_chat(message) if prefix: response = prefix + response self.history['text'].append((self.message_prev, response)) else: self.history['text'].append((message, response)) def euc_100(self, x): _, subsection, entry = self.chat_state.split('.') entry = int(entry) if subsection == 'q': if x.isnumeric() and (0 < int(x) < 11): self.euc_100_emotion_degree.append(int(x)) self.euc_100_input_time.append(time.gmtime()) if entry == len(ChatHelper.euc_100['q']) - 1: if self.run_on_own_server: print(self.euc_100_emotion_degree) mood = 'good_mood' if self.euc_100_emotion_degree[0] > 5 else 'bad_mood' self.chat_state = f'euc_100.{mood}.0' response = ChatHelper.euc_100[mood][0] else: self.chat_state = f'euc_100.q.{entry+1}' response = ChatHelper.euc_100['q'][entry+1] else: response = ChatHelper.invalid_input elif subsection == 'good_mood': if x == '': response = ChatHelper.good_mood_over else: response = ChatHelper.good_case response += '\n' + ChatHelper.euc_100['good_mood'][1] self.chat_state = 'free_chat' elif subsection == 'bad_mood': if entry == -1: if 'yes' in x.lower() or 'better' in x.lower(): response = ChatHelper.good_case else: entry = int(self.chat_state_prev.rsplit('.', 1)) response = ChatHelper.not_answer + '\n' + ChatHelper.euc_100['bad_mood'][entry+1] if entry == len(ChatHelper.euc_100['bad_mood']) - 2: self.chat_state = 'free_chat' else: self.chat_state = f'euc_100.bad_mood.{entry+1}' if x == '': response = ChatHelper.bad_mood_over self.chat_state_prev = self.chat_state self.chat_state = 'euc_100.bad_mood.-1' else: response = ChatHelper.euc_100['bad_mood'][entry+1] if entry == len(ChatHelper.euc_100['bad_mood']) - 2: self.chat_state = 'free_chat' else: self.chat_state = f'euc_100.bad_mood.{entry+1}' return response def euc_200(self, x): # don't ask question in euc_200, because they're similar to question in euc_100 if x.lower() == 'okay': response = ChatHelper.display_form else: response = ChatHelper.reference response += ChatHelper.euc_200.format(self.message_prev) message = self.message_prev self.message_prev = x self.chat_state = self.chat_state_prev return self.__call__(message, prefix=response) def free_chat(self, message): message_ids = ChatHelper.dialog_tokenizer.encode(message + ChatHelper.eos, return_tensors='pt') self.history['input_ids'] = torch.cat([self.history['input_ids'], message_ids], dim=-1) input_ids = self.history['input_ids'].clone() while True: bot_output_ids = ChatHelper.dialog_model.generate(input_ids, max_length=1000, do_sample=True, top_p=0.9, temperature=0.8, num_beams=2, pad_token_id=ChatHelper.dialog_tokenizer.eos_token_id) response = ChatHelper.dialog_tokenizer.decode(bot_output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True) if response.strip() != '': break elif input_ids[0].tolist().count(ChatHelper.dialog_tokenizer.eos_token_id) > 0: idx = input_ids[0].tolist().index(ChatHelper.dialog_tokenizer.eos_token_id) input_ids = input_ids[:, (idx+1):] else: input_ids = message_ids if self.run_on_own_server: print(input_ids) self.history['input_ids'] = torch.cat([self.history['input_ids'], bot_output_ids[0:1, input_ids.shape[-1]:]], dim=-1) if self.run_on_own_server == 1: print((message, response), '\n', self.history['input_ids']) return response if __name__ == '__main__': def chat(message, bot): bot = bot or TherapyChatBot(args) bot(message) return bot.history['text'], bot title = 'PsyPlus Empathetic Chatbot' description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT' greeting = [('', ChatHelper.greeting_template[args.initial_chat_state])] chatbot = gr.Chatbot(value=greeting) iface = gr.Interface( chat, ['text', 'state'], [chatbot, 'state'], allow_flagging='never', title=title, description=description, ) if args.run_on_own_server == 0: iface.launch(debug=True) else: iface.launch(debug=True, share=True)