File size: 9,253 Bytes
b8b10fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
from operator import itemgetter
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema import StrOutputParser
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableLambda
from langchain.schema.runnable.config import RunnableConfig
from langchain.memory import ConversationBufferMemory
from resolution_logic import ResolutionLogic
from literal_thread_manager import LiteralThreadManager
from prompt_engineering.prompt_desing import system_prompt, system_prompt_b, system_prompt_questioning
import chainlit as cl
from chainlit.types import ThreadDict
import os
from dotenv import load_dotenv
# Load the environment variables from the .env file
load_dotenv()
jwt_secret_key = os.getenv('CHAINLIT_AUTH_SECRET')
if not jwt_secret_key:
raise ValueError(
"You must provide a JWT secret in the environment to use authentication.")
# Get the value of the OPENAI_API_KEY from the environment variables
openai_api_key = os.getenv("OPENAI_API_KEY")
# Set the OPENAI_API_KEY in the environment
os.environ["OPENAI_API_KEY"] = openai_api_key
manager = LiteralThreadManager(api_key=os.getenv("LITERAL_API_KEY"))
def setup_runnable():
"""
Sets up the runnable pipeline for the chatbot. This pipeline includes a model for generating responses
and memory management for maintaining conversation context.
"""
memory = cl.user_session.get("memory") # type: ConversationBufferMemory
model = ChatOpenAI(streaming=True, model="gpt-3.5-turbo")
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt_questioning),
MessagesPlaceholder(variable_name="history"),
("human", "{question}"),
]
)
runnable = (
RunnablePassthrough.assign(
history=RunnableLambda(
memory.load_memory_variables) | itemgetter("history")
)
| prompt
| model
| StrOutputParser()
)
cl.user_session.set("runnable", runnable)
@cl.password_auth_callback
def auth_callback(username: str, password: str):
"""
Authenticates a user using the provided username and password. If the user does not exist in the
LiteralAI database, a new user is created.
Args:
username (str): The username provided by the user.
password (str): The password provided by the user.
Returns:
cl.User | None: A User object if authentication is successful, create a User otherwise.
"""
auth_user = manager.literal_client.api.get_or_create_user(identifier=username)
if auth_user:
if username != "admin":
return cl.User(
identifier=username, metadata={
"role": "user", "provider": "credentials"}
)
else:
return cl.User(
identifier=username, metadata={
"role": "admin", "provider": "credentials"}
)
else:
return None
def create_and_update_threads(first_res, current_user, partner_user):
"""
Creates and updates threads for the conversation between the current user and their partner.
Args:
first_res (str): The initial response from the user.
current_user (cl.User): The current user initiating the conversation.
partner_user (cl.User): The partner user to connect with.
"""
latest_thread = manager.literal_client.api.get_threads(first=1)
partner_thread = manager.literal_client.api.create_thread(name=first_res['output'], participant_id=partner_user.id, metadata={
"partner_id": current_user.id, "partner_thread_id": latest_thread.data[0].id, "user_id": partner_user.id})
resolver = ResolutionLogic()
message_to_other_partner = resolver.summarize_conflict_topic(partner_user.identifier, current_user.identifier, first_res['output'])
manager.literal_client.api.create_step(thread_id=partner_thread.id, type="assistant_message",
output={'content': message_to_other_partner})
current_thread = manager.literal_client.api.upsert_thread(id=latest_thread.data[0].id,
participant_id=current_user.id, metadata={"partner_id": partner_user.id, "partner_thread_id": partner_thread.id})
cl.user_session.set("thread_id", current_thread.id)
manager.get_other_partner_thread_id(current_thread.id)
@cl.action_callback("2-1 Chat")
async def on_action(action):
"""
Handles the action callback for initiating a 2-1 chat.
Args:
action (cl.Action): The action object containing the user's input.
"""
await cl.Message(content="Write the email and the chat id:").send()
action.get("value")
await action.remove()
@cl.on_chat_start
async def on_chat_start():
"""
Handles the start of a chat session. Initializes the memory, sets up the runnable pipeline, and prompts the user
to summarize the type of conflict.
"""
cl.user_session.set("memory", ConversationBufferMemory(return_messages=True))
setup_runnable()
first_res = await cl.AskUserMessage(content="Welcome to the Relationship Coach chatbot. I can help you with your relationship questions. Please first summarize the type of conflict.").send()
add_person = await cl.AskActionMessage(
content="Select the conversation type.",
actions=[
cl.Action(name="1-1 Chat", value="1-1 Chat", label="π€ 1-1"),
cl.Action(name="2-1 Chat", value="2-1 Chat", label="π₯ 2-1"),
],
).send()
if add_person and add_person.get("value") == "2-1 Chat":
res = await cl.AskUserMessage(content="Please write the username of the person to connect with.").send()
if res:
# request the parnet username until it exists in the db
while manager.literal_client.api.get_user(identifier=res["output"]) == None:
await cl.Message(content=f"Partner {res['output']} does not exist in db.").send()
res = await cl.AskUserMessage(content="Please write the username of the person to connect with.").send()
partner_username = res['output']
partner_user = manager.literal_client.api.get_user(identifier=partner_username)
current_user = cl.user_session.get("user")
current_username = current_user.identifier
manager.literal_client.api.update_user(id=current_user.id, identifier=current_username, metadata={
"role": "user", "provider": "credentials", "relationships": {"partner_username": partner_username}})
await cl.Message(content=f"Connected with {partner_username}!").send()
await on_message(cl.Message(content=first_res['output']))
create_and_update_threads(first_res, current_user, partner_user)
else:
await cl.Message(
content=f"Action timed out!",
).send()
@cl.on_chat_resume
async def on_chat_resume(thread: ThreadDict):
"""
Handles the resumption of a chat session. Restores the chat memory and sets up the runnable pipeline.
Args:
thread (ThreadDict): The thread dictionary containing the chat history.
"""
memory = ConversationBufferMemory(return_messages=True)
root_messages = [m for m in thread["steps"] if m["parentId"] == None]
for message in root_messages:
if message["type"] == "user_message":
memory.chat_memory.add_user_message(message["output"])
else:
memory.chat_memory.add_ai_message(message["output"])
cl.user_session.set("memory", memory)
cl.user_session.set("thread_id", thread["id"])
setup_runnable()
conflict_resolution = ResolutionLogic()
resolution = conflict_resolution.intervention(thread["id"])
if resolution:
await cl.Message(content=resolution).send()
@cl.on_message
async def on_message(message: cl.Message):
"""
Handles incoming messages during a chat session. Updates the memory and generates a response.
Args:
message (cl.Message): The incoming message from the user.
"""
memory = cl.user_session.get("memory") # type: ConversationBufferMemory
runnable = cl.user_session.get("runnable") # type: Runnable
response = cl.Message(content="")
conflict_resolution = ResolutionLogic()
if cl.user_session.get("thread_id"):
resolution = conflict_resolution.intervention(cl.user_session.get("thread_id"))
if cl.user_session.get("thread_id") and resolution:
response = cl.Message(content=resolution)
else:
async for chunk in runnable.astream(
{"question": message.content},
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
):
await response.stream_token(chunk)
await response.send()
memory.chat_memory.add_user_message(message.content)
memory.chat_memory.add_ai_message(response.content)
def main():
"""
The main function to demonstrate the usage of the chatbot. Initializes the chat session and starts the event loop.
"""
on_chat_start()
cl.run()
if __name__ == "__main__":
main()
|