Spaces:
Sleeping
Sleeping
import numpy as np | |
import pandas as pd | |
from typing import List, Dict | |
import faiss | |
import logging | |
logger = logging.getLogger(__name__) | |
class FAISSManager: | |
def __init__(self, dimension: int = 384): | |
""" | |
Initialize FAISS index with error handling | |
""" | |
try: | |
self.dimension = dimension | |
self.index = faiss.IndexFlatL2(dimension) | |
self.metadata = [] | |
logger.info("FAISS index initialized successfully") | |
except Exception as e: | |
logger.error(f"Error initializing FAISS: {str(e)}") | |
raise RuntimeError(f"Failed to initialize FAISS: {str(e)}") | |
def upsert_courses(self, df: pd.DataFrame) -> None: | |
""" | |
Add course embeddings to the FAISS index with error handling | |
""" | |
try: | |
# Convert embeddings to numpy array | |
vectors = np.vstack([ | |
emb.astype('float32') for emb in df['embeddings'].values | |
]) | |
# Add vectors to index | |
self.index.add(vectors) | |
self.metadata.extend(df[['title', 'description', 'url']].to_dict('records')) | |
logger.info(f"Added {len(vectors)} vectors to FAISS index") | |
except Exception as e: | |
logger.error(f"Error adding vectors to FAISS: {str(e)}") | |
raise RuntimeError(f"Failed to add vectors to FAISS: {str(e)}") | |
def search_courses(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Dict]: | |
""" | |
Search for similar courses using query embedding with error handling | |
""" | |
try: | |
# Ensure query embedding is in correct format | |
query_embedding = query_embedding.astype('float32').reshape(1, -1) | |
# Perform search | |
distances, indices = self.index.search(query_embedding, top_k) | |
results = [] | |
for i, idx in enumerate(indices[0]): | |
if idx == -1: | |
continue | |
result = self.metadata[idx].copy() | |
result['score'] = float(distances[0][i]) # Convert to float for JSON serialization | |
results.append(result) | |
return results | |
except Exception as e: | |
logger.error(f"Error searching FAISS index: {str(e)}") | |
raise RuntimeError(f"Failed to search FAISS index: {str(e)}") |