general_chat / embeddings.py
pvanand's picture
Create embeddings.py
be2f825 verified
raw
history blame
4.79 kB
import os
import json
import logging
from typing import List
from txtai.embeddings import Embeddings
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EmbeddingsManager:
def __init__(self, base_path: str = "./indexes", model_path: str = "avsolatorio/GIST-all-MiniLM-L6-v2"):
"""
Initializes the EmbeddingsManager.
Args:
base_path (str): Base directory to store indices.
model_path (str): Path or identifier for the embeddings model.
"""
self.base_path = base_path
os.makedirs(self.base_path, exist_ok=True)
self.model_path = model_path
self.embeddings = Embeddings({"path": self.model_path})
logger.info(f"Embeddings model loaded from '{self.model_path}'. Base path set to '{self.base_path}'.")
def create_index(self, index_id: str, documents: List[str]) -> None:
"""
Creates a new embeddings index with the provided documents.
Args:
index_id (str): Unique identifier for the index.
documents (List[str]): List of documents to be indexed.
Raises:
ValueError: If the index already exists.
Exception: For any other errors during indexing or saving.
"""
index_path = os.path.join(self.base_path, index_id)
if os.path.exists(index_path):
logger.error(f"Index with index_id '{index_id}' already exists at '{index_path}'.")
raise ValueError(f"Index with index_id '{index_id}' already exists.")
try:
# Prepare documents for txtai indexing
document_tuples = [(i, text, None) for i, text in enumerate(documents)]
self.embeddings.index(document_tuples)
logger.info(f"Documents indexed for index_id '{index_id}'.")
# Create index directory
os.makedirs(index_path, exist_ok=True)
# Save embeddings
self.embeddings.save(os.path.join(index_path, "embeddings"))
logger.info(f"Embeddings saved to '{os.path.join(index_path, 'embeddings')}'.")
# Save document list
with open(os.path.join(index_path, "document_list.json"), "w", encoding='utf-8') as f:
json.dump(documents, f, ensure_ascii=False, indent=4)
logger.info(f"Document list saved to '{os.path.join(index_path, 'document_list.json')}'.")
logger.info(f"Index '{index_id}' created and saved successfully.")
except Exception as e:
logger.error(f"Failed to create index '{index_id}': {e}")
raise Exception(f"Failed to create index '{index_id}': {e}")
def query_index(self, index_id: str, query: str, num_results: int = 5) -> List[str]:
"""
Queries an existing embeddings index.
Args:
index_id (str): Unique identifier for the index to query.
query (str): The search query.
num_results (int): Number of top results to return.
Returns:
List[str]: List of top matching documents.
Raises:
FileNotFoundError: If the index does not exist.
Exception: For any other errors during querying.
"""
index_path = os.path.join(self.base_path, index_id)
if not os.path.exists(index_path):
logger.error(f"Index '{index_id}' not found at '{index_path}'.")
raise FileNotFoundError(f"Index '{index_id}' not found.")
try:
# Load embeddings from the index
self.embeddings.load(os.path.join(index_path, "embeddings"))
logger.info(f"Embeddings loaded from '{os.path.join(index_path, 'embeddings')}' for index '{index_id}'.")
# Load document list
document_list_path = os.path.join(index_path, "document_list.json")
if not os.path.exists(document_list_path):
logger.error(f"Document list not found at '{document_list_path}'.")
raise FileNotFoundError(f"Document list not found for index '{index_id}'.")
with open(document_list_path, "r", encoding='utf-8') as f:
document_list = json.load(f)
logger.info(f"Document list loaded from '{document_list_path}'.")
# Perform the search
results = self.embeddings.search(query, num_results)
queried_texts = [document_list[idx[0]] for idx in results]
logger.info(f"Query executed successfully on index '{index_id}'. Retrieved {len(queried_texts)} results.")
return queried_texts
except Exception as e:
logger.error(f"Failed to query index '{index_id}': {e}")
raise Exception(f"Failed to query index '{index_id}': {e}")