ChatGLM / chains /modules /vectorstores.py
y001j's picture
Upload 305 files
af9251e
raw
history blame contribute delete
No virus
4.77 kB
from langchain.vectorstores import FAISS
from typing import Any, Callable, List, Optional, Tuple, Dict
from langchain.docstore.document import Document
from langchain.docstore.base import Docstore
from langchain.vectorstores.utils import maximal_marginal_relevance
from langchain.embeddings.base import Embeddings
import uuid
from langchain.docstore.in_memory import InMemoryDocstore
import numpy as np
def dependable_faiss_import() -> Any:
"""Import faiss if available, otherwise raise error."""
try:
import faiss
except ImportError:
raise ValueError(
"Could not import faiss python package. "
"Please install it with `pip install faiss` "
"or `pip install faiss-cpu` (depending on Python version)."
)
return faiss
class FAISSVS(FAISS):
def __init__(self,
embedding_function: Callable[..., Any],
index: Any,
docstore: Docstore,
index_to_docstore_id: Dict[int, str]):
super().__init__(embedding_function, index, docstore, index_to_docstore_id)
def max_marginal_relevance_search_by_vector(
self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Returns:
List of Documents with scores selected by maximal marginal relevance.
"""
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k)
# -1 happens when not enough docs are returned.
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
mmr_selected = maximal_marginal_relevance(
np.array([embedding], dtype=np.float32), embeddings, k=k
)
selected_indices = [indices[0][i] for i in mmr_selected]
selected_scores = [scores[0][i] for i in mmr_selected]
docs = []
for i, score in zip(selected_indices, selected_scores):
if i == -1:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs.append((doc, score))
return docs
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Returns:
List of Documents with scores selected by maximal marginal relevance.
"""
embedding = self.embedding_function(query)
docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k)
return docs
@classmethod
def __from(
cls,
texts: List[str],
embeddings: List[List[float]],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> FAISS:
faiss = dependable_faiss_import()
index = faiss.IndexFlatIP(len(embeddings[0]))
index.add(np.array(embeddings, dtype=np.float32))
# # my code, for speeding up search
# quantizer = faiss.IndexFlatL2(len(embeddings[0]))
# index = faiss.IndexIVFFlat(quantizer, len(embeddings[0]), 100)
# index.train(np.array(embeddings, dtype=np.float32))
# index.add(np.array(embeddings, dtype=np.float32))
documents = []
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
documents.append(Document(page_content=text, metadata=metadata))
index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
docstore = InMemoryDocstore(
{index_to_id[i]: doc for i, doc in enumerate(documents)}
)
return cls(embedding.embed_query, index, docstore, index_to_id)