Rohil Bansal
commit
2ed2129
import numpy as np
import pandas as pd
from typing import List, Dict
import faiss
import logging
logger = logging.getLogger(__name__)
class FAISSManager:
def __init__(self, dimension: int = 384):
"""
Initialize FAISS index with error handling
"""
try:
self.dimension = dimension
self.index = faiss.IndexFlatL2(dimension)
self.metadata = []
logger.info("FAISS index initialized successfully")
except Exception as e:
logger.error(f"Error initializing FAISS: {str(e)}")
raise RuntimeError(f"Failed to initialize FAISS: {str(e)}")
def upsert_courses(self, df: pd.DataFrame) -> None:
"""
Add course embeddings to the FAISS index with error handling
"""
try:
# Convert embeddings to numpy array
vectors = np.vstack([
emb.astype('float32') for emb in df['embeddings'].values
])
# Add vectors to index
self.index.add(vectors)
self.metadata.extend(df[['title', 'description', 'url']].to_dict('records'))
logger.info(f"Added {len(vectors)} vectors to FAISS index")
except Exception as e:
logger.error(f"Error adding vectors to FAISS: {str(e)}")
raise RuntimeError(f"Failed to add vectors to FAISS: {str(e)}")
def search_courses(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Dict]:
"""
Search for similar courses using query embedding with error handling
"""
try:
# Ensure query embedding is in correct format
query_embedding = query_embedding.astype('float32').reshape(1, -1)
# Perform search
distances, indices = self.index.search(query_embedding, top_k)
results = []
for i, idx in enumerate(indices[0]):
if idx == -1:
continue
result = self.metadata[idx].copy()
result['score'] = float(distances[0][i]) # Convert to float for JSON serialization
results.append(result)
return results
except Exception as e:
logger.error(f"Error searching FAISS index: {str(e)}")
raise RuntimeError(f"Failed to search FAISS index: {str(e)}")