vinhnx90 commited on
Commit
24bec11
1 Parent(s): 375bd04

Extract chat profile to class

Browse files
Files changed (2) hide show
  1. app.py +8 -12
  2. chat_profile.py +26 -0
app.py CHANGED
@@ -2,8 +2,9 @@ import os
2
  import streamlit as st
3
 
4
  from token_stream_handler import StreamHandler
 
 
5
  from langchain.chains import ConversationalRetrievalChain
6
- from langchain.schema import ChatMessage
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
9
  from langchain_community.vectorstores.chroma import Chroma
@@ -67,22 +68,17 @@ def main():
67
 
68
  assistant_message = "Hello, you can upload a document and chat with me to ask questions related to its content."
69
  st.session_state["messages"] = [
70
- ChatMessage(role="assistant", content=assistant_message)
71
  ]
72
 
73
- st.chat_message("assistant").write(assistant_message)
74
 
75
  if prompt := st.chat_input(
76
  placeholder="Chat with your document",
77
  disabled=(not st.session_state.api_key),
78
  ):
79
- st.session_state.messages.append(
80
- ChatMessage(
81
- role="user",
82
- content=prompt,
83
- )
84
- )
85
- st.chat_message("user").write(prompt)
86
 
87
  handle_question(prompt)
88
 
@@ -108,7 +104,7 @@ def handle_question(question):
108
  for msg in st.session_state.messages:
109
  st.chat_message(msg.role).write(msg.content)
110
 
111
- with st.chat_message("assistant"):
112
  stream_handler = StreamHandler(st.empty())
113
  llm = ChatOpenAI(
114
  openai_api_key=st.session_state.api_key,
@@ -117,7 +113,7 @@ def handle_question(question):
117
  )
118
  response = llm.invoke(st.session_state.messages)
119
  st.session_state.messages.append(
120
- ChatMessage(role="assistant", content=response.content)
121
  )
122
 
123
 
 
2
  import streamlit as st
3
 
4
  from token_stream_handler import StreamHandler
5
+ from chat_profile import User, Assistant, ChatProfileRoleEnum
6
+
7
  from langchain.chains import ConversationalRetrievalChain
 
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
10
  from langchain_community.vectorstores.chroma import Chroma
 
68
 
69
  assistant_message = "Hello, you can upload a document and chat with me to ask questions related to its content."
70
  st.session_state["messages"] = [
71
+ Assistant(message=assistant_message).build_message()
72
  ]
73
 
74
+ st.chat_message(ChatProfileRoleEnum.Assistant).write(assistant_message)
75
 
76
  if prompt := st.chat_input(
77
  placeholder="Chat with your document",
78
  disabled=(not st.session_state.api_key),
79
  ):
80
+ st.session_state.messages.append(User(message=prompt).build_message())
81
+ st.chat_message(ChatProfileRoleEnum.User).write(prompt)
 
 
 
 
 
82
 
83
  handle_question(prompt)
84
 
 
104
  for msg in st.session_state.messages:
105
  st.chat_message(msg.role).write(msg.content)
106
 
107
+ with st.chat_message(ChatProfileRoleEnum.Assistant):
108
  stream_handler = StreamHandler(st.empty())
109
  llm = ChatOpenAI(
110
  openai_api_key=st.session_state.api_key,
 
113
  )
114
  response = llm.invoke(st.session_state.messages)
115
  st.session_state.messages.append(
116
+ Assistant(message=response.content).build_message()
117
  )
118
 
119
 
chat_profile.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.schema import ChatMessage
2
+ from enum import Enum
3
+
4
+
5
+ class ChatProfileRoleEnum(str, Enum):
6
+ User = "user"
7
+ Assistant = "assistant"
8
+
9
+
10
+ class ChatProfile:
11
+ def __init__(self, role: str, message: str):
12
+ self.role = role
13
+ self.message = message
14
+
15
+ def build_message(self) -> ChatMessage:
16
+ return ChatMessage(role=self.role, content=self.message)
17
+
18
+
19
+ class Assistant(ChatProfile):
20
+ def __init__(self, message: str):
21
+ super().__init__(ChatProfileRoleEnum.Assistant, message)
22
+
23
+
24
+ class User(ChatProfile):
25
+ def __init__(self, message: str):
26
+ super().__init__(ChatProfileRoleEnum.User, message)