Spaces:
Sleeping
Sleeping
import os | |
from typing import Dict, Iterable, List, Union | |
import json | |
from langchain.document_loaders import (PyPDFLoader, TextLoader, | |
UnstructuredFileLoader) | |
from langchain.embeddings import ModelScopeEmbeddings | |
from langchain.embeddings.base import Embeddings | |
from langchain.schema import Document | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.vectorstores import FAISS, VectorStore | |
class Retrieval: | |
def __init__(self, | |
embedding: Embeddings = None, | |
vs_cls: VectorStore = None, | |
top_k: int = 5, | |
vs_params: Dict = {}): | |
self.embedding = embedding or ModelScopeEmbeddings( | |
model_id='damo/nlp_gte_sentence-embedding_chinese-base') | |
self.top_k = top_k | |
self.vs_cls = vs_cls or FAISS | |
self.vs_params = vs_params | |
self.vs = None | |
def construct(self, docs): | |
assert len(docs) > 0 | |
if isinstance(docs[0], str): | |
self.vs = self.vs_cls.from_texts(docs, self.embedding, | |
**self.vs_params) | |
elif isinstance(docs[0], Document): | |
self.vs = self.vs_cls.from_documents(docs, self.embedding, | |
**self.vs_params) | |
def retrieve(self, query: str) -> List[str]: | |
res = self.vs.similarity_search(query, k=self.top_k) | |
if 'page' in res[0].metadata: | |
res.sort(key=lambda doc: doc.metadata['page']) | |
return [r.page_content for r in res] | |
class ToolRetrieval(Retrieval): | |
def __init__(self, | |
embedding: Embeddings = None, | |
vs_cls: VectorStore = None, | |
top_k: int = 5, | |
vs_params: Dict = {}): | |
super().__init__(embedding, vs_cls, top_k, vs_params) | |
def retrieve(self, query: str) -> Dict[str, str]: | |
res = self.vs.similarity_search(query, k=self.top_k) | |
final_res = {} | |
for r in res: | |
content = r.page_content | |
name = json.loads(content)['name'] | |
final_res[name] = content | |
return final_res | |
class KnowledgeRetrieval(Retrieval): | |
def __init__(self, | |
docs, | |
embedding: Embeddings = None, | |
vs_cls: VectorStore = None, | |
top_k: int = 5, | |
vs_params: Dict = {}): | |
super().__init__(embedding, vs_cls, top_k, vs_params) | |
self.construct(docs) | |
def from_file(cls, | |
file_path: Union[str, list], | |
embedding: Embeddings = None, | |
vs_cls: VectorStore = None, | |
top_k: int = 5, | |
vs_params: Dict = {}): | |
textsplitter = CharacterTextSplitter() | |
all_files = [] | |
if isinstance(file_path, str) and os.path.isfile(file_path): | |
all_files.append(file_path) | |
elif isinstance(file_path, list): | |
all_files = file_path | |
elif os.path.isdir(file_path): | |
for root, dirs, files in os.walk(file_path): | |
for f in files: | |
all_files.append(os.path.join(root, f)) | |
else: | |
raise ValueError('file_path must be a file or a directory') | |
docs = [] | |
for f in all_files: | |
if f.lower().endswith('.txt'): | |
loader = TextLoader(f, autodetect_encoding=True) | |
docs += (loader.load_and_split(textsplitter)) | |
elif f.lower().endswith('.md'): | |
loader = UnstructuredFileLoader(f, mode='elements') | |
docs += loader.load() | |
elif f.lower().endswith('.pdf'): | |
loader = PyPDFLoader(f) | |
docs += (loader.load_and_split(textsplitter)) | |
else: | |
print(f'not support file type: {f}, will be support soon') | |
if len(docs) == 0: | |
return None | |
else: | |
return cls(docs, embedding, vs_cls, top_k, vs_params) | |