Spaces:
Runtime error
Runtime error
初始化代码
Browse files- README.md +28 -7
- app.py +264 -0
- knowledge/__init__.py +0 -0
- knowledge/faiss_handler.py +203 -0
- knowledge/img_handler.py +44 -0
- llms/__init__.py +0 -0
- llms/chatbot.py +37 -0
- llms/embeddings.py +30 -0
- llms/tools.py +45 -0
- requirements.txt +0 -0
- utils.py +12 -0
README.md
CHANGED
@@ -1,12 +1,33 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Faiss Chat
|
3 |
+
emoji: 🐠
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.32.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: mit
|
11 |
---
|
12 |
|
13 |
+
# FAISS Chat: Chat with FAISS database
|
14 |
+
|
15 |
+
Webui版本的Langchain-Chat. 目前支持两个功能:
|
16 |
+
* 将本地PDF和TXT文件打包上传, 构建FAISS向量数据库.
|
17 |
+
* 直接上传本地的FAISS向量数据库.
|
18 |
+
|
19 |
+
|
20 |
+
## 更新日志
|
21 |
+
|
22 |
+
|
23 |
+
* 2023-06-04:
|
24 |
+
* 支持读取图片格式的图表数据(目前支持JPG, PNG)
|
25 |
+
|
26 |
+
* 2023-06-04:
|
27 |
+
* 支持更多文件格式 (目前支持PDF, TXT, MD, TEX)
|
28 |
+
* 支持更多Embedding Models (目前支持[text-embedding-ada-002](https://openai.com/blog/new-and-improved-embedding-model), [text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese), 和[distilbert-dot-tas_b-b256-msmarco](https://huggingface.co/sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco) )
|
29 |
+
* 优化本地知识库文件结构.
|
30 |
+
|
31 |
+
## 体验地址
|
32 |
+
[Huggingface Space](https://huggingface.co/spaces/shaocongma/faiss_chat)
|
33 |
+
|
app.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import uuid
|
5 |
+
from datetime import datetime
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import openai
|
9 |
+
from huggingface_hub import HfApi
|
10 |
+
from langchain.document_loaders import PyPDFLoader, \
|
11 |
+
UnstructuredPDFLoader, PyPDFium2Loader, PyMuPDFLoader, PDFPlumberLoader
|
12 |
+
|
13 |
+
from knowledge.faiss_handler import create_faiss_index_from_zip, load_faiss_index_from_zip
|
14 |
+
from knowledge.img_handler import process_image, add_markup
|
15 |
+
from llms.chatbot import OpenAIChatBot
|
16 |
+
from llms.embeddings import EMBEDDINGS_MAPPING
|
17 |
+
from utils import make_archive
|
18 |
+
|
19 |
+
UPLOAD_REPO_ID=os.getenv("UPLOAD_REPO_ID")
|
20 |
+
HF_TOKEN=os.getenv("HF_TOKEN")
|
21 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
22 |
+
openai.api_base == os.getenv("OPENAI_API_BASE")
|
23 |
+
hf_api = HfApi(token=HF_TOKEN)
|
24 |
+
|
25 |
+
ALL_PDF_LOADERS = [PyPDFLoader, UnstructuredPDFLoader, PyPDFium2Loader, PyMuPDFLoader, PDFPlumberLoader]
|
26 |
+
ALL_EMBEDDINGS = EMBEDDINGS_MAPPING.keys()
|
27 |
+
PDF_LOADER_MAPPING = {loader.__name__: loader for loader in ALL_PDF_LOADERS}
|
28 |
+
|
29 |
+
|
30 |
+
#######################################################################################################################
|
31 |
+
# Host multiple vector database for use
|
32 |
+
#######################################################################################################################
|
33 |
+
# todo: add this feature in the future
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
INSTRUCTIONS = '''# FAISS Chat: 和本地数据库聊天!
|
38 |
+
|
39 |
+
***2023-06-06更新:***
|
40 |
+
1. 支持读取图片格式的图表数据(目前支持JPG, PNG).
|
41 |
+
2. 在"总结图表(Demo)"的标签页里提供了这个模块的测试.
|
42 |
+
|
43 |
+
***2023-06-04更新:***
|
44 |
+
1. 支持更多的Embedding Model (目前支持[text-embedding-ada-002](https://openai.com/blog/new-and-improved-embedding-model), [text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese), 和[distilbert-dot-tas_b-b256-msmarco](https://huggingface.co/sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco) )
|
45 |
+
2. 支持更多的文件格式(PDF, TXT, TEX, 和MD).
|
46 |
+
3. 所有生成的数据库都可以在[这个数据集](https://huggingface.co/datasets/shaocongma/shared-faiss-vdb)里访问了!如果不希望文件被上传,可以在高级设置里关闭.
|
47 |
+
'''
|
48 |
+
|
49 |
+
|
50 |
+
def load_zip_as_db(file_from_gradio,
|
51 |
+
pdf_loader,
|
52 |
+
embedding_model,
|
53 |
+
chunk_size=300,
|
54 |
+
chunk_overlap=20,
|
55 |
+
upload_to_cloud=True):
|
56 |
+
if chunk_size <= chunk_overlap:
|
57 |
+
return "chunk_size小于chunk_overlap. 创建失败.", None, None
|
58 |
+
if file_from_gradio is None:
|
59 |
+
return "文件为空. 创建失败.", None, None
|
60 |
+
pdf_loader = PDF_LOADER_MAPPING[pdf_loader]
|
61 |
+
zip_file_path = file_from_gradio.name
|
62 |
+
project_name = uuid.uuid4().hex
|
63 |
+
db, project_name, db_meta = create_faiss_index_from_zip(zip_file_path, embeddings=embedding_model,
|
64 |
+
pdf_loader=pdf_loader, chunk_size=chunk_size,
|
65 |
+
chunk_overlap=chunk_overlap, project_name=project_name)
|
66 |
+
index_name = project_name + ".zip"
|
67 |
+
make_archive(project_name, index_name)
|
68 |
+
date = datetime.today().strftime('%Y-%m-%d')
|
69 |
+
if upload_to_cloud:
|
70 |
+
hf_api.upload_file(path_or_fileobj=index_name,
|
71 |
+
path_in_repo=f"{date}/faiss_{index_name}.zip",
|
72 |
+
repo_id=UPLOAD_REPO_ID,
|
73 |
+
repo_type="dataset")
|
74 |
+
return "成功创建知识库. 可以开始聊天了!", index_name, db, db_meta
|
75 |
+
|
76 |
+
|
77 |
+
def load_local_db(file_from_gradio):
|
78 |
+
if file_from_gradio is None:
|
79 |
+
return "文件为空. 创建失败.", None
|
80 |
+
zip_file_path = file_from_gradio.name
|
81 |
+
db = load_faiss_index_from_zip(zip_file_path)
|
82 |
+
|
83 |
+
return "成功读取知识库. 可以开始聊天了!", db
|
84 |
+
|
85 |
+
|
86 |
+
def extract_image(image_path):
|
87 |
+
from PIL import Image
|
88 |
+
print("Image Path:", image_path)
|
89 |
+
im = Image.open(image_path)
|
90 |
+
table = process_image(im)
|
91 |
+
print(f"Success in processing the image. Table: {table}")
|
92 |
+
return table, add_markup(table)
|
93 |
+
|
94 |
+
|
95 |
+
def describe(image):
|
96 |
+
table = add_markup(process_image(image))
|
97 |
+
_INSTRUCTION = 'Read the table below to answer the following questions.'
|
98 |
+
question = "Please refer to the above table, and write a summary of no less than 200 words based on it in Chinese, ensuring that your response is detailed and precise. "
|
99 |
+
prompt_0shot = _INSTRUCTION + "\n" + add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
|
100 |
+
|
101 |
+
messages = [{"role": "assistant", "content": prompt_0shot}]
|
102 |
+
response = openai.ChatCompletion.create(
|
103 |
+
model="gpt-3.5-turbo",
|
104 |
+
messages=messages,
|
105 |
+
temperature=0.7,
|
106 |
+
top_p=1,
|
107 |
+
frequency_penalty=0,
|
108 |
+
presence_penalty=0,
|
109 |
+
)
|
110 |
+
ret = response.choices[0].message['content']
|
111 |
+
return ret
|
112 |
+
|
113 |
+
|
114 |
+
with gr.Blocks() as demo:
|
115 |
+
local_db = gr.State(None)
|
116 |
+
|
117 |
+
def get_augmented_message(message, local_db, query_count, preprocessing, meta):
|
118 |
+
print(f"Receiving message: {message}")
|
119 |
+
|
120 |
+
print("Detecting if the user need to read image from the local database...")
|
121 |
+
# read the db_meta.json from the local file
|
122 |
+
# read the images file list
|
123 |
+
files = meta["files"]
|
124 |
+
source_path = meta["source_path"]
|
125 |
+
# with open(meta.name, "r", encoding="utf-8") as f:
|
126 |
+
# files = json.load(f)["files"]
|
127 |
+
img_files = []
|
128 |
+
for file in files:
|
129 |
+
if os.path.splitext(file)[1] in [".png", ".jpg"]:
|
130 |
+
img_files.append(file)
|
131 |
+
|
132 |
+
# scan user's input to see if it contains images' name
|
133 |
+
do_extract_image = False
|
134 |
+
target_file = None
|
135 |
+
for file in img_files:
|
136 |
+
img = os.path.splitext(file)[0]
|
137 |
+
if img in message:
|
138 |
+
do_extract_image = True
|
139 |
+
target_file = file
|
140 |
+
break
|
141 |
+
|
142 |
+
# extract image to tables
|
143 |
+
image_info = ""
|
144 |
+
if do_extract_image:
|
145 |
+
print("The user needs to read image from the local database. Extract image ... ")
|
146 |
+
target_file = os.path.join(source_path, target_file)
|
147 |
+
_, image_info = extract_image(target_file)
|
148 |
+
if len(image_info)>0:
|
149 |
+
image_content = {"content": image_info, "source": os.path.basename(target_file)}
|
150 |
+
else:
|
151 |
+
image_content = None
|
152 |
+
|
153 |
+
print("Querying references from the local database...")
|
154 |
+
contents = []
|
155 |
+
try:
|
156 |
+
if query_count > 0:
|
157 |
+
docs = local_db.similarity_search(message, k=query_count)
|
158 |
+
for i in range(query_count):
|
159 |
+
# pre-processing each chunk
|
160 |
+
content = docs[i].page_content.replace('\n', ' ')
|
161 |
+
# pre-process meta data
|
162 |
+
contents.append(content)
|
163 |
+
except:
|
164 |
+
print("Failed to query from the local database. ")
|
165 |
+
# generate augmented_message
|
166 |
+
print("Success in querying references: {}".format(contents))
|
167 |
+
if image_content is not None:
|
168 |
+
augmented_message = f"{image_content}\n\n---\n\n" + "\n\n---\n\n".join(contents) + "\n\n-----\n\n"
|
169 |
+
else:
|
170 |
+
augmented_message = "\n\n---\n\n".join(contents) + "\n\n-----\n\n"
|
171 |
+
return augmented_message + "\n\n" + f"'user_input': {message}"
|
172 |
+
|
173 |
+
|
174 |
+
def respond(message, local_db, chat_history, meta, query_count=5, test_mode=False, response_delay=5, preprocessing=False):
|
175 |
+
gpt_chatbot = OpenAIChatBot()
|
176 |
+
print("Chat History: ", chat_history)
|
177 |
+
print("Local DB: ", local_db is None)
|
178 |
+
for chat in chat_history:
|
179 |
+
gpt_chatbot.load_chat(chat)
|
180 |
+
if local_db is None or query_count == 0:
|
181 |
+
bot_message = gpt_chatbot(message)
|
182 |
+
print(bot_message)
|
183 |
+
print(message)
|
184 |
+
chat_history.append((message, bot_message))
|
185 |
+
return "", chat_history
|
186 |
+
else:
|
187 |
+
augmented_message = get_augmented_message(message, local_db, query_count, preprocessing, meta)
|
188 |
+
bot_message = gpt_chatbot(augmented_message, original_message=message)
|
189 |
+
print(message)
|
190 |
+
print(augmented_message)
|
191 |
+
print(bot_message)
|
192 |
+
if test_mode:
|
193 |
+
chat_history.append((augmented_message, bot_message))
|
194 |
+
else:
|
195 |
+
chat_history.append((message, bot_message))
|
196 |
+
time.sleep(response_delay) # sleep 5 seconds to avoid freq. wall.
|
197 |
+
return "", chat_history
|
198 |
+
|
199 |
+
with gr.Row():
|
200 |
+
with gr.Column():
|
201 |
+
gr.Markdown(INSTRUCTIONS)
|
202 |
+
|
203 |
+
with gr.Row():
|
204 |
+
with gr.Tab("从本地PDF文件创建知识库"):
|
205 |
+
zip_file = gr.File(file_types=[".zip"], label="本地PDF文件(.zip)")
|
206 |
+
create_db = gr.Button("创建知识库", variant="primary")
|
207 |
+
with gr.Accordion("高级设置", open=False):
|
208 |
+
embedding_selector = gr.Dropdown(ALL_EMBEDDINGS,
|
209 |
+
value="distilbert-dot-tas_b-b256-msmarco",
|
210 |
+
label="Embedding Models")
|
211 |
+
pdf_loader_selector = gr.Dropdown([loader.__name__ for loader in ALL_PDF_LOADERS],
|
212 |
+
value=PyPDFLoader.__name__, label="PDF Loader")
|
213 |
+
chunk_size_slider = gr.Slider(minimum=50, maximum=2000, step=50, value=500,
|
214 |
+
label="Chunk size (tokens)")
|
215 |
+
chunk_overlap_slider = gr.Slider(minimum=0, maximum=500, step=1, value=50,
|
216 |
+
label="Chunk overlap (tokens)")
|
217 |
+
save_to_cloud_checkbox = gr.Checkbox(value=False, label="把数据库上传到云端")
|
218 |
+
|
219 |
+
|
220 |
+
file_dp_output = gr.File(file_types=[".zip"], label="(输出)知识库文件(.zip)")
|
221 |
+
with gr.Tab("读取本地知识库文件"):
|
222 |
+
file_local = gr.File(file_types=[".zip"], label="本地知识库文件(.zip)")
|
223 |
+
load_db = gr.Button("读取已创建知识库", variant="primary")
|
224 |
+
|
225 |
+
with gr.Tab("总结图表(Demo)"):
|
226 |
+
gr.Markdown(r"代码来源于: https://huggingface.co/spaces/fl399/deplot_plus_llm")
|
227 |
+
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
|
228 |
+
extract = gr.Button("总结", variant="primary")
|
229 |
+
|
230 |
+
output_text = gr.Textbox(lines=8, label="Output")
|
231 |
+
|
232 |
+
|
233 |
+
|
234 |
+
|
235 |
+
with gr.Column():
|
236 |
+
status = gr.Textbox(label="用来显示程序运行状态的Textbox")
|
237 |
+
chatbot = gr.Chatbot()
|
238 |
+
|
239 |
+
msg = gr.Textbox()
|
240 |
+
submit = gr.Button("Submit", variant="primary")
|
241 |
+
with gr.Accordion("高级设置", open=False):
|
242 |
+
json_output = gr.JSON()
|
243 |
+
with gr.Row():
|
244 |
+
query_count_slider = gr.Slider(minimum=0, maximum=10, step=1, value=3,
|
245 |
+
label="Query counts")
|
246 |
+
test_mode_checkbox = gr.Checkbox(label="Test mode")
|
247 |
+
|
248 |
+
|
249 |
+
# def load_pdf_as_db(file_from_gradio,
|
250 |
+
# pdf_loader,
|
251 |
+
# embedding_model,
|
252 |
+
# chunk_size=300,
|
253 |
+
# chunk_overlap=20,
|
254 |
+
# upload_to_cloud=True):
|
255 |
+
msg.submit(respond, [msg, local_db, chatbot, json_output, query_count_slider, test_mode_checkbox], [msg, chatbot])
|
256 |
+
submit.click(respond, [msg, local_db, chatbot, json_output, query_count_slider, test_mode_checkbox], [msg, chatbot])
|
257 |
+
|
258 |
+
create_db.click(load_zip_as_db, [zip_file, pdf_loader_selector, embedding_selector, chunk_size_slider, chunk_overlap_slider, save_to_cloud_checkbox],
|
259 |
+
[status, file_dp_output, local_db, json_output])
|
260 |
+
load_db.click(load_local_db, [file_local], [status, local_db])
|
261 |
+
|
262 |
+
extract.click(describe, [input_image], [output_text])
|
263 |
+
|
264 |
+
demo.launch(show_api=False)
|
knowledge/__init__.py
ADDED
File without changes
|
knowledge/faiss_handler.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import uuid
|
3 |
+
from langchain.vectorstores import FAISS
|
4 |
+
import os
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
+
from langchain.document_loaders import DirectoryLoader, TextLoader
|
8 |
+
from llms.embeddings import EMBEDDINGS_MAPPING
|
9 |
+
import tiktoken
|
10 |
+
import zipfile
|
11 |
+
import pickle
|
12 |
+
|
13 |
+
tokenizer_name = tiktoken.encoding_for_model('gpt-4')
|
14 |
+
tokenizer = tiktoken.get_encoding(tokenizer_name.name)
|
15 |
+
EMBED_MODEL = "text-embedding-ada-002"
|
16 |
+
EMBED_DIM = 1536
|
17 |
+
METRIC = 'cosine'
|
18 |
+
|
19 |
+
#######################################################################################################################
|
20 |
+
# Files handler
|
21 |
+
#######################################################################################################################
|
22 |
+
def check_existence(path):
|
23 |
+
return os.path.isfile(path) or os.path.isdir(path)
|
24 |
+
|
25 |
+
|
26 |
+
def list_files(directory, ext=".pdf"):
|
27 |
+
# List all files in the directory
|
28 |
+
files_in_directory = os.listdir(directory)
|
29 |
+
# Filter the list to only include PDF files
|
30 |
+
files_list = [file for file in files_in_directory if file.endswith(ext)]
|
31 |
+
return files_list
|
32 |
+
|
33 |
+
|
34 |
+
def list_pdf_files(directory):
|
35 |
+
# List all files in the directory
|
36 |
+
files_in_directory = os.listdir(directory)
|
37 |
+
# Filter the list to only include PDF files
|
38 |
+
pdf_files = [file for file in files_in_directory if file.endswith(".pdf")]
|
39 |
+
return pdf_files
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
def tiktoken_len(text):
|
44 |
+
# evaluate how many tokens for the given text
|
45 |
+
tokens = tokenizer.encode(text, disallowed_special=())
|
46 |
+
return len(tokens)
|
47 |
+
|
48 |
+
|
49 |
+
def get_chunks(docs, chunk_size=500, chunk_overlap=20, length_function=tiktoken_len):
|
50 |
+
# docs should be the output of `loader.load()`
|
51 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
|
52 |
+
chunk_overlap=chunk_overlap,
|
53 |
+
length_function=length_function,
|
54 |
+
separators=["\n\n", "\n", " ", ""])
|
55 |
+
chunks = []
|
56 |
+
for idx, page in enumerate(tqdm(docs)):
|
57 |
+
source = page.metadata.get('source')
|
58 |
+
content = page.page_content
|
59 |
+
if len(content) > 50:
|
60 |
+
texts = text_splitter.split_text(content)
|
61 |
+
chunks.extend([str({'content': texts[i], 'chunk': i, 'source': os.path.basename(source)}) for i in
|
62 |
+
range(len(texts))])
|
63 |
+
return chunks
|
64 |
+
|
65 |
+
|
66 |
+
#######################################################################################################################
|
67 |
+
# Create FAISS object
|
68 |
+
#######################################################################################################################
|
69 |
+
|
70 |
+
# ["text-embedding-ada-002", "distilbert-dot-tas_b-b256-msmarco"]
|
71 |
+
|
72 |
+
def create_faiss_index_from_zip(path_to_zip_file, embeddings=None, pdf_loader=None,
|
73 |
+
chunk_size=500, chunk_overlap=20,
|
74 |
+
project_name="Very_Cool_Project_Name"):
|
75 |
+
# initialize the file structure
|
76 |
+
# structure: project_name
|
77 |
+
# - source data
|
78 |
+
# - embeddings
|
79 |
+
# - faiss_index
|
80 |
+
if isinstance(embeddings, str):
|
81 |
+
import copy
|
82 |
+
embeddings_str = copy.deepcopy(embeddings)
|
83 |
+
else:
|
84 |
+
embeddings_str = "other-embedding-model"
|
85 |
+
|
86 |
+
if embeddings is None or embeddings == "text-embedding-ada-002":
|
87 |
+
embeddings = EMBEDDINGS_MAPPING["text-embedding-ada-002"]
|
88 |
+
elif isinstance(embeddings, str):
|
89 |
+
embeddings = EMBEDDINGS_MAPPING[embeddings]
|
90 |
+
else:
|
91 |
+
embeddings = EMBEDDINGS_MAPPING["text-embedding-ada-002"]
|
92 |
+
# STEP 1:
|
93 |
+
# Create a folder f"{project_name}" in the current directory.
|
94 |
+
current_directory = os.getcwd()
|
95 |
+
if not os.path.exists(project_name):
|
96 |
+
os.makedirs(project_name)
|
97 |
+
project_path = os.path.join(current_directory, project_name)
|
98 |
+
source_data = os.path.join(project_path, "source_data")
|
99 |
+
embeddings_data = os.path.join(project_path, "embeddings")
|
100 |
+
index_data = os.path.join(project_path, "faiss_index")
|
101 |
+
os.makedirs(source_data) #./project/source_data
|
102 |
+
os.makedirs(embeddings_data) #./project/embeddings
|
103 |
+
os.makedirs(index_data) #./project/faiss_index
|
104 |
+
else:
|
105 |
+
raise ValueError(f"The project {project_name} exists.")
|
106 |
+
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
|
107 |
+
# extract everything to "source_data"
|
108 |
+
zip_ref.extractall(source_data)
|
109 |
+
|
110 |
+
|
111 |
+
db_meta = {"project_name": project_name,
|
112 |
+
"pdf_loader": pdf_loader.__name__, "chunk_size": chunk_size,
|
113 |
+
"chunk_overlap": chunk_overlap,
|
114 |
+
"embedding_model": embeddings_str,
|
115 |
+
"files": os.listdir(source_data),
|
116 |
+
"source_path": source_data}
|
117 |
+
with open(os.path.join(project_path, "db_meta.json"), "w", encoding="utf-8") as f:
|
118 |
+
# save db_meta.json to folder
|
119 |
+
json.dump(db_meta, f)
|
120 |
+
|
121 |
+
|
122 |
+
all_docs = []
|
123 |
+
for ext in [".txt", ".tex", ".md", ".pdf"]:
|
124 |
+
if ext in [".txt", ".tex", ".md"]:
|
125 |
+
loader = DirectoryLoader(source_data, glob=f"**/*{ext}", loader_cls=TextLoader,
|
126 |
+
loader_kwargs={'autodetect_encoding': True})
|
127 |
+
elif ext in [".pdf"]:
|
128 |
+
loader = DirectoryLoader(source_data, glob=f"**/*{ext}", loader_cls=pdf_loader)
|
129 |
+
else:
|
130 |
+
continue
|
131 |
+
docs = loader.load()
|
132 |
+
all_docs = all_docs + docs
|
133 |
+
|
134 |
+
# split pdf files into chunks and evaluate its embeddings; save all results into embeddings
|
135 |
+
chunks = get_chunks(all_docs, chunk_size, chunk_overlap)
|
136 |
+
text_embeddings = embeddings.embed_documents(chunks)
|
137 |
+
text_embedding_pairs = list(zip(chunks, text_embeddings))
|
138 |
+
embeddings_save_to = os.path.join(embeddings_data, 'text_embedding_pairs.pickle')
|
139 |
+
with open(embeddings_save_to, 'wb') as handle:
|
140 |
+
pickle.dump(text_embedding_pairs, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
141 |
+
db = FAISS.from_embeddings(text_embedding_pairs, embeddings)
|
142 |
+
|
143 |
+
db.save_local(index_data)
|
144 |
+
print(db_meta)
|
145 |
+
print("Success!")
|
146 |
+
return db, project_name, db_meta
|
147 |
+
|
148 |
+
|
149 |
+
def find_file(file_name, directory):
|
150 |
+
for root, dirs, files in os.walk(directory):
|
151 |
+
if file_name in files:
|
152 |
+
return os.path.join(root, file_name)
|
153 |
+
return None # If the file was not found
|
154 |
+
|
155 |
+
def find_file_dir(file_name, directory):
|
156 |
+
for root, dirs, files in os.walk(directory):
|
157 |
+
if file_name in files:
|
158 |
+
return root # return the directory instead of the full path
|
159 |
+
return None # If the file was not found
|
160 |
+
|
161 |
+
|
162 |
+
def load_faiss_index_from_zip(path_to_zip_file):
|
163 |
+
# Extract the zip file. Read the db_meta
|
164 |
+
# base_name = os.path.basename(path_to_zip_file)
|
165 |
+
path_to_extract = os.path.join(os.getcwd())
|
166 |
+
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
|
167 |
+
zip_ref.extractall(path_to_extract)
|
168 |
+
|
169 |
+
db_meta_json = find_file("db_meta.json" , path_to_extract)
|
170 |
+
if db_meta_json is not None:
|
171 |
+
with open(db_meta_json, "r", encoding="utf-8") as f:
|
172 |
+
db_meta_dict = json.load(f)
|
173 |
+
else:
|
174 |
+
raise ValueError("Cannot find `db_meta.json` in the .zip file. ")
|
175 |
+
|
176 |
+
try:
|
177 |
+
embeddings = EMBEDDINGS_MAPPING[db_meta_dict["embedding_model"]]
|
178 |
+
except:
|
179 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
180 |
+
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
|
181 |
+
|
182 |
+
# locate index.faiss
|
183 |
+
index_path = find_file_dir("index.faiss", path_to_extract)
|
184 |
+
if index_path is not None:
|
185 |
+
db = FAISS.load_local(index_path, embeddings)
|
186 |
+
return db
|
187 |
+
else:
|
188 |
+
raise ValueError("Failed to find `index.faiss` in the .zip file.")
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
from langchain.document_loaders import PyPDFLoader
|
193 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
194 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
195 |
+
|
196 |
+
model_name = "sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco"
|
197 |
+
model_kwargs = {'device': 'cpu'}
|
198 |
+
encode_kwargs = {'normalize_embeddings': False}
|
199 |
+
embeddings = HuggingFaceEmbeddings(
|
200 |
+
model_name=model_name,
|
201 |
+
model_kwargs=model_kwargs,
|
202 |
+
encode_kwargs=encode_kwargs)
|
203 |
+
create_faiss_index_from_zip(path_to_zip_file="document.zip", pdf_loader=PyPDFLoader, embeddings=embeddings)
|
knowledge/img_handler.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
if torch.cuda.is_available():
|
6 |
+
device = "cuda"
|
7 |
+
else:
|
8 |
+
device = "cpu"
|
9 |
+
|
10 |
+
model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16)
|
11 |
+
if device == "cuda":
|
12 |
+
model_deplot = model_deplot.to(0)
|
13 |
+
processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def add_markup(table):
|
18 |
+
try:
|
19 |
+
parts = [p.strip() for p in table.splitlines(keepends=False)]
|
20 |
+
if parts[0].startswith('TITLE'):
|
21 |
+
result = f"Title: {parts[0].split(' | ')[1].strip()}\n"
|
22 |
+
rows = parts[1:]
|
23 |
+
else:
|
24 |
+
result = ''
|
25 |
+
rows = parts
|
26 |
+
prefixes = ['Header: '] + [f'Row {i+1}: ' for i in range(len(rows) - 1)]
|
27 |
+
return result + '\n'.join(prefix + row for prefix, row in zip(prefixes, rows))
|
28 |
+
except:
|
29 |
+
# just use the raw table if parsing fails
|
30 |
+
return table
|
31 |
+
|
32 |
+
def process_image(image):
|
33 |
+
inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:",
|
34 |
+
return_tensors="pt").to(torch.bfloat16)
|
35 |
+
if device == "cuda":
|
36 |
+
inputs = inputs.to(0)
|
37 |
+
predictions = model_deplot.generate(**inputs, max_new_tokens=512)
|
38 |
+
table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
|
39 |
+
return table
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
im = Image.open(r"meat-image.png")
|
44 |
+
process_image(im)
|
llms/__init__.py
ADDED
File without changes
|
llms/chatbot.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
import copy
|
3 |
+
|
4 |
+
class OpenAIChatBot:
|
5 |
+
def __init__(self, model="gpt-3.5-turbo"):
|
6 |
+
self.system = "You are Q&A bot. A highly intelligent system that answers user questions based on the information provided by the user's local database. " \
|
7 |
+
"User's question will include some references information above his question." \
|
8 |
+
"You need to answer user's question based on the provided references and inform the user what is the source of that reference. " \
|
9 |
+
"If you cannot find answer in the provided references, you still need to answer user's question but you also need to notice the user that your response is not based on the provided references."
|
10 |
+
self.model = model
|
11 |
+
self.message = [{"role": "system", "content": self.system}]
|
12 |
+
self.raw_message = [{"role": "system", "content": self.system}]
|
13 |
+
|
14 |
+
def load_message(self, message, role, original_message=None):
|
15 |
+
if original_message is None:
|
16 |
+
original_message = message
|
17 |
+
msg = {"role": role, "content": message}
|
18 |
+
self.message.append(msg)
|
19 |
+
msg = {"role": role, "content": original_message}
|
20 |
+
self.raw_message.append(msg)
|
21 |
+
|
22 |
+
def load_chat(self, chat):
|
23 |
+
msg = {"role": "user", "content": chat[0]}
|
24 |
+
self.message.append(msg)
|
25 |
+
msg = {"role": "assistant", "content": chat[1]}
|
26 |
+
self.raw_message.append(msg)
|
27 |
+
|
28 |
+
|
29 |
+
def __call__(self, message, original_message = None):
|
30 |
+
self.load_message(message, "user", original_message)
|
31 |
+
augmented_message = copy.deepcopy(self.message)
|
32 |
+
completion = openai.ChatCompletion.create(
|
33 |
+
model=self.model,
|
34 |
+
messages=augmented_message
|
35 |
+
)
|
36 |
+
assistant_message = completion.choices[0].message
|
37 |
+
return assistant_message["content"]
|
llms/embeddings.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
2 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
3 |
+
|
4 |
+
model_name = "sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco"
|
5 |
+
model_kwargs = {'device': 'cpu'}
|
6 |
+
encode_kwargs = {'normalize_embeddings': False}
|
7 |
+
hf_embeddings_1 = HuggingFaceEmbeddings(
|
8 |
+
model_name=model_name,
|
9 |
+
model_kwargs=model_kwargs,
|
10 |
+
encode_kwargs=encode_kwargs)
|
11 |
+
|
12 |
+
openai_embedding = OpenAIEmbeddings(model="text-embedding-ada-002")
|
13 |
+
|
14 |
+
|
15 |
+
model_name = "GanymedeNil/text2vec-large-chinese"
|
16 |
+
hf_embeddings_2 = HuggingFaceEmbeddings(
|
17 |
+
model_name=model_name,
|
18 |
+
model_kwargs=model_kwargs,
|
19 |
+
encode_kwargs=encode_kwargs)
|
20 |
+
|
21 |
+
|
22 |
+
EMBEDDINGS_MAPPING = {"text-embedding-ada-002": openai_embedding,
|
23 |
+
"distilbert-dot-tas_b-b256-msmarco": hf_embeddings_1,
|
24 |
+
"text2vec-large-chinese": hf_embeddings_2}
|
25 |
+
|
26 |
+
def main():
|
27 |
+
pass
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
main()
|
llms/tools.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
|
3 |
+
|
4 |
+
class BaseTool:
|
5 |
+
def __init__(self, model="gpt-3.5-turbo"):
|
6 |
+
self.system = ""
|
7 |
+
self.model = model
|
8 |
+
self.message = [{"role": "system", "content": self.system}]
|
9 |
+
|
10 |
+
def __call__(self, message):
|
11 |
+
user_message = {"role": "user", "content": message}
|
12 |
+
messages = self.message + [user_message]
|
13 |
+
completion = openai.ChatCompletion.create(
|
14 |
+
model=self.model,
|
15 |
+
messages=messages
|
16 |
+
)
|
17 |
+
assistant_message = completion.choices[0].message
|
18 |
+
return assistant_message["content"].replace("\n", " ")
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
class PreprocessingBot(BaseTool):
|
24 |
+
def __init__(self, model="gpt-3.5-turbo"):
|
25 |
+
super().__init__(model)
|
26 |
+
self.system = r"""You are an AI assistant for raw data pre-processing. The user will input multiple raw references which may include unicode characters or ASCII code such as '\u001e'. Your task it to make it more readable by doing:
|
27 |
+
- Change all unicode characters or ASCII code such as '\u001e' to LaTeX format and put them in formula environment $...$ or $$...$$.
|
28 |
+
- Re-write formulas or mathematical notations to LaTeX format in formula environment $...$ or $$...$$.
|
29 |
+
- Remove meaningless contents.
|
30 |
+
- Response in the following format: {pdf-name-1: main contents from pdf-name-1, pdf-name-2: main contents from pdf-name-2, ...}.
|
31 |
+
"""
|
32 |
+
self.message = [{"role": "system", "content": self.system}]
|
33 |
+
|
34 |
+
class ToolBot(BaseTool):
|
35 |
+
def __init__(self, model="gpt-3.5-turbo"):
|
36 |
+
super().__init__(model)
|
37 |
+
self.system = r"""You need to pretend a Python function. You receive a string that is the user's question to a QA bot. You need to analyze the user's goal and decide if the QA bot needs to use the search engine to generate the response to the user.
|
38 |
+
Response 1 if you think the QA bot needs to use the search engine to user's input and response 0 if the QA bot doesn't need that.
|
39 |
+
"""
|
40 |
+
self.message = [{"role": "system", "content": self.system}]
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
bot = ToolBot()
|
44 |
+
rsp = bot("Hello!")
|
45 |
+
print(rsp)
|
requirements.txt
ADDED
Binary file (4.19 kB). View file
|
|
utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
def make_archive(source, destination):
|
5 |
+
base = os.path.basename(destination)
|
6 |
+
name = base.split('.')[0]
|
7 |
+
format = base.split('.')[1]
|
8 |
+
archive_from = os.path.dirname(source)
|
9 |
+
archive_to = os.path.basename(source.strip(os.sep))
|
10 |
+
shutil.make_archive(name, format, archive_from, archive_to)
|
11 |
+
shutil.move('%s.%s'%(name,format), destination)
|
12 |
+
return destination
|