obsidian-qa-bot / app.py
anpigon's picture
Update document loader and add platform-specific model configuration
84af0cb
raw
history blame
7.55 kB
import os
import gradio as gr
from langchain_community.document_loaders import ObsidianLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter, Language
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain_cohere import CohereRerank
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_core.prompts import PromptTemplate
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_core.callbacks.manager import CallbackManager
from langchain_core.runnables import ConfigurableField
from langchain.callbacks.base import BaseCallbackHandler
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_groq import ChatGroq
from langchain_community.llms import HuggingFaceHub
from langchain_google_genai import GoogleGenerativeAI
import platform
directories = ["./docs/obsidian-help", "./docs/obsidian-developer"]
# 1. 문서 로더를 사용하여 모든 .md 파일을 로드합니다.
md_documents = []
for directory in directories:
try:
loader = ObsidianLoader(directory, encoding="utf-8")
md_documents.extend(loader.load())
except Exception:
pass
# 2. 청크 분할기를 생성합니다.
# 청크 크기는 2000, 청크간 겹치는 부분은 200 문자로 설정합니다.
md_splitter = RecursiveCharacterTextSplitter.from_language(
language=Language.MARKDOWN,
chunk_size=2000,
chunk_overlap=200,
)
md_docs = md_splitter.split_documents(md_documents)
# 3. 임베딩 모델을 사용하여 문서의 임베딩을 계산합니다.
# 허깅페이스 임베딩 모델 인스턴스를 생성합니다. 모델명으로 "BAAI/bge-m3 "을 사용합니다.
if platform.system() == "Darwin":
model_kwargs = {"device": "mps"}
else:
model_kwargs = {"device": "cpu"}
model_name = "BAAI/bge-m3"
encode_kwargs = {"normalize_embeddings": False}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
# CacheBackedEmbeddings를 사용하여 임베딩 계산 결과를 캐시합니다.
store = LocalFileStore("./.cache/")
cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
embeddings,
store,
namespace=embeddings.model_name,
)
# 4. FAISS 벡터 데이터베이스 인덱스를 생성하고 저장합니다.
FAISS_DB_INDEX = "db_index"
if os.path.exists(FAISS_DB_INDEX):
# 저장된 데이터베이스 인덱스가 이미 존재하는 경우, 해당 인덱스를 로드합니다.
db = FAISS.load_local(
FAISS_DB_INDEX, # 로드할 FAISS 인덱스의 디렉토리 이름
cached_embeddings, # 임베딩 정보를 제공
allow_dangerous_deserialization=True, # 역직렬화를 허용하는 옵션
)
else:
# combined_documents 문서들과 cached_embeddings 임베딩을 사용하여
# FAISS 데이터베이스 인스턴스를 생성합니다.
db = FAISS.from_documents(md_docs, cached_embeddings)
# 생성된 데이터베이스 인스턴스를 지정한 폴더에 로컬로 저장합니다.
db.save_local(folder_path=FAISS_DB_INDEX)
# 5. Retrieval를 생성합니다.
faiss_retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 10})
# 문서 컬렉션을 사용하여 BM25 검색 모델 인스턴스를 생성합니다.
bm25_retriever = BM25Retriever.from_documents(md_docs) # 초기화에 사용할 문서 컬렉션
bm25_retriever.k = 10 # 검색 시 최대 10개의 결과를 반환하도록 합니다.
# EnsembleRetriever 인스턴스를 생성합니다.
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever], # 사용할 검색 모델의 리스트
weights=[0.6, 0.4], # 각 검색 모델의 결과에 적용할 가중치
search_type="mmr", # 검색 결과의 다양성을 증진시키는 MMR 방식을 사용
)
# 6. CohereRerank 모델을 사용하여 재정렬을 수행합니다.
compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=5)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=ensemble_retriever,
)
# 7. Prompt를 생성합니다.
prompt = PromptTemplate.from_template(
"""당신은 20년 경력의 옵시디언 노트앱 및 플러그인 개발 전문가로, 옵시디언 노트앱 사용법, 플러그인 및 테마 개발에 대한 깊은 지식을 가지고 있습니다. 당신의 주된 임무는 제공된 문서를 바탕으로 질문에 최대한 정확하고 상세하게 답변하는 것입니다.
문서에는 옵시디언 노트앱의 기본 사용법, 고급 기능, 플러그인 개발 방법, 테마 개발 가이드 등 옵시디언 노트앱을 깊이 있게 사용하고 확장하는 데 필요한 정보가 포함되어 있습니다.
귀하의 답변은 다음 지침에 따라야 합니다:
1. 모든 답변은 명확하고 이해하기 쉬운 한국어로 제공되어야 합니다.
2. 답변은 문서의 내용을 기반으로 해야 하며, 가능한 한 구체적인 정보를 포함해야 합니다.
3. 문서 내에서 직접적인 답변을 찾을 수 없는 경우, "문서에는 해당 질문에 대한 구체적인 답변이 없습니다."라고 명시해 주세요.
4. 가능한 경우, 답변과 관련된 문서의 구체적인 부분(예: 섹션 이름, 페이지 번호 등)을 출처로서 명시해 주세요.
5. 질문에 대한 답변이 문서에 부분적으로만 포함되어 있는 경우, 가능한 한 많은 정보를 종합하여 답변해 주세요. 또한, 추가적인 연구나 참고자료가 필요할 수 있음을 언급해 주세요.
#참고문서:
\"\"\"
{context}
\"\"\"
#질문:
{question}
#답변:
출처:
- source1
- source2
- ...
"""
)
# 7. chain를 생성합니다.
llm = ChatGroq(
model_name="llama3-70b-8192",
temperature=0,
).configurable_alternatives(
ConfigurableField(id="llm"),
default_key="llama3",
gemini=GoogleGenerativeAI(
model="gemini-pro",
temperature=0,
),
)
def format_docs(docs):
formatted_docs = []
for doc in docs:
formatted_doc = f"Page Content:\n{doc.page_content}\n"
if doc.metadata.get("source"):
formatted_doc += f"Source: {doc.metadata['source']}\n"
formatted_docs.append(formatted_doc)
return "\n---\n".join(formatted_docs)
rag_chain = (
{"context": compression_retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
# # 8. chain를 실행합니다.
def predict(message, history=None):
answer = rag_chain.invoke(message)
return answer
gr.ChatInterface(
predict,
title="옵시디언 노트앱 및 플러그인 개발에 대해서 물어보세요!",
description="안녕하세요!\n저는 옵시디언 노트앱과 플러그인 개발에 대한 인공지능 QA봇입니다. 옵시디언 노트앱의 사용법, 고급 기능, 플러그인 및 테마 개발에 대해 깊은 지식을 가지고 있어요. 문서 작업, 정보 정리 또는 개발에 관한 도움이 필요하시면 언제든지 질문해주세요!",
).launch()