OpenLLM / tools /doc_qa.py
xusenlin's picture
Upload 13 files
20048e3
import os
import shutil
from typing import Optional
from langchain.document_loaders import UnstructuredFileLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from loguru import logger
from tqdm import tqdm
from .parser import parse_pdf
PROMPT_TEMPLATE = """已知信息:
{context}
根据上述已知信息,简洁和专业的来回答用户的问题。
如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。
问题是:{question}"""
def _get_documents(filepath, chunk_size=500, chunk_overlap=0, two_column=False):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
file_type = os.path.splitext(filepath)[1]
logger.info(f"Loading file: {filepath}")
texts = Document(page_content="", metadata={"source": filepath})
try:
if file_type == ".pdf":
logger.debug("Loading PDF...")
try:
pdftext = parse_pdf(filepath, two_column).text
except:
from PyPDF2 import PdfReader
pdftext = ""
with open(filepath, "rb") as pdfFileObj:
pdfReader = PdfReader(pdfFileObj)
for page in tqdm(pdfReader.pages):
pdftext += page.extract_text()
texts = Document(page_content=pdftext, metadata={"source": filepath})
elif file_type == ".docx":
from langchain.document_loaders import UnstructuredWordDocumentLoader
logger.debug("Loading Word...")
loader = UnstructuredWordDocumentLoader(filepath)
texts = loader.load()
elif file_type == ".pptx":
from langchain.document_loaders import UnstructuredPowerPointLoader
logger.debug("Loading PowerPoint...")
loader = UnstructuredPowerPointLoader(filepath)
texts = loader.load()
elif file_type == ".epub":
from langchain.document_loaders import UnstructuredEPubLoader
logger.debug("Loading EPUB...")
loader = UnstructuredEPubLoader(filepath)
texts = loader.load()
elif file_type == ".md":
loader = UnstructuredFileLoader(filepath, mode="elements")
return loader.load()
else:
loader = UnstructuredFileLoader(filepath, mode="elements")
return loader.load_and_split(text_splitter=text_splitter)
except Exception as e:
import traceback
logger.error(f"Error loading file: {filepath}")
traceback.print_exc()
return text_splitter.split_documents([texts])
def get_documents(filepath, chunk_size=500, chunk_overlap=0, two_column=False):
documents = []
logger.debug("Loading documents...")
if os.path.isfile(filepath):
documents.extend(
_get_documents(
filepath,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
two_column=two_column
)
)
else:
for file in filepath:
documents.extend(
_get_documents(
file,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
two_column=two_column
)
)
logger.debug("Documents loaded.")
return documents
def generate_prompt(related_docs, query: str, prompt_template=PROMPT_TEMPLATE) -> str:
context = "\n".join([doc[0].page_content for doc in related_docs])
return prompt_template.replace("{question}", query).replace("{context}", context)
class DocQAPromptAdapter:
def __init__(self, chunk_size: Optional[int] = 500, chunk_overlap: Optional[int] = 0, api_key: Optional[str] = "xxx"):
self.embeddings = OpenAIEmbeddings(openai_api_key=api_key)
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.vector_store = None
def create_vector_store(self, file_path, vs_path, embeddings=None):
documents = get_documents(file_path, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
self.vector_store = FAISS.from_documents(documents, self.embeddings if not embeddings else embeddings)
self.vector_store.save_local(vs_path)
def reset_vector_store(self, vs_path, embeddings=None):
self.vector_store = FAISS.load_local(vs_path, self.embeddings if not embeddings else embeddings)
@staticmethod
def delete_files(files):
for file in files:
if os.path.exists(file):
if os.path.isfile(file):
os.remove(file)
else:
shutil.rmtree(file)
def __call__(self, query, vs_path=None, topk=6):
if vs_path is not None and os.path.exists(vs_path):
self.reset_vector_store(vs_path)
self.vector_store.embedding_function = self.embeddings.embed_query
related_docs_with_score = self.vector_store.similarity_search_with_score(query, k=topk)
return generate_prompt(related_docs_with_score, query)