import os import logging from llama_index import download_loader from llama_index import ( Document, LLMPredictor, PromptHelper, QuestionAnswerPrompt, RefinePrompt, ) import colorama import PyPDF2 from tqdm import tqdm from modules.presets import * from modules.utils import * from modules.config import local_embedding def get_index_name(file_src): file_paths = [x.name for x in file_src] file_paths.sort(key=lambda x: os.path.basename(x)) md5_hash = hashlib.md5() for file_path in file_paths: with open(file_path, "rb") as f: while chunk := f.read(8192): md5_hash.update(chunk) return md5_hash.hexdigest() def block_split(text): blocks = [] while len(text) > 0: blocks.append(Document(text[:1000])) text = text[1000:] return blocks def get_documents(file_src): documents = [] logging.debug("Loading documents...") logging.debug(f"file_src: {file_src}") for file in file_src: filepath = file.name filename = os.path.basename(filepath) file_type = os.path.splitext(filepath)[1] logging.info(f"loading file: {filename}") try: if file_type == ".pdf": logging.debug("Loading PDF...") try: from modules.pdf_func import parse_pdf from modules.config import advance_docs two_column = advance_docs["pdf"].get("two_column", False) pdftext = parse_pdf(filepath, two_column).text except: pdftext = "" with open(filepath, "rb") as pdfFileObj: pdfReader = PyPDF2.PdfReader(pdfFileObj) for page in tqdm(pdfReader.pages): pdftext += page.extract_text() text_raw = pdftext elif file_type == ".docx": logging.debug("Loading Word...") DocxReader = download_loader("DocxReader") loader = DocxReader() text_raw = loader.load_data(file=filepath)[0].text elif file_type == ".epub": logging.debug("Loading EPUB...") EpubReader = download_loader("EpubReader") loader = EpubReader() text_raw = loader.load_data(file=filepath)[0].text elif file_type == ".xlsx": logging.debug("Loading Excel...") text_list = excel_to_string(filepath) for elem in text_list: documents.append(Document(elem)) continue else: logging.debug("Loading text file...") with open(filepath, "r", encoding="utf-8") as f: text_raw = f.read() except Exception as e: logging.error(f"Error loading file: {filename}") pass text = add_space(text_raw) # text = block_split(text) # documents += text documents += [Document(text)] logging.debug("Documents loaded.") return documents def construct_index( api_key, file_src, max_input_size=4096, num_outputs=5, max_chunk_overlap=20, chunk_size_limit=600, embedding_limit=None, separator=" ", ): from langchain.chat_models import ChatOpenAI from langchain.embeddings.huggingface import HuggingFaceEmbeddings from llama_index import GPTSimpleVectorIndex, ServiceContext, LangchainEmbedding, OpenAIEmbedding if api_key: os.environ["OPENAI_API_KEY"] = api_key else: # 由于一个依赖的愚蠢的设计,这里必须要有一个API KEY os.environ["OPENAI_API_KEY"] = "sk-xxxxxxx" chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit embedding_limit = None if embedding_limit == 0 else embedding_limit separator = " " if separator == "" else separator prompt_helper = PromptHelper( max_input_size=max_input_size, num_output=num_outputs, max_chunk_overlap=max_chunk_overlap, embedding_limit=embedding_limit, chunk_size_limit=600, separator=separator, ) index_name = get_index_name(file_src) if os.path.exists(f"./index/{index_name}.json"): logging.info("找到了缓存的索引文件,加载中……") return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json") else: try: documents = get_documents(file_src) if local_embedding: embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name = "sentence-transformers/distiluse-base-multilingual-cased-v2")) else: embed_model = OpenAIEmbedding() logging.info("构建索引中……") with retrieve_proxy(): service_context = ServiceContext.from_defaults( prompt_helper=prompt_helper, chunk_size_limit=chunk_size_limit, embed_model=embed_model, ) index = GPTSimpleVectorIndex.from_documents( documents, service_context=service_context ) logging.debug("索引构建完成!") os.makedirs("./index", exist_ok=True) index.save_to_disk(f"./index/{index_name}.json") logging.debug("索引已保存至本地!") return index except Exception as e: logging.error("索引构建失败!", e) print(e) return None def add_space(text): punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "} for cn_punc, en_punc in punctuations.items(): text = text.replace(cn_punc, en_punc) return text