Artem733733 commited on
Commit
5aa92ab
·
verified ·
1 Parent(s): 80223ff

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +381 -0
  2. requirements.txt +16 -0
app.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import os
3
+ import gradio as gr
4
+
5
+ import torch
6
+ from uuid import uuid4
7
+ from huggingface_hub.file_download import http_get
8
+ from langchain_community.document_loaders import (
9
+ CSVLoader,
10
+ EverNoteLoader,
11
+ PDFMinerLoader,
12
+ TextLoader,
13
+ UnstructuredEmailLoader,
14
+ UnstructuredEPubLoader,
15
+ UnstructuredHTMLLoader,
16
+ UnstructuredMarkdownLoader,
17
+ UnstructuredODTLoader,
18
+ UnstructuredPowerPointLoader,
19
+ UnstructuredWordDocumentLoader,
20
+ )
21
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
22
+ from langchain.docstore.document import Document
23
+ from sentence_transformers import SentenceTransformer
24
+ from sentence_transformers.util import cos_sim
25
+ from llama_cpp import Llama
26
+
27
+
28
+ SYSTEM_PROMPT = "Ты — русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
29
+
30
+ LOADER_MAPPING = {
31
+ ".csv": (CSVLoader, {}),
32
+ ".doc": (UnstructuredWordDocumentLoader, {}),
33
+ ".docx": (UnstructuredWordDocumentLoader, {}),
34
+ ".enex": (EverNoteLoader, {}),
35
+ ".epub": (UnstructuredEPubLoader, {}),
36
+ ".html": (UnstructuredHTMLLoader, {}),
37
+ ".md": (UnstructuredMarkdownLoader, {}),
38
+ ".odt": (UnstructuredODTLoader, {}),
39
+ ".pdf": (PDFMinerLoader, {}),
40
+ ".ppt": (UnstructuredPowerPointLoader, {}),
41
+ ".pptx": (UnstructuredPowerPointLoader, {}),
42
+ ".txt": (TextLoader, {"encoding": "utf8"}),
43
+ }
44
+
45
+
46
+ def load_model(
47
+ directory: str = ".",
48
+ model_name: str = "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf",
49
+ model_url: str = "https://huggingface.co/second-state/Mistral-Nemo-Instruct-2407-GGUF/resolve/main/Mistral-Nemo-Instruct-2407-Q4_K_M.gguf"
50
+ ):
51
+ final_model_path = os.path.join(directory, model_name)
52
+
53
+ print("Downloading all files...")
54
+ if not os.path.exists(final_model_path):
55
+ with open(final_model_path, "wb") as f:
56
+ http_get(model_url, f)
57
+ os.chmod(final_model_path, 0o777)
58
+ print("Files downloaded!")
59
+
60
+ model = Llama(
61
+ model_path=final_model_path,
62
+ n_ctx=2000,
63
+ n_parts=1,
64
+ )
65
+
66
+ print("Model loaded!")
67
+ return model
68
+
69
+
70
+ EMBEDDER = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
71
+ #Alibaba-NLP/gte-multilingual-base
72
+ #Лидерборд по эмбеддингам
73
+ #intfloat/e5-mistral-7b-instruct-лучшая для русского языка
74
+ #deepvk/USER-bge-m3 - немного отстает по качеству, но в 10 раз меньше и быстрее
75
+ #BAAI/bge-multilingual-gemma2
76
+ #EMBEDDER = SentenceTransformer("intfloat/multilingual-e5-large-instruct")
77
+ #EMBEDDER = SentenceTransformer("deepvk/USER-bge-m3")
78
+ MODEL = load_model()
79
+
80
+
81
+ def get_uuid():
82
+ return str(uuid4())
83
+
84
+
85
+ def load_single_document(file_path: str) -> Document:
86
+ ext = "." + file_path.rsplit(".", 1)[-1]
87
+ assert ext in LOADER_MAPPING
88
+ loader_class, loader_args = LOADER_MAPPING[ext]
89
+ loader = loader_class(file_path, **loader_args)
90
+ return loader.load()[0]
91
+
92
+
93
+ def get_message_tokens(model, role, content):
94
+ content = f"{role}\n{content}\n</s>"
95
+ content = content.encode("utf-8")
96
+ return model.tokenize(content, special=True)
97
+
98
+
99
+ def get_system_tokens(model):
100
+ system_message = {"role": "system", "content": SYSTEM_PROMPT}
101
+ return get_message_tokens(model, **system_message)
102
+
103
+
104
+ def process_text(text):
105
+ lines = text.split("\n")
106
+ lines = [line for line in lines if len(line.strip()) > 2]
107
+ text = "\n".join(lines).strip()
108
+ if len(text) < 10:
109
+ return None
110
+ return text
111
+
112
+
113
+ def upload_files(files, file_paths):
114
+ file_paths = [f.name for f in files]
115
+ return file_paths
116
+
117
+
118
+ def build_index(file_paths, db, chunk_size, chunk_overlap, file_warning):
119
+ documents = [load_single_document(path) for path in file_paths]
120
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
121
+ documents = text_splitter.split_documents(documents)
122
+ print("Documents after split:", len(documents))
123
+ fixed_documents = []
124
+ for doc in documents:
125
+ doc.page_content = process_text(doc.page_content)
126
+ if not doc.page_content:
127
+ continue
128
+ fixed_documents.append(doc)
129
+ print("Documents after processing:", len(fixed_documents))
130
+
131
+ texts = [doc.page_content for doc in fixed_documents]
132
+ embeddings = EMBEDDER.encode(texts, convert_to_tensor=True)
133
+ db = {"docs": texts, "embeddings": embeddings}
134
+ print("Embeddings calculated!")
135
+
136
+ file_warning = f"Загружено {len(fixed_documents)} фрагментов! Можно задавать вопросы."
137
+ return db, file_warning
138
+
139
+
140
+ def retrieve(history, db, retrieved_docs, k_documents):
141
+ retrieved_docs = ""
142
+ if db:
143
+ last_user_message = history[-1][0]
144
+ query_embedding = EMBEDDER.encode(last_user_message, convert_to_tensor=True)
145
+ scores = cos_sim(query_embedding, db["embeddings"])[0]
146
+ top_k_idx = torch.topk(scores, k=k_documents)[1]
147
+ top_k_documents = [db["docs"][idx] for idx in top_k_idx]
148
+ retrieved_docs = "\n\n".join(top_k_documents)
149
+ return retrieved_docs
150
+
151
+
152
+ def user(message, history, system_prompt):
153
+ new_history = history + [[message, None]]
154
+ return "", new_history
155
+
156
+
157
+ def bot(
158
+ history,
159
+ system_prompt,
160
+ conversation_id,
161
+ retrieved_docs,
162
+ top_p,
163
+ top_k,
164
+ temp
165
+ ):
166
+ model = MODEL
167
+ if not history:
168
+ return
169
+
170
+ tokens = get_system_tokens(model)[:]
171
+
172
+ for user_message, bot_message in history[:-1]:
173
+ message_tokens = get_message_tokens(model=model, role="user", content=user_message)
174
+ tokens.extend(message_tokens)
175
+ if bot_message:
176
+ message_tokens = get_message_tokens(model=model, role="bot", content=bot_message)
177
+ tokens.extend(message_tokens)
178
+
179
+ last_user_message = history[-1][0]
180
+ if retrieved_docs:
181
+ last_user_message = f"Контекст: {retrieved_docs}\n\nИспользуя контекст, ответь на вопрос: {last_user_message}"
182
+ message_tokens = get_message_tokens(model=model, role="user", content=last_user_message)
183
+ tokens.extend(message_tokens)
184
+
185
+ role_tokens = model.tokenize("bot\n".encode("utf-8"), special=True)
186
+ tokens.extend(role_tokens)
187
+ generator = model.generate(
188
+ tokens,
189
+ top_k=top_k,
190
+ top_p=top_p,
191
+ temp=temp
192
+ )
193
+
194
+ partial_text = ""
195
+ for i, token in enumerate(generator):
196
+ if token == model.token_eos():
197
+ break
198
+ partial_text += model.detokenize([token]).decode("utf-8", "ignore")
199
+ history[-1][1] = partial_text
200
+ yield history
201
+
202
+
203
+ with gr.Blocks(
204
+ theme=gr.themes.Soft()
205
+ ) as demo:
206
+ db = gr.State(None)
207
+ conversation_id = gr.State(get_uuid)
208
+ #favicon = '<img src="https://cdn.midjourney.com/b88e5beb-6324-4820-8504-a1a37a9ba36d/0_1.png" width="48px" style="display: inline">'
209
+ gr.Markdown(
210
+ #f"""<h1><center>{favicon}Saiga 13B llama.cpp: retrieval QA</center></h1>
211
+ f"""<h1><center>Вопросно-ответная система по Вашим документам. Работает на CPU.\n
212
+ На демо-стенде реализован простейший алгоритм поиска информации, при внедрении в IT-контуре компании, качество поиска выше в разы.\n
213
+ Для внедрения быстрой версии (на GPU ответ быстрее в 20-100 раз) в Информационном контуре Вашей организации, пишите на e-mail: info@digital-human.ru</center></h1>
214
+ """
215
+ )
216
+
217
+ with gr.Row():
218
+ with gr.Column(scale=5):
219
+ file_output = gr.File(file_count="multiple", label="Загрузка файлов")
220
+ file_paths = gr.State([])
221
+ file_warning = gr.Markdown(f"Фрагменты ещё не загружены!")
222
+
223
+ with gr.Column(min_width=200, scale=3):
224
+ with gr.Tab(label="Параметры нарезки"):
225
+ chunk_size = gr.Slider(
226
+ minimum=50,
227
+ maximum=2000,
228
+ value=250,
229
+ step=50,
230
+ interactive=True,
231
+ label="Размер фрагментов",
232
+ )
233
+ chunk_overlap = gr.Slider(
234
+ minimum=0,
235
+ maximum=500,
236
+ value=30,
237
+ step=10,
238
+ interactive=True,
239
+ label="Пересечение"
240
+ )
241
+
242
+
243
+ with gr.Row():
244
+ k_documents = gr.Slider(
245
+ minimum=1,
246
+ maximum=10,
247
+ value=2,
248
+ step=1,
249
+ interactive=True,
250
+ label="Кол-во фрагментов для контекста"
251
+ )
252
+ with gr.Row():
253
+ retrieved_docs = gr.Textbox(
254
+ lines=6,
255
+ label="Извлеченные фрагменты",
256
+ placeholder="Появятся после задавания вопросов",
257
+ interactive=False
258
+ )
259
+ with gr.Row():
260
+ with gr.Column(scale=5):
261
+ system_prompt = gr.Textbox(label="Системный промпт", placeholder="", value=SYSTEM_PROMPT, interactive=False)
262
+ chatbot = gr.Chatbot(label="Диалог").style(height=400)
263
+ with gr.Column(min_width=80, scale=1):
264
+ with gr.Tab(label="Параметры генерации"):
265
+ top_p = gr.Slider(
266
+ minimum=0.0,
267
+ maximum=1.0,
268
+ value=0.9,
269
+ step=0.05,
270
+ interactive=True,
271
+ label="Top-p",
272
+ )
273
+ top_k = gr.Slider(
274
+ minimum=10,
275
+ maximum=100,
276
+ value=30,
277
+ step=5,
278
+ interactive=True,
279
+ label="Top-k",
280
+ )
281
+ temp = gr.Slider(
282
+ minimum=0.0,
283
+ maximum=2.0,
284
+ value=0.1,
285
+ step=0.1,
286
+ interactive=True,
287
+ label="Temp"
288
+ )
289
+
290
+ with gr.Row():
291
+ with gr.Column():
292
+ msg = gr.Textbox(
293
+ label="Отправить сообщение",
294
+ placeholder="Отправить сообщение",
295
+ show_label=False,
296
+ ).style(container=False)
297
+ with gr.Column():
298
+ with gr.Row():
299
+ submit = gr.Button("Отправить")
300
+ stop = gr.Button("Остановить")
301
+ clear = gr.Button("Очистить")
302
+
303
+ # Upload files
304
+ upload_event = file_output.change(
305
+ fn=upload_files,
306
+ inputs=[file_output, file_paths],
307
+ outputs=[file_paths],
308
+ queue=True,
309
+ ).success(
310
+ fn=build_index,
311
+ inputs=[file_paths, db, chunk_size, chunk_overlap, file_warning],
312
+ outputs=[db, file_warning],
313
+ queue=True
314
+ )
315
+
316
+ # Pressing Enter
317
+ submit_event = msg.submit(
318
+ fn=user,
319
+ inputs=[msg, chatbot, system_prompt],
320
+ outputs=[msg, chatbot],
321
+ queue=False,
322
+ ).success(
323
+ fn=retrieve,
324
+ inputs=[chatbot, db, retrieved_docs, k_documents],
325
+ outputs=[retrieved_docs],
326
+ queue=True,
327
+ ).success(
328
+ fn=bot,
329
+ inputs=[
330
+ chatbot,
331
+ system_prompt,
332
+ conversation_id,
333
+ retrieved_docs,
334
+ top_p,
335
+ top_k,
336
+ temp
337
+ ],
338
+ outputs=chatbot,
339
+ queue=True,
340
+ )
341
+
342
+ # Pressing the button
343
+ submit_click_event = submit.click(
344
+ fn=user,
345
+ inputs=[msg, chatbot, system_prompt],
346
+ outputs=[msg, chatbot],
347
+ queue=False,
348
+ ).success(
349
+ fn=retrieve,
350
+ inputs=[chatbot, db, retrieved_docs, k_documents],
351
+ outputs=[retrieved_docs],
352
+ queue=True,
353
+ ).success(
354
+ fn=bot,
355
+ inputs=[
356
+ chatbot,
357
+ system_prompt,
358
+ conversation_id,
359
+ retrieved_docs,
360
+ top_p,
361
+ top_k,
362
+ temp
363
+ ],
364
+ outputs=chatbot,
365
+ queue=True,
366
+ )
367
+
368
+ # Stop generation
369
+ stop.click(
370
+ fn=None,
371
+ inputs=None,
372
+ outputs=None,
373
+ cancels=[submit_event, submit_click_event],
374
+ queue=False,
375
+ )
376
+
377
+ # Clear history
378
+ clear.click(lambda: None, None, chatbot, queue=False)
379
+
380
+ demo.queue(max_size=128, concurrency_count=1)
381
+ demo.launch(show_error=True)
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ llama-cpp-python
2
+ langchain==0.2.4
3
+ langchain-community==0.2.4
4
+ chromadb==0.5.0
5
+ huggingface-hub==0.19.4
6
+ gradio==4.36.1
7
+ tenacity==8.3.0
8
+ torch==2.1.0
9
+ sentence-transformers
10
+ #langchain==0.0.174
11
+ #huggingface-hub==0.19.4
12
+ pdfminer.six==20221105
13
+ unstructured==0.6.10
14
+ #gradio
15
+ tabulate
16
+ azure-core