import os |
import shutil |
from typing import Optional |
from langchain.document_loaders import UnstructuredFileLoader |
from langchain.embeddings import OpenAIEmbeddings |
from langchain.schema import Document |
from langchain.text_splitter import RecursiveCharacterTextSplitter |
from langchain.vectorstores import FAISS |
from loguru import logger |
from tqdm import tqdm |
from .parser import parse_pdf |
PROMPT_TEMPLATE = """已知信息: |
{context} |
根据上述已知信息,简洁和专业的来回答用户的问题。 |
如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 |
问题是:{question}""" |
def _get_documents(filepath, chunk_size=500, chunk_overlap=0, two_column=False): |
text_splitter = RecursiveCharacterTextSplitter( |
chunk_size=chunk_size, |
chunk_overlap=chunk_overlap, |
) |
file_type = os.path.splitext(filepath)[1] |
logger.info(f"Loading file: {filepath}") |
texts = Document(page_content="", metadata={"source": filepath}) |
try: |
if file_type == ".pdf": |
logger.debug("Loading PDF...") |
try: |
pdftext = parse_pdf(filepath, two_column).text |
except: |
from PyPDF2 import PdfReader |
pdftext = "" |
with open(filepath, "rb") as pdfFileObj: |
pdfReader = PdfReader(pdfFileObj) |
for page in tqdm(pdfReader.pages): |
pdftext += page.extract_text() |
texts = Document(page_content=pdftext, metadata={"source": filepath}) |
elif file_type == ".docx": |
from langchain.document_loaders import UnstructuredWordDocumentLoader |
logger.debug("Loading Word...") |
loader = UnstructuredWordDocumentLoader(filepath) |
texts = loader.load() |
elif file_type == ".pptx": |
from langchain.document_loaders import UnstructuredPowerPointLoader |
logger.debug("Loading PowerPoint...") |
loader = UnstructuredPowerPointLoader(filepath) |
texts = loader.load() |
elif file_type == ".epub": |
from langchain.document_loaders import UnstructuredEPubLoader |
logger.debug("Loading EPUB...") |
loader = UnstructuredEPubLoader(filepath) |
texts = loader.load() |
elif file_type == ".md": |
loader = UnstructuredFileLoader(filepath, mode="elements") |
return loader.load() |
else: |
loader = UnstructuredFileLoader(filepath, mode="elements") |
return loader.load_and_split(text_splitter=text_splitter) |
except Exception as e: |
import traceback |
logger.error(f"Error loading file: {filepath}") |
traceback.print_exc() |
return text_splitter.split_documents([texts]) |
def get_documents(filepath, chunk_size=500, chunk_overlap=0, two_column=False): |
documents = [] |
logger.debug("Loading documents...") |
if os.path.isfile(filepath): |
documents.extend( |
_get_documents( |
filepath, |
chunk_size=chunk_size, |
chunk_overlap=chunk_overlap, |
two_column=two_column |
) |
) |
else: |
for file in filepath: |
documents.extend( |
_get_documents( |
file, |
chunk_size=chunk_size, |
chunk_overlap=chunk_overlap, |
two_column=two_column |
) |
) |
logger.debug("Documents loaded.") |
return documents |
def generate_prompt(related_docs, query: str, prompt_template=PROMPT_TEMPLATE) -> str: |
context = "\n".join([doc[0].page_content for doc in related_docs]) |
return prompt_template.replace("{question}", query).replace("{context}", context) |
class DocQAPromptAdapter: |
def __init__(self, chunk_size: Optional[int] = 500, chunk_overlap: Optional[int] = 0, api_key: Optional[str] = "xxx"): |
self.embeddings = OpenAIEmbeddings(openai_api_key=api_key) |
self.chunk_size = chunk_size |
self.chunk_overlap = chunk_overlap |
self.vector_store = None |
def create_vector_store(self, file_path, vs_path, embeddings=None): |
documents = get_documents(file_path, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) |
self.vector_store = FAISS.from_documents(documents, self.embeddings if not embeddings else embeddings) |
self.vector_store.save_local(vs_path) |
def reset_vector_store(self, vs_path, embeddings=None): |
self.vector_store = FAISS.load_local(vs_path, self.embeddings if not embeddings else embeddings) |
@staticmethod |
def delete_files(files): |
for file in files: |
if os.path.exists(file): |
if os.path.isfile(file): |
os.remove(file) |
else: |
shutil.rmtree(file) |
def __call__(self, query, vs_path=None, topk=6): |
if vs_path is not None and os.path.exists(vs_path): |
self.reset_vector_store(vs_path) |
self.vector_store.embedding_function = self.embeddings.embed_query |
related_docs_with_score = self.vector_store.similarity_search_with_score(query, k=topk) |
return generate_prompt(related_docs_with_score, query) |