Spaces:
Running
Running
import os | |
import json | |
import logging | |
from typing import List | |
from txtai.embeddings import Embeddings | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class EmbeddingsManager: | |
def __init__(self, base_path: str = "./indexes", model_path: str = "avsolatorio/GIST-all-MiniLM-L6-v2"): | |
""" | |
Initializes the EmbeddingsManager. | |
Args: | |
base_path (str): Base directory to store indices. | |
model_path (str): Path or identifier for the embeddings model. | |
""" | |
self.base_path = base_path | |
os.makedirs(self.base_path, exist_ok=True) | |
self.model_path = model_path | |
self.embeddings = Embeddings({"path": self.model_path}) | |
logger.info(f"Embeddings model loaded from '{self.model_path}'. Base path set to '{self.base_path}'.") | |
def create_index(self, index_id: str, documents: List[str]) -> None: | |
""" | |
Creates a new embeddings index with the provided documents. | |
Args: | |
index_id (str): Unique identifier for the index. | |
documents (List[str]): List of documents to be indexed. | |
Raises: | |
ValueError: If the index already exists. | |
Exception: For any other errors during indexing or saving. | |
""" | |
index_path = os.path.join(self.base_path, index_id) | |
if os.path.exists(index_path): | |
logger.error(f"Index with index_id '{index_id}' already exists at '{index_path}'.") | |
raise ValueError(f"Index with index_id '{index_id}' already exists.") | |
try: | |
# Prepare documents for txtai indexing | |
document_tuples = [(i, text, None) for i, text in enumerate(documents)] | |
self.embeddings.index(document_tuples) | |
logger.info(f"Documents indexed for index_id '{index_id}'.") | |
# Create index directory | |
os.makedirs(index_path, exist_ok=True) | |
# Save embeddings | |
self.embeddings.save(os.path.join(index_path, "embeddings")) | |
logger.info(f"Embeddings saved to '{os.path.join(index_path, 'embeddings')}'.") | |
# Save document list | |
with open(os.path.join(index_path, "document_list.json"), "w", encoding='utf-8') as f: | |
json.dump(documents, f, ensure_ascii=False, indent=4) | |
logger.info(f"Document list saved to '{os.path.join(index_path, 'document_list.json')}'.") | |
logger.info(f"Index '{index_id}' created and saved successfully.") | |
except Exception as e: | |
logger.error(f"Failed to create index '{index_id}': {e}") | |
raise Exception(f"Failed to create index '{index_id}': {e}") | |
def query_index(self, index_id: str, query: str, num_results: int = 5) -> List[str]: | |
""" | |
Queries an existing embeddings index. | |
Args: | |
index_id (str): Unique identifier for the index to query. | |
query (str): The search query. | |
num_results (int): Number of top results to return. | |
Returns: | |
List[str]: List of top matching documents. | |
Raises: | |
FileNotFoundError: If the index does not exist. | |
Exception: For any other errors during querying. | |
""" | |
index_path = os.path.join(self.base_path, index_id) | |
if not os.path.exists(index_path): | |
logger.error(f"Index '{index_id}' not found at '{index_path}'.") | |
raise FileNotFoundError(f"Index '{index_id}' not found.") | |
try: | |
# Load embeddings from the index | |
self.embeddings.load(os.path.join(index_path, "embeddings")) | |
logger.info(f"Embeddings loaded from '{os.path.join(index_path, 'embeddings')}' for index '{index_id}'.") | |
# Load document list | |
document_list_path = os.path.join(index_path, "document_list.json") | |
if not os.path.exists(document_list_path): | |
logger.error(f"Document list not found at '{document_list_path}'.") | |
raise FileNotFoundError(f"Document list not found for index '{index_id}'.") | |
with open(document_list_path, "r", encoding='utf-8') as f: | |
document_list = json.load(f) | |
logger.info(f"Document list loaded from '{document_list_path}'.") | |
# Perform the search | |
results = self.embeddings.search(query, num_results) | |
queried_texts = [document_list[idx[0]] for idx in results] | |
logger.info(f"Query executed successfully on index '{index_id}'. Retrieved {len(queried_texts)} results.") | |
return queried_texts | |
except Exception as e: | |
logger.error(f"Failed to query index '{index_id}': {e}") | |
raise Exception(f"Failed to query index '{index_id}': {e}") | |