File size: 2,949 Bytes
121a1b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import openai
openai.api_key_path = './openai_api_key.txt'
import streamlit as st
from streamlit_chat import message


completion = openai.Completion()


start_prompt = '[Instruction] Act as a friendly, compasionate, insightful, and empathetic AI therapist named Joy. Joy listens, asks for details and offers detailed advices once a while.  End the conversation when you wishes to.'
start_message = 'I am Joy, your AI therapist. How are you feeling today?'

start_sequence = "\nJoy:"
restart_sequence = "\n\nYou:"

# to do: 
# let the user choose between models (curie, davinci, curie-finetuned, davinci-finetuned)
# let the user choose between different temperatures, frequency_penalty, presence_penalty

# save the user's input and the model's output to the database
# analyze the user's input and the model's output 
# sentiment/mood analysis / topic analysis of the user's input 

# embed the user's input and look for therapy catalogue that is similar to the user's input
# push the therapy catalogue to the user

  
def ask(question: str, chat_log: str) -> (str, str):

  prompt = f'{chat_log}{restart_sequence} {question}{start_sequence}'

  response = completion.create(
      prompt = prompt,
      model = model,
      stop = ["You:",'Joy:'],
      temperature = temp, #the higher the more creative
      frequency_penalty = 0.3, #prevents word repetition, larger -> higher penalty
      presence_penalty = 0.6, #prevents topic repetition, larger -> higher penalty
      top_p =1, 
      best_of=1,
      max_tokens=170
  )
  
  answer = response.choices[0].text.strip()
  log = f'{restart_sequence}{question}{start_sequence}{answer}'
  return str(answer), str(log)


# button for starting a new conversation

st.title("Chat with Joy - the AI therapist!")
temp = st.slider("Creativity", 0.0, 1.0, 0.7, 0.1)
model = st.selectbox("Model", ["text-davinci-003", "text-curie-001", "curie:ft-personal-2023-02-03-17-06-53"])

if 'generated' not in st.session_state:
    st.session_state['generated'] = [start_message]

if 'past' not in st.session_state:
    st.session_state['past'] = []



if 'chat_log' not in st.session_state:
    st.session_state['chat_log'] = [start_prompt+start_sequence+start_message]

user_input=st.text_input("You:",key='input')

if user_input:
    output, chat_log = ask(user_input, st.session_state['chat_log'])
    st.session_state['chat_log'].append(chat_log)
    st.session_state['past'].append(user_input)
    st.session_state['generated'].append(output)
    print(st.session_state['chat_log'])
if st.session_state['generated']:
    for i in range(len(st.session_state['generated'])-1, -1, -1):
        if i < len(st.session_state['past']):
            message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
        message(st.session_state["generated"][i], key=str(i))

# save the user's input and the model's output to the database and analyze the user's input and the model's output