Makar / app.py
TimurZav's picture
Create app.py
3fca975
import tempfile
import itertools
import gradio as gr
from __init__ import *
from llama_cpp import Llama
from chromadb.config import Settings
from typing import List, Optional, Union
from langchain.vectorstores import Chroma
from langchain.docstore.document import Document
from huggingface_hub.file_download import http_get
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
class LocalChatGPT:
def __init__(self):
self.llama_model: Optional[Llama] = None
self.embeddings: HuggingFaceEmbeddings = self.initialize_app()
def initialize_app(self) -> HuggingFaceEmbeddings:
"""
Загружаем все модели из списка.
:return:
"""
os.makedirs(MODELS_DIR, exist_ok=True)
model_url, model_name = list(DICT_REPO_AND_MODELS.items())[0]
final_model_path = os.path.join(MODELS_DIR, model_name)
os.makedirs("/".join(final_model_path.split("/")[:-1]), exist_ok=True)
if not os.path.exists(final_model_path):
with open(final_model_path, "wb") as f:
http_get(model_url, f)
self.llama_model = Llama(
model_path=final_model_path,
n_ctx=2000,
n_parts=1,
)
return HuggingFaceEmbeddings(model_name=EMBEDDER_NAME, cache_folder=MODELS_DIR)
def load_model(self, model_name):
"""
:param model_name:
:return:
"""
final_model_path = os.path.join(MODELS_DIR, model_name)
os.makedirs("/".join(final_model_path.split("/")[:-1]), exist_ok=True)
if not os.path.exists(final_model_path):
with open(final_model_path, "wb") as f:
if model_url := [i for i in DICT_REPO_AND_MODELS if DICT_REPO_AND_MODELS[i] == model_name]:
http_get(model_url[0], f)
self.llama_model = Llama(
model_path=final_model_path,
n_ctx=2000,
n_parts=1,
)
return model_name
@staticmethod
def load_single_document(file_path: str) -> Document:
"""
Загружаем один документ.
:param file_path:
:return:
"""
ext: str = "." + file_path.rsplit(".", 1)[-1]
assert ext in LOADER_MAPPING
loader_class, loader_args = LOADER_MAPPING[ext]
loader = loader_class(file_path, **loader_args)
return loader.load()[0]
@staticmethod
def get_message_tokens(model: Llama, role: str, content: str) -> list:
"""
:param model:
:param role:
:param content:
:return:
"""
message_tokens: list = model.tokenize(content.encode("utf-8"))
message_tokens.insert(1, ROLE_TOKENS[role])
message_tokens.insert(2, LINEBREAK_TOKEN)
message_tokens.append(model.token_eos())
return message_tokens
def get_system_tokens(self, model: Llama) -> list:
"""
:param model:
:return:
"""
system_message: dict = {"role": "system", "content": SYSTEM_PROMPT}
return self.get_message_tokens(model, **system_message)
@staticmethod
def upload_files(files: List[tempfile.TemporaryFile]) -> List[str]:
"""
:param files:
:return:
"""
return [f.name for f in files]
@staticmethod
def process_text(text: str) -> Optional[str]:
"""
:param text:
:return:
"""
lines: list = text.split("\n")
lines = [line for line in lines if len(line.strip()) > 2]
text = "\n".join(lines).strip()
return None if len(text) < 10 else text
@staticmethod
def update_text_db(
db: Optional[Chroma],
fixed_documents: List[Document],
ids: List[str]
) -> Union[Optional[Chroma], str]:
if db:
data: dict = db.get()
files_db = {dict_data['source'].split('/')[-1] for dict_data in data["metadatas"]}
files_load = {dict_data.metadata["source"].split('/')[-1] for dict_data in fixed_documents}
if files_load == files_db:
# db.delete([item for item in data['ids'] if item not in ids])
# db.update_documents(ids, fixed_documents)
db.delete(data['ids'])
db.add_texts(
texts=[doc.page_content for doc in fixed_documents],
metadatas=[doc.metadata for doc in fixed_documents],
ids=ids
)
file_warning = f"Загружено {len(fixed_documents)} фрагментов! Можно задавать вопросы."
return db, file_warning
def build_index(
self,
file_paths: List[str],
db: Optional[Chroma],
chunk_size: int,
chunk_overlap: int
):
"""
:param file_paths:
:param db:
:param chunk_size:
:param chunk_overlap:
:return:
"""
documents: List[Document] = [self.load_single_document(path) for path in file_paths]
text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
documents = text_splitter.split_documents(documents)
fixed_documents: List[Document] = []
for doc in documents:
doc.page_content = self.process_text(doc.page_content)
if not doc.page_content:
continue
fixed_documents.append(doc)
ids: List[str] = [
f"{path.split('/')[-1].replace('.txt', '')}{i}"
for path, i in itertools.product(file_paths, range(1, len(fixed_documents) + 1))
]
self.update_text_db(db, fixed_documents, ids)
db = Chroma.from_documents(
documents=fixed_documents,
embedding=self.embeddings,
ids=ids,
client_settings=Settings(
anonymized_telemetry=False,
persist_directory="db"
)
)
file_warning = f"Загружено {len(fixed_documents)} фрагментов! Можно задавать вопросы."
return db, file_warning
@staticmethod
def user(message, history):
new_history = history + [[message, None]]
return "", new_history
@staticmethod
def regenerate_response(history):
"""
:param history:
:return:
"""
return "", history
@staticmethod
def retrieve(history, db: Optional[Chroma], retrieved_docs):
"""
:param history:
:param db:
:param retrieved_docs:
:return:
"""
if db:
last_user_message = history[-1][0]
try:
docs = db.similarity_search(last_user_message, k=4)
# retriever = db.as_retriever(search_kwargs={"k": k_documents})
# docs = retriever.get_relevant_documents(last_user_message)
except RuntimeError:
docs = db.similarity_search(last_user_message, k=1)
# retriever = db.as_retriever(search_kwargs={"k": 1})
# docs = retriever.get_relevant_documents(last_user_message)
source_docs = set()
for doc in docs:
for content in doc.metadata.values():
source_docs.add(content.split("/")[-1])
retrieved_docs = "\n\n".join([doc.page_content for doc in docs])
retrieved_docs = f"Документ - {''.join(list(source_docs))}.\n\n{retrieved_docs}"
return retrieved_docs
def bot(self, history, retrieved_docs):
"""
:param history:
:param retrieved_docs:
:return:
"""
if not history:
return
tokens = self.get_system_tokens(self.llama_model)[:]
tokens.append(LINEBREAK_TOKEN)
for user_message, bot_message in history[:-1]:
message_tokens = self.get_message_tokens(model=self.llama_model, role="user", content=user_message)
tokens.extend(message_tokens)
last_user_message = history[-1][0]
if retrieved_docs:
last_user_message = f"Контекст: {retrieved_docs}\n\nИспользуя контекст, ответь на вопрос: " \
f"{last_user_message}"
message_tokens = self.get_message_tokens(model=self.llama_model, role="user", content=last_user_message)
tokens.extend(message_tokens)
role_tokens = [self.llama_model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN]
tokens.extend(role_tokens)
generator = self.llama_model.generate(
tokens,
top_k=30,
top_p=0.9,
temp=0.1
)
partial_text = ""
for i, token in enumerate(generator):
if token == self.llama_model.token_eos() or (MAX_NEW_TOKENS is not None and i >= MAX_NEW_TOKENS):
break
partial_text += self.llama_model.detokenize([token]).decode("utf-8", "ignore")
history[-1][1] = partial_text
yield history
def run(self):
"""
:return:
"""
with gr.Blocks(theme=gr.themes.Soft(), css=BLOCK_CSS) as demo:
db: Optional[Chroma] = gr.State(None)
favicon = f'<img src="{FAVICON_PATH}" width="48px" style="display: inline">'
gr.Markdown(
f"""<h1><center>{favicon} Я, Макар - текстовый ассистент на основе GPT</center></h1>"""
)
with gr.Row(elem_id="model_selector_row"):
models: list = list(DICT_REPO_AND_MODELS.values())
model_selector = gr.Dropdown(
choices=models,
value=models[0] if models else "",
interactive=True,
show_label=False,
container=False,
)
with gr.Row():
with gr.Column(scale=5):
chatbot = gr.Chatbot(label="Диалог", height=400)
with gr.Column(min_width=200, scale=4):
retrieved_docs = gr.Textbox(
label="Извлеченные фрагменты",
placeholder="Появятся после задавания вопросов",
interactive=False
)
with gr.Row():
with gr.Column(scale=20):
msg = gr.Textbox(
label="Отправить сообщение",
show_label=False,
placeholder="Отправить сообщение",
container=False
)
with gr.Column(scale=3, min_width=100):
submit = gr.Button("📤 Отправить", variant="primary")
with gr.Row():
# gr.Button(value="👍 Понравилось")
# gr.Button(value="👎 Не понравилось")
stop = gr.Button(value="⛔ Остановить")
regenerate = gr.Button(value="🔄 Повторить")
clear = gr.Button(value="🗑️ Очистить")
# # Upload files
# file_output.upload(
# fn=self.upload_files,
# inputs=[file_output],
# outputs=[file_paths],
# queue=True,
# ).success(
# fn=self.build_index,
# inputs=[file_paths, db, chunk_size, chunk_overlap],
# outputs=[db, file_warning],
# queue=True
# )
model_selector.change(
fn=self.load_model,
inputs=[model_selector],
outputs=[model_selector]
)
# Pressing Enter
submit_event = msg.submit(
fn=self.user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
queue=False,
).success(
fn=self.retrieve,
inputs=[chatbot, db, retrieved_docs],
outputs=[retrieved_docs],
queue=True,
).success(
fn=self.bot,
inputs=[chatbot, retrieved_docs],
outputs=chatbot,
queue=True,
)
# Pressing the button
submit_click_event = submit.click(
fn=self.user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
queue=False,
).success(
fn=self.retrieve,
inputs=[chatbot, db, retrieved_docs],
outputs=[retrieved_docs],
queue=True,
).success(
fn=self.bot,
inputs=[chatbot, retrieved_docs],
outputs=chatbot,
queue=True,
)
# Stop generation
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
# Regenerate
regenerate.click(
fn=self.regenerate_response,
inputs=[chatbot],
outputs=[msg, chatbot],
queue=False,
).success(
fn=self.retrieve,
inputs=[chatbot, db, retrieved_docs],
outputs=[retrieved_docs],
queue=True,
).success(
fn=self.bot,
inputs=[chatbot, retrieved_docs],
outputs=chatbot,
queue=True,
)
# Clear history
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue(max_size=128, default_concurrency_limit=10, api_open=False)
demo.launch(server_name="0.0.0.0", max_threads=200)
if __name__ == "__main__":
local_chat_gpt = LocalChatGPT()
local_chat_gpt.run()