|
import tempfile |
|
import itertools |
|
import gradio as gr |
|
from __init__ import * |
|
from llama_cpp import Llama |
|
from chromadb.config import Settings |
|
from typing import List, Optional, Union |
|
from langchain.vectorstores import Chroma |
|
from langchain.docstore.document import Document |
|
from huggingface_hub.file_download import http_get |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
|
|
|
class LocalChatGPT: |
|
def __init__(self): |
|
self.llama_model: Optional[Llama] = None |
|
self.embeddings: HuggingFaceEmbeddings = self.initialize_app() |
|
|
|
def initialize_app(self) -> HuggingFaceEmbeddings: |
|
""" |
|
Загружаем все модели из списка. |
|
:return: |
|
""" |
|
os.makedirs(MODELS_DIR, exist_ok=True) |
|
model_url, model_name = list(DICT_REPO_AND_MODELS.items())[0] |
|
final_model_path = os.path.join(MODELS_DIR, model_name) |
|
os.makedirs("/".join(final_model_path.split("/")[:-1]), exist_ok=True) |
|
|
|
if not os.path.exists(final_model_path): |
|
with open(final_model_path, "wb") as f: |
|
http_get(model_url, f) |
|
|
|
self.llama_model = Llama( |
|
model_path=final_model_path, |
|
n_ctx=2000, |
|
n_parts=1, |
|
) |
|
|
|
return HuggingFaceEmbeddings(model_name=EMBEDDER_NAME, cache_folder=MODELS_DIR) |
|
|
|
def load_model(self, model_name): |
|
""" |
|
|
|
:param model_name: |
|
:return: |
|
""" |
|
final_model_path = os.path.join(MODELS_DIR, model_name) |
|
os.makedirs("/".join(final_model_path.split("/")[:-1]), exist_ok=True) |
|
|
|
if not os.path.exists(final_model_path): |
|
with open(final_model_path, "wb") as f: |
|
if model_url := [i for i in DICT_REPO_AND_MODELS if DICT_REPO_AND_MODELS[i] == model_name]: |
|
http_get(model_url[0], f) |
|
|
|
self.llama_model = Llama( |
|
model_path=final_model_path, |
|
n_ctx=2000, |
|
n_parts=1, |
|
) |
|
return model_name |
|
|
|
@staticmethod |
|
def load_single_document(file_path: str) -> Document: |
|
""" |
|
Загружаем один документ. |
|
:param file_path: |
|
:return: |
|
""" |
|
ext: str = "." + file_path.rsplit(".", 1)[-1] |
|
assert ext in LOADER_MAPPING |
|
loader_class, loader_args = LOADER_MAPPING[ext] |
|
loader = loader_class(file_path, **loader_args) |
|
return loader.load()[0] |
|
|
|
@staticmethod |
|
def get_message_tokens(model: Llama, role: str, content: str) -> list: |
|
""" |
|
|
|
:param model: |
|
:param role: |
|
:param content: |
|
:return: |
|
""" |
|
message_tokens: list = model.tokenize(content.encode("utf-8")) |
|
message_tokens.insert(1, ROLE_TOKENS[role]) |
|
message_tokens.insert(2, LINEBREAK_TOKEN) |
|
message_tokens.append(model.token_eos()) |
|
return message_tokens |
|
|
|
def get_system_tokens(self, model: Llama) -> list: |
|
""" |
|
|
|
:param model: |
|
:return: |
|
""" |
|
system_message: dict = {"role": "system", "content": SYSTEM_PROMPT} |
|
return self.get_message_tokens(model, **system_message) |
|
|
|
@staticmethod |
|
def upload_files(files: List[tempfile.TemporaryFile]) -> List[str]: |
|
""" |
|
|
|
:param files: |
|
:return: |
|
""" |
|
return [f.name for f in files] |
|
|
|
@staticmethod |
|
def process_text(text: str) -> Optional[str]: |
|
""" |
|
|
|
:param text: |
|
:return: |
|
""" |
|
lines: list = text.split("\n") |
|
lines = [line for line in lines if len(line.strip()) > 2] |
|
text = "\n".join(lines).strip() |
|
return None if len(text) < 10 else text |
|
|
|
@staticmethod |
|
def update_text_db( |
|
db: Optional[Chroma], |
|
fixed_documents: List[Document], |
|
ids: List[str] |
|
) -> Union[Optional[Chroma], str]: |
|
if db: |
|
data: dict = db.get() |
|
files_db = {dict_data['source'].split('/')[-1] for dict_data in data["metadatas"]} |
|
files_load = {dict_data.metadata["source"].split('/')[-1] for dict_data in fixed_documents} |
|
if files_load == files_db: |
|
|
|
|
|
|
|
db.delete(data['ids']) |
|
db.add_texts( |
|
texts=[doc.page_content for doc in fixed_documents], |
|
metadatas=[doc.metadata for doc in fixed_documents], |
|
ids=ids |
|
) |
|
file_warning = f"Загружено {len(fixed_documents)} фрагментов! Можно задавать вопросы." |
|
return db, file_warning |
|
|
|
def build_index( |
|
self, |
|
file_paths: List[str], |
|
db: Optional[Chroma], |
|
chunk_size: int, |
|
chunk_overlap: int |
|
): |
|
""" |
|
|
|
:param file_paths: |
|
:param db: |
|
:param chunk_size: |
|
:param chunk_overlap: |
|
:return: |
|
""" |
|
documents: List[Document] = [self.load_single_document(path) for path in file_paths] |
|
text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter( |
|
chunk_size=chunk_size, chunk_overlap=chunk_overlap |
|
) |
|
documents = text_splitter.split_documents(documents) |
|
fixed_documents: List[Document] = [] |
|
for doc in documents: |
|
doc.page_content = self.process_text(doc.page_content) |
|
if not doc.page_content: |
|
continue |
|
fixed_documents.append(doc) |
|
|
|
ids: List[str] = [ |
|
f"{path.split('/')[-1].replace('.txt', '')}{i}" |
|
for path, i in itertools.product(file_paths, range(1, len(fixed_documents) + 1)) |
|
] |
|
|
|
self.update_text_db(db, fixed_documents, ids) |
|
|
|
db = Chroma.from_documents( |
|
documents=fixed_documents, |
|
embedding=self.embeddings, |
|
ids=ids, |
|
client_settings=Settings( |
|
anonymized_telemetry=False, |
|
persist_directory="db" |
|
) |
|
) |
|
file_warning = f"Загружено {len(fixed_documents)} фрагментов! Можно задавать вопросы." |
|
return db, file_warning |
|
|
|
@staticmethod |
|
def user(message, history): |
|
new_history = history + [[message, None]] |
|
return "", new_history |
|
|
|
@staticmethod |
|
def regenerate_response(history): |
|
""" |
|
|
|
:param history: |
|
:return: |
|
""" |
|
return "", history |
|
|
|
@staticmethod |
|
def retrieve(history, db: Optional[Chroma], retrieved_docs): |
|
""" |
|
|
|
:param history: |
|
:param db: |
|
:param retrieved_docs: |
|
:return: |
|
""" |
|
if db: |
|
last_user_message = history[-1][0] |
|
try: |
|
docs = db.similarity_search(last_user_message, k=4) |
|
|
|
|
|
except RuntimeError: |
|
docs = db.similarity_search(last_user_message, k=1) |
|
|
|
|
|
source_docs = set() |
|
for doc in docs: |
|
for content in doc.metadata.values(): |
|
source_docs.add(content.split("/")[-1]) |
|
retrieved_docs = "\n\n".join([doc.page_content for doc in docs]) |
|
retrieved_docs = f"Документ - {''.join(list(source_docs))}.\n\n{retrieved_docs}" |
|
return retrieved_docs |
|
|
|
def bot(self, history, retrieved_docs): |
|
""" |
|
|
|
:param history: |
|
:param retrieved_docs: |
|
:return: |
|
""" |
|
if not history: |
|
return |
|
tokens = self.get_system_tokens(self.llama_model)[:] |
|
tokens.append(LINEBREAK_TOKEN) |
|
|
|
for user_message, bot_message in history[:-1]: |
|
message_tokens = self.get_message_tokens(model=self.llama_model, role="user", content=user_message) |
|
tokens.extend(message_tokens) |
|
|
|
last_user_message = history[-1][0] |
|
if retrieved_docs: |
|
last_user_message = f"Контекст: {retrieved_docs}\n\nИспользуя контекст, ответь на вопрос: " \ |
|
f"{last_user_message}" |
|
message_tokens = self.get_message_tokens(model=self.llama_model, role="user", content=last_user_message) |
|
tokens.extend(message_tokens) |
|
|
|
role_tokens = [self.llama_model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN] |
|
tokens.extend(role_tokens) |
|
generator = self.llama_model.generate( |
|
tokens, |
|
top_k=30, |
|
top_p=0.9, |
|
temp=0.1 |
|
) |
|
|
|
partial_text = "" |
|
for i, token in enumerate(generator): |
|
if token == self.llama_model.token_eos() or (MAX_NEW_TOKENS is not None and i >= MAX_NEW_TOKENS): |
|
break |
|
partial_text += self.llama_model.detokenize([token]).decode("utf-8", "ignore") |
|
history[-1][1] = partial_text |
|
yield history |
|
|
|
def run(self): |
|
""" |
|
|
|
:return: |
|
""" |
|
with gr.Blocks(theme=gr.themes.Soft(), css=BLOCK_CSS) as demo: |
|
db: Optional[Chroma] = gr.State(None) |
|
favicon = f'<img src="{FAVICON_PATH}" width="48px" style="display: inline">' |
|
gr.Markdown( |
|
f"""<h1><center>{favicon} Я, Макар - текстовый ассистент на основе GPT</center></h1>""" |
|
) |
|
|
|
with gr.Row(elem_id="model_selector_row"): |
|
models: list = list(DICT_REPO_AND_MODELS.values()) |
|
model_selector = gr.Dropdown( |
|
choices=models, |
|
value=models[0] if models else "", |
|
interactive=True, |
|
show_label=False, |
|
container=False, |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
chatbot = gr.Chatbot(label="Диалог", height=400) |
|
with gr.Column(min_width=200, scale=4): |
|
retrieved_docs = gr.Textbox( |
|
label="Извлеченные фрагменты", |
|
placeholder="Появятся после задавания вопросов", |
|
interactive=False |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=20): |
|
msg = gr.Textbox( |
|
label="Отправить сообщение", |
|
show_label=False, |
|
placeholder="Отправить сообщение", |
|
container=False |
|
) |
|
with gr.Column(scale=3, min_width=100): |
|
submit = gr.Button("📤 Отправить", variant="primary") |
|
|
|
with gr.Row(): |
|
|
|
|
|
stop = gr.Button(value="⛔ Остановить") |
|
regenerate = gr.Button(value="🔄 Повторить") |
|
clear = gr.Button(value="🗑️ Очистить") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_selector.change( |
|
fn=self.load_model, |
|
inputs=[model_selector], |
|
outputs=[model_selector] |
|
) |
|
|
|
|
|
submit_event = msg.submit( |
|
fn=self.user, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot], |
|
queue=False, |
|
).success( |
|
fn=self.retrieve, |
|
inputs=[chatbot, db, retrieved_docs], |
|
outputs=[retrieved_docs], |
|
queue=True, |
|
).success( |
|
fn=self.bot, |
|
inputs=[chatbot, retrieved_docs], |
|
outputs=chatbot, |
|
queue=True, |
|
) |
|
|
|
|
|
submit_click_event = submit.click( |
|
fn=self.user, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot], |
|
queue=False, |
|
).success( |
|
fn=self.retrieve, |
|
inputs=[chatbot, db, retrieved_docs], |
|
outputs=[retrieved_docs], |
|
queue=True, |
|
).success( |
|
fn=self.bot, |
|
inputs=[chatbot, retrieved_docs], |
|
outputs=chatbot, |
|
queue=True, |
|
) |
|
|
|
|
|
stop.click( |
|
fn=None, |
|
inputs=None, |
|
outputs=None, |
|
cancels=[submit_event, submit_click_event], |
|
queue=False, |
|
) |
|
|
|
|
|
regenerate.click( |
|
fn=self.regenerate_response, |
|
inputs=[chatbot], |
|
outputs=[msg, chatbot], |
|
queue=False, |
|
).success( |
|
fn=self.retrieve, |
|
inputs=[chatbot, db, retrieved_docs], |
|
outputs=[retrieved_docs], |
|
queue=True, |
|
).success( |
|
fn=self.bot, |
|
inputs=[chatbot, retrieved_docs], |
|
outputs=chatbot, |
|
queue=True, |
|
) |
|
|
|
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
demo.queue(max_size=128, default_concurrency_limit=10, api_open=False) |
|
demo.launch(server_name="0.0.0.0", max_threads=200) |
|
|
|
|
|
if __name__ == "__main__": |
|
local_chat_gpt = LocalChatGPT() |
|
local_chat_gpt.run() |