ray commited on
Commit
dfc6dc5
·
1 Parent(s): 28f4c9d

initial commit

Browse files
Files changed (7) hide show
  1. .gitignore +3 -0
  2. app.py +143 -0
  3. chat_template.py +32 -0
  4. chatbot.py +151 -0
  5. custom_io.py +73 -0
  6. qdrant.py +5 -0
  7. requirements.txt +7 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ **/__pycache__
3
+ awesumcare_data
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List
3
+ import gradio as gr
4
+ import openai
5
+ import os
6
+ from dotenv import load_dotenv
7
+ import phoenix as px
8
+ import llama_index
9
+ from llama_index import OpenAIEmbedding, Prompt, ServiceContext, VectorStoreIndex, SimpleDirectoryReader
10
+ from llama_index.chat_engine.types import ChatMode
11
+ from llama_index.llms import ChatMessage, MessageRole, OpenAI
12
+ from llama_index.vector_stores.qdrant import QdrantVectorStore
13
+ from llama_index.text_splitter import SentenceSplitter
14
+ from llama_index.extractors import TitleExtractor
15
+ from llama_index.ingestion import IngestionPipeline
16
+ from chat_template import CHAT_TEXT_QA_PROMPT
17
+ from chatbot import Chatbot, ChatbotVersion
18
+ from custom_io import UnstructuredReader, default_file_metadata_func
19
+ from qdrant import client as qdrantClient
20
+
21
+ load_dotenv()
22
+ openai.api_key = os.getenv("OPENAI_API_KEY")
23
+
24
+
25
+ class AwesumCareChatbot(Chatbot):
26
+ DENIED_ANSWER_PROMPT = ""
27
+ SYSTEM_PROMPT = ""
28
+ CHAT_EXAMPLES = [
29
+ "什麼是安心三寶?",
30
+ "點樣立平安紙?"
31
+ ]
32
+
33
+ def _load_doucments(self):
34
+ dir_reader = SimpleDirectoryReader('./awesumcare_data', file_extractor={
35
+ ".pdf": UnstructuredReader(),
36
+ ".docx": UnstructuredReader(),
37
+ ".pptx": UnstructuredReader(),
38
+ },
39
+ recursive=True,
40
+ exclude=["*.png", "*.pptx"],
41
+ file_metadata=default_file_metadata_func)
42
+
43
+ self.documents = dir_reader.load_data()
44
+ super()._load_doucments()
45
+
46
+ def _setup_service_context(self):
47
+ self.service_context = ServiceContext.from_defaults(
48
+ chunk_size=self.chunk_size,
49
+ llm=self.llm,
50
+ embed_model=self.embed_model
51
+ )
52
+ super()._setup_service_context()
53
+
54
+ def _setup_vector_store(self):
55
+ self.vector_store = QdrantVectorStore(
56
+ client=qdrantClient, collection_name=self.vdb_collection_name)
57
+ super()._setup_vector_store()
58
+
59
+ def _setup_index(self):
60
+ if self.vdb_collection_name in [col.name for col in qdrantClient.get_collections().collections] and qdrantClient.get_collection(self.vdb_collection_name).vectors_count > 0:
61
+ self.index = VectorStoreIndex.from_vector_store(
62
+ self.vector_store, service_context=self.service_context)
63
+ print("set up index from vector store")
64
+ return
65
+ pipeline = IngestionPipeline(
66
+ transformations=[
67
+ SentenceSplitter(),
68
+ OpenAIEmbedding(),
69
+ ],
70
+ vector_store=self.vector_store,
71
+ )
72
+ pipeline.run(documents=self.documents)
73
+ self.index = VectorStoreIndex.from_vector_store(
74
+ self.vector_store, service_context=self.service_context)
75
+ super()._setup_index()
76
+
77
+ # def _setup_index(self):
78
+ # self.index = VectorStoreIndex.from_documents(
79
+ # self.documents,
80
+ # service_context=self.service_context
81
+ # )
82
+ # super()._setup_index()
83
+
84
+ def _setup_chat_engine(self):
85
+ # testing #
86
+ from llama_index.agent import OpenAIAgent
87
+ from llama_index.tools.query_engine import QueryEngineTool
88
+
89
+ query_engine = self.index.as_query_engine(
90
+ text_qa_template=CHAT_TEXT_QA_PROMPT)
91
+ query_engine_tool = QueryEngineTool.from_defaults(
92
+ query_engine=query_engine)
93
+ self.chat_engine = OpenAIAgent.from_tools(
94
+ tools=[query_engine_tool],
95
+ llm=self.service_context.llm,
96
+ similarity_top_k=1,
97
+ verbose=True
98
+ )
99
+ print("set up agent as chat engine")
100
+ # testing #
101
+ # self.chat_engine = self.index.as_chat_engine(
102
+ # chat_mode=ChatMode.BEST,
103
+ # similarity_top_k=5,
104
+ # text_qa_template=CHAT_TEXT_QA_PROMPT)
105
+ super()._setup_chat_engine()
106
+
107
+
108
+ # gpt-3.5-turbo-1106, gpt-4-1106-preview
109
+ awesum_chatbot = AwesumCareChatbot(ChatbotVersion.CHATGPT_35.value,
110
+ chunk_size=2048,
111
+ vdb_collection_name="v2")
112
+
113
+
114
+ def vote(data: gr.LikeData):
115
+ if data.liked:
116
+ gr.Info("You up-voted this response: " + data.value)
117
+ else:
118
+ gr.Info("You down-voted this response: " + data.value)
119
+
120
+
121
+ chatbot = gr.Chatbot()
122
+
123
+ with gr.Blocks() as demo:
124
+ gr.Markdown("# Awesum Care demo")
125
+
126
+ with gr.Tab("With awesum care data prepared"):
127
+ gr.ChatInterface(
128
+ awesum_chatbot.stream_chat,
129
+ chatbot=chatbot,
130
+ examples=awesum_chatbot.CHAT_EXAMPLES,
131
+ )
132
+ chatbot.like(vote, None, None)
133
+
134
+ with gr.Tab("With Initial System Prompt (a.k.a. prompt wrapper)"):
135
+ gr.ChatInterface(
136
+ awesum_chatbot.predict_with_prompt_wrapper, examples=awesum_chatbot.CHAT_EXAMPLES)
137
+
138
+ with gr.Tab("Vanilla ChatGPT without modification"):
139
+ gr.ChatInterface(awesum_chatbot.predict_vanilla_chatgpt,
140
+ examples=awesum_chatbot.CHAT_EXAMPLES)
141
+
142
+ demo.queue()
143
+ demo.launch()
chat_template.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llama_index.llms.base import ChatMessage, MessageRole
2
+ from llama_index.prompts.base import ChatPromptTemplate
3
+
4
+ # text qa prompt
5
+ TEXT_QA_SYSTEM_PROMPT = ChatMessage(
6
+ content=(
7
+ "You are '安心三寶', a specialized chatbot for elderly users, trusted for providing "
8
+ "detailed information on legal and medical documents like '平安紙', '持久授權書', and '預設醫療指示'.\n"
9
+ "Always answer queries using the context information provided, focusing on delivering "
10
+ "accurate, comprehensive, and user-friendly responses.\n"
11
+ ),
12
+ role=MessageRole.SYSTEM,
13
+ )
14
+
15
+ TEXT_QA_PROMPT_TMPL_MSGS = [
16
+ TEXT_QA_SYSTEM_PROMPT,
17
+ ChatMessage(
18
+ content=(
19
+ "Context information is below.\n"
20
+ "---------------------\n"
21
+ "{context_str}\n"
22
+ "---------------------\n"
23
+ "Given the context information and not prior knowledge, "
24
+ "answer the query in a warm, approachable manner, ensuring clarity and precision.\n"
25
+ "Query: {query_str}\n"
26
+ "Answer: "
27
+ ),
28
+ role=MessageRole.USER,
29
+ ),
30
+ ]
31
+
32
+ CHAT_TEXT_QA_PROMPT = ChatPromptTemplate(message_templates=TEXT_QA_PROMPT_TMPL_MSGS)
chatbot.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import List
3
+ import os
4
+ import re
5
+ from typing import List
6
+ from dotenv import load_dotenv
7
+ from openai import OpenAI
8
+ import phoenix as px
9
+ import llama_index
10
+ from llama_index import OpenAIEmbedding
11
+ from llama_index.llms import ChatMessage, MessageRole, OpenAI
12
+
13
+ load_dotenv()
14
+
15
+
16
+ class Chatbot:
17
+ SYSTEM_PROMPT = ""
18
+ DENIED_ANSWER_PROMPT = ""
19
+ CHAT_EXAMPLES = []
20
+
21
+ def __init__(self, model_name, chunk_size, vdb_collection_name="test_store"):
22
+ self.model_name = model_name
23
+ self.llm = OpenAI(model=self.model_name)
24
+ self.embed_model = OpenAIEmbedding()
25
+ self.chunk_size = chunk_size
26
+
27
+ self.documents = None
28
+ self.index = None
29
+ self.chat_engine = None
30
+ self.service_context = None
31
+ self.vector_store = None
32
+ self.vdb_collection_name = vdb_collection_name
33
+
34
+ self._setup_chatbot()
35
+
36
+ def _setup_chatbot(self):
37
+ self._setup_observer()
38
+ self._setup_service_context()
39
+ self._setup_vector_store()
40
+ self._load_doucments()
41
+ self._setup_index()
42
+ self._setup_chat_engine()
43
+
44
+ def _setup_observer(self):
45
+ px.launch_app()
46
+ llama_index.set_global_handler("arize_phoenix")
47
+
48
+ def _load_doucments(self):
49
+ pass
50
+ print(f"Loaded {len(self.documents)} docs")
51
+
52
+ def _setup_service_context(self):
53
+ pass
54
+ print("Setup service context...")
55
+
56
+ def _setup_vector_store(self):
57
+ pass
58
+ print("Setup vector store...")
59
+
60
+ def _setup_index(self):
61
+ if self.documents is None:
62
+ raise ValueError("No documents loaded")
63
+ pass
64
+ print("Built index...")
65
+
66
+ def _setup_chat_engine(self):
67
+ if self.index is None:
68
+ raise ValueError("No index built")
69
+ pass
70
+ print("Setup chat engine...")
71
+
72
+ def stream_chat(self, message, history):
73
+ print(history)
74
+ print(self.convert_to_chat_messages(history))
75
+ response = self.chat_engine.stream_chat(
76
+ message, chat_history=self.convert_to_chat_messages(history)
77
+ )
78
+ # Stream tokens as they are generated
79
+ partial_message = ""
80
+ for token in response.response_gen:
81
+ partial_message += token
82
+ yield partial_message
83
+
84
+ urls = [source.node.metadata.get(
85
+ "file_name") for source in response.source_nodes if source.score >= 0.78 and source.node.metadata.get("file_name")]
86
+ if urls:
87
+ urls = list(set(urls))
88
+ url_section = "\n \n\n---\n\n參考: \n" + \
89
+ "\n".join(f"- {url}" for url in urls)
90
+ partial_message += url_section
91
+ yield partial_message
92
+
93
+ def convert_to_chat_messages(self, history: List[List[str]]) -> List[ChatMessage]:
94
+ chat_messages = [ChatMessage(
95
+ role=MessageRole.SYSTEM, content=self.SYSTEM_PROMPT)]
96
+ for conversation in history[-3:]:
97
+ for index, message in enumerate(conversation):
98
+ role = MessageRole.USER if index % 2 == 0 else MessageRole.ASSISTANT
99
+ clean_message = re.sub(
100
+ r"\n \n\n---\n\n參考: \n.*$", "", message, flags=re.DOTALL)
101
+ chat_messages.append(ChatMessage(
102
+ role=role, content=clean_message.strip()))
103
+ return chat_messages
104
+
105
+ def predict_with_rag(self, message, history):
106
+ return self.stream_chat(message, history)
107
+
108
+ # barebone chatgpt methods, shared across all chatbot instance
109
+ def _invoke_chatgpt(self, history, message, is_include_system_prompt=False):
110
+ openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
111
+ history_openai_format = []
112
+ if is_include_system_prompt:
113
+ history_openai_format.append(
114
+ {"role": "system", "content": self.SYSTEM_PROMPT})
115
+ for human, assistant in history:
116
+ history_openai_format.append({"role": "user", "content": human})
117
+ history_openai_format.append(
118
+ {"role": "assistant", "content": assistant})
119
+ history_openai_format.append({"role": "user", "content": message})
120
+
121
+ import openai
122
+ print(openai.__version__)
123
+ stream = openai_client.chat.completions.create(
124
+ model=self.model_name,
125
+ messages=history_openai_format,
126
+ temperature=1.0,
127
+ stream=True)
128
+ for part in stream:
129
+ yield part.choices[0].delta.content or ""
130
+ # partial_message = ""
131
+ # for chunk in response:
132
+ # if len(chunk["choices"][0]["delta"]) != 0:
133
+ # partial_message = partial_message + \
134
+ # chunk["choices"][0]["delta"]["content"]
135
+ # yield partial_message
136
+
137
+ # For 'With Prompt Wrapper' - Add system prompt, no Pinecone
138
+ def predict_with_prompt_wrapper(self, message, history):
139
+ yield from self._invoke_chatgpt(history, message, is_include_system_prompt=True)
140
+
141
+ # For 'Vanilla ChatGPT' - No system prompt
142
+ def predict_vanilla_chatgpt(self, message, history):
143
+ yield from self._invoke_chatgpt(history, message)
144
+
145
+
146
+ # make a enum of chatbot type and string
147
+
148
+
149
+ class ChatbotVersion(str, Enum):
150
+ CHATGPT_35 = "gpt-3.5-turbo-1106"
151
+ CHATGPT_4 = "gpt-4-1106-preview"
custom_io.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unstructured file reader.
2
+
3
+ A parser for unstructured text files using Unstructured.io.
4
+ Supports .txt, .docx, .pptx, .jpg, .png, .eml, .html, and .pdf documents.
5
+
6
+ """
7
+ from datetime import datetime
8
+ import mimetypes
9
+ import os
10
+ from pathlib import Path
11
+ from typing import Any, Dict, List, Optional
12
+
13
+ from llama_index.readers.base import BaseReader
14
+ from llama_index.readers.schema.base import Document
15
+
16
+
17
+ class UnstructuredReader(BaseReader):
18
+ """General unstructured text reader for a variety of files."""
19
+
20
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
21
+ """Init params."""
22
+ super().__init__(*args, **kwargs)
23
+
24
+ # Prerequisite for Unstructured.io to work
25
+ import nltk
26
+
27
+ nltk.download("punkt")
28
+ nltk.download("averaged_perceptron_tagger")
29
+
30
+ def load_data(
31
+ self,
32
+ file: Path,
33
+ extra_info: Optional[Dict] = None,
34
+ split_documents: Optional[bool] = True,
35
+ ) -> List[Document]:
36
+ """Parse file."""
37
+ from unstructured.partition.auto import partition
38
+
39
+ elements = partition(str(file))
40
+ text_chunks = [" ".join(str(el).split()) for el in elements]
41
+
42
+ if split_documents:
43
+ return [
44
+ Document(text=chunk, extra_info=extra_info or {})
45
+ for chunk in text_chunks
46
+ ]
47
+ else:
48
+ return [
49
+ Document(text="\n\n".join(text_chunks), extra_info=extra_info or {})
50
+ ]
51
+
52
+
53
+ def default_file_metadata_func(file_path: str) -> Dict:
54
+ """Get some handy metadate from filesystem.
55
+
56
+ Args:
57
+ file_path: str: file path in str
58
+ """
59
+ return {
60
+ "file_path": file_path,
61
+ "file_name": os.path.basename(file_path),
62
+ "file_type": mimetypes.guess_type(file_path)[0],
63
+ "file_size": os.path.getsize(file_path),
64
+ "creation_date": datetime.fromtimestamp(
65
+ Path(file_path).stat().st_ctime
66
+ ).strftime("%Y-%m-%d"),
67
+ "last_modified_date": datetime.fromtimestamp(
68
+ Path(file_path).stat().st_mtime
69
+ ).strftime("%Y-%m-%d"),
70
+ "last_accessed_date": datetime.fromtimestamp(
71
+ Path(file_path).stat().st_atime
72
+ ).strftime("%Y-%m-%d"),
73
+ }
qdrant.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ import qdrant_client
3
+
4
+
5
+ client = qdrant_client.QdrantClient(path="/tmp/total_qdrant/")
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy
2
+ openai
3
+ llama_index
4
+ arize-phoenix[experimental]
5
+ pypdf
6
+ gradio
7
+ # unstructure io