File size: 9,253 Bytes
b8b10fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
from operator import itemgetter
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema import StrOutputParser
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableLambda
from langchain.schema.runnable.config import RunnableConfig
from langchain.memory import ConversationBufferMemory
from resolution_logic import ResolutionLogic
from literal_thread_manager import LiteralThreadManager
from prompt_engineering.prompt_desing import system_prompt, system_prompt_b, system_prompt_questioning
import chainlit as cl
from chainlit.types import ThreadDict
import os
from dotenv import load_dotenv

# Load the environment variables from the .env file
load_dotenv()

jwt_secret_key = os.getenv('CHAINLIT_AUTH_SECRET')
if not jwt_secret_key:
    raise ValueError(
        "You must provide a JWT secret in the environment to use authentication.")

# Get the value of the OPENAI_API_KEY from the environment variables
openai_api_key = os.getenv("OPENAI_API_KEY")

# Set the OPENAI_API_KEY in the environment
os.environ["OPENAI_API_KEY"] = openai_api_key
manager = LiteralThreadManager(api_key=os.getenv("LITERAL_API_KEY"))

def setup_runnable():
    """
    Sets up the runnable pipeline for the chatbot. This pipeline includes a model for generating responses
    and memory management for maintaining conversation context.
    """
    memory = cl.user_session.get("memory")  # type: ConversationBufferMemory
    model = ChatOpenAI(streaming=True, model="gpt-3.5-turbo")
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt_questioning),
            MessagesPlaceholder(variable_name="history"),
            ("human", "{question}"),
        ]
    )

    runnable = (
        RunnablePassthrough.assign(
            history=RunnableLambda(
                memory.load_memory_variables) | itemgetter("history")
        )
        | prompt
        | model
        | StrOutputParser()
    )
    cl.user_session.set("runnable", runnable)

@cl.password_auth_callback
def auth_callback(username: str, password: str):
    """
    Authenticates a user using the provided username and password. If the user does not exist in the
    LiteralAI database, a new user is created.

    Args:
        username (str): The username provided by the user.
        password (str): The password provided by the user.

    Returns:
        cl.User | None: A User object if authentication is successful, create a User otherwise.
    """
    auth_user = manager.literal_client.api.get_or_create_user(identifier=username)
    if auth_user:
        if username != "admin":
            return cl.User(
                identifier=username, metadata={
                    "role": "user", "provider": "credentials"}
            )
        else:
            return cl.User(
                identifier=username, metadata={
                    "role": "admin", "provider": "credentials"}
            )
    else:
        return None

def create_and_update_threads(first_res, current_user, partner_user):
    """
    Creates and updates threads for the conversation between the current user and their partner.

    Args:
        first_res (str): The initial response from the user.
        current_user (cl.User): The current user initiating the conversation.
        partner_user (cl.User): The partner user to connect with.
    """
    latest_thread = manager.literal_client.api.get_threads(first=1)
    partner_thread = manager.literal_client.api.create_thread(name=first_res['output'], participant_id=partner_user.id, metadata={
        "partner_id": current_user.id, "partner_thread_id": latest_thread.data[0].id, "user_id": partner_user.id})
    resolver = ResolutionLogic()
    message_to_other_partner = resolver.summarize_conflict_topic(partner_user.identifier, current_user.identifier, first_res['output'])
    manager.literal_client.api.create_step(thread_id=partner_thread.id, type="assistant_message",
                                           output={'content': message_to_other_partner})
    current_thread = manager.literal_client.api.upsert_thread(id=latest_thread.data[0].id,
                                                              participant_id=current_user.id, metadata={"partner_id": partner_user.id, "partner_thread_id": partner_thread.id})
    cl.user_session.set("thread_id", current_thread.id)
    manager.get_other_partner_thread_id(current_thread.id)

@cl.action_callback("2-1 Chat")
async def on_action(action):
    """
    Handles the action callback for initiating a 2-1 chat.

    Args:
        action (cl.Action): The action object containing the user's input.
    """
    await cl.Message(content="Write the email and the chat id:").send()
    action.get("value")
    await action.remove()

@cl.on_chat_start
async def on_chat_start():
    """
    Handles the start of a chat session. Initializes the memory, sets up the runnable pipeline, and prompts the user
    to summarize the type of conflict.
    """
    cl.user_session.set("memory", ConversationBufferMemory(return_messages=True))
    setup_runnable()
    first_res = await cl.AskUserMessage(content="Welcome to the Relationship Coach chatbot. I can help you with your relationship questions. Please first summarize the type of conflict.").send()
    add_person = await cl.AskActionMessage(
        content="Select the conversation type.",
        actions=[
            cl.Action(name="1-1 Chat", value="1-1 Chat", label="πŸ‘€ 1-1"),
            cl.Action(name="2-1 Chat", value="2-1 Chat", label="πŸ‘₯ 2-1"),
        ],
    ).send()

    if add_person and add_person.get("value") == "2-1 Chat":
        res = await cl.AskUserMessage(content="Please write the username of the person to connect with.").send()
        if res:
            # request the parnet username until it exists in the db
            while manager.literal_client.api.get_user(identifier=res["output"]) == None:
                await cl.Message(content=f"Partner {res['output']} does not exist in db.").send()
                res = await cl.AskUserMessage(content="Please write the username of the person to connect with.").send()
            partner_username = res['output']
            partner_user = manager.literal_client.api.get_user(identifier=partner_username)
            current_user = cl.user_session.get("user")
            current_username = current_user.identifier
            manager.literal_client.api.update_user(id=current_user.id, identifier=current_username, metadata={
                "role": "user", "provider": "credentials", "relationships": {"partner_username": partner_username}})
            await cl.Message(content=f"Connected with {partner_username}!").send()
            await on_message(cl.Message(content=first_res['output']))
            create_and_update_threads(first_res, current_user, partner_user)
        else:
            await cl.Message(
                content=f"Action timed out!",
            ).send()

@cl.on_chat_resume
async def on_chat_resume(thread: ThreadDict):
    """
    Handles the resumption of a chat session. Restores the chat memory and sets up the runnable pipeline.

    Args:
        thread (ThreadDict): The thread dictionary containing the chat history.
    """
    memory = ConversationBufferMemory(return_messages=True)
    root_messages = [m for m in thread["steps"] if m["parentId"] == None]
    for message in root_messages:
        if message["type"] == "user_message":
            memory.chat_memory.add_user_message(message["output"])
        else:
            memory.chat_memory.add_ai_message(message["output"])

    cl.user_session.set("memory", memory)
    cl.user_session.set("thread_id", thread["id"])

    setup_runnable()

    conflict_resolution = ResolutionLogic()
    resolution = conflict_resolution.intervention(thread["id"])

    if resolution:
        await cl.Message(content=resolution).send()

@cl.on_message
async def on_message(message: cl.Message):
    """
    Handles incoming messages during a chat session. Updates the memory and generates a response.

    Args:
        message (cl.Message): The incoming message from the user.
    """
    memory = cl.user_session.get("memory")  # type: ConversationBufferMemory
    runnable = cl.user_session.get("runnable")  # type: Runnable

    response = cl.Message(content="")

    conflict_resolution = ResolutionLogic()
    if cl.user_session.get("thread_id"):
        resolution = conflict_resolution.intervention(cl.user_session.get("thread_id"))

    if cl.user_session.get("thread_id") and resolution:
        response = cl.Message(content=resolution)
    else:
        async for chunk in runnable.astream(
            {"question": message.content},
            config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
        ):
            await response.stream_token(chunk)

    await response.send()

    memory.chat_memory.add_user_message(message.content)
    memory.chat_memory.add_ai_message(response.content)

def main():
    """
    The main function to demonstrate the usage of the chatbot. Initializes the chat session and starts the event loop.
    """
    on_chat_start()
    cl.run()

if __name__ == "__main__":
    main()