|
import openai |
|
openai.api_key_path = './openai_api_key.txt' |
|
import streamlit as st |
|
from streamlit_chat import message |
|
|
|
|
|
completion = openai.Completion() |
|
|
|
|
|
start_prompt = '[Instruction] Act as a friendly, compasionate, insightful, and empathetic AI therapist named Joy. Joy listens, asks for details and offers detailed advices once a while. End the conversation when you wishes to.' |
|
start_message = 'I am Joy, your AI therapist. How are you feeling today?' |
|
|
|
start_sequence = "\nJoy:" |
|
restart_sequence = "\n\nYou:" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ask(question: str, chat_log: str) -> (str, str): |
|
|
|
prompt = f'{chat_log}{restart_sequence} {question}{start_sequence}' |
|
|
|
response = completion.create( |
|
prompt = prompt, |
|
model = model, |
|
stop = ["You:",'Joy:'], |
|
temperature = temp, |
|
frequency_penalty = 0.3, |
|
presence_penalty = 0.6, |
|
top_p =1, |
|
best_of=1, |
|
max_tokens=170 |
|
) |
|
|
|
answer = response.choices[0].text.strip() |
|
log = f'{restart_sequence}{question}{start_sequence}{answer}' |
|
return str(answer), str(log) |
|
|
|
|
|
|
|
|
|
st.title("Chat with Joy - the AI therapist!") |
|
temp = st.slider("Creativity", 0.0, 1.0, 0.7, 0.1) |
|
model = st.selectbox("Model", ["text-davinci-003", "text-curie-001", "curie:ft-personal-2023-02-03-17-06-53"]) |
|
|
|
if 'generated' not in st.session_state: |
|
st.session_state['generated'] = [start_message] |
|
|
|
if 'past' not in st.session_state: |
|
st.session_state['past'] = [] |
|
|
|
|
|
|
|
if 'chat_log' not in st.session_state: |
|
st.session_state['chat_log'] = [start_prompt+start_sequence+start_message] |
|
|
|
user_input=st.text_input("You:",key='input') |
|
|
|
if user_input: |
|
output, chat_log = ask(user_input, st.session_state['chat_log']) |
|
st.session_state['chat_log'].append(chat_log) |
|
st.session_state['past'].append(user_input) |
|
st.session_state['generated'].append(output) |
|
print(st.session_state['chat_log']) |
|
if st.session_state['generated']: |
|
for i in range(len(st.session_state['generated'])-1, -1, -1): |
|
if i < len(st.session_state['past']): |
|
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') |
|
message(st.session_state["generated"][i], key=str(i)) |
|
|
|
|
|
|
|
|
|
|