embeddings-sebastian / chroma_storage.py
sebastianalgharaballi's picture
Update chroma_storage.py
43ccbdb verified
raw
history blame
6.17 kB
from typing import List, Dict, Optional
import chromadb
from chromadb.config import Settings
import numpy as np
from embeddings import EmbeddingManager, MatchResult
from encoder import create_encoders, FIELD_MAPPING
class ChromaMatchingSystem:
def __init__(self, collection_name: str = "job_seekers"):
# Initialize ChromaDB client with settings
self.client = chromadb.Client(Settings(
allow_reset=True,
is_persistent=True
))
# Initialize your existing embedding system
job_encoder, seeker_encoder = create_encoders('all-mpnet-base-v2')
self.embedding_manager = EmbeddingManager(job_encoder, seeker_encoder)
# Create or get collections for each field type based on FIELD_MAPPING
self.collections = {}
job_fields = set(FIELD_MAPPING.keys())
for field in job_fields:
self.collections[field] = self.client.get_or_create_collection(
name=f"{collection_name}_{field}",
embedding_function=None
)
def add_job_seeker(self, jobseeker_id: str, processed_seeker, unprocessed_seeker, metadata: Optional[Dict] = None):
"""Add a job seeker to ChromaDB collections"""
field_embeddings = self.embedding_manager.embed_jobseeker(processed_seeker, unprocessed_seeker)
# Ensure metadata includes status field
safe_metadata = metadata if metadata is not None else {}
safe_metadata['status'] = 'unseen' # Initialize all matches as unseen
for job_field, seeker_field in FIELD_MAPPING.items():
if seeker_field in field_embeddings:
self.collections[job_field].add(
embeddings=[field_embeddings[seeker_field].tolist()],
metadatas=[safe_metadata],
ids=[jobseeker_id],
documents=[jobseeker_id]
)
def get_matches(self, job_posting, n_results: int = 10, where_conditions: Optional[Dict] = None) -> List[MatchResult]:
"""Get all matches regardless of status"""
return self._get_matches_internal(job_posting, n_results, where_conditions)
def get_unseen_matches(self, job_posting, n_results: int = 10, where_conditions: Optional[Dict] = None) -> List[MatchResult]:
"""Get only unseen matches"""
# Combine existing conditions with unseen status
combined_conditions = {"status": "unseen"}
if where_conditions:
combined_conditions.update(where_conditions)
return self._get_matches_internal(job_posting, n_results, combined_conditions)
def mark_matches_as_seen(self, match_ids: List[str]):
"""Mark specific matches as seen"""
for job_field in self.collections:
collection = self.collections[job_field]
# Update each ID's metadata to mark as seen
for match_id in match_ids:
try:
# Get current metadata
result = collection.get(
ids=[match_id],
include=['metadatas']
)
if result and result['metadatas']:
metadata = result['metadatas'][0]
metadata['status'] = 'seen'
# Update the metadata
collection.update(
ids=[match_id],
metadatas=[metadata]
)
except Exception as e:
print(f"Error updating status for {match_id} in collection {job_field}: {str(e)}")
def _get_matches_internal(self, job_posting, n_results: int = 10, where_conditions: Optional[Dict] = None) -> List[MatchResult]:
"""Internal method for getting matches with shared logic"""
job_embeddings = self.embedding_manager.embed_jobposting(job_posting)
matches = []
field_results = {}
for job_field in FIELD_MAPPING.keys():
if job_field in job_embeddings:
try:
results = self.collections[job_field].query(
query_embeddings=[job_embeddings[job_field].tolist()],
n_results=n_results,
where=where_conditions,
include=["embeddings", "metadatas", "distances", "documents"]
)
if results and 'embeddings' in results and results['embeddings']:
field_results[job_field] = results
except Exception as e:
print(f"Error querying {job_field}: {str(e)}")
continue
jobseeker_ids = set()
for results in field_results.values():
if 'ids' in results and results['ids']:
jobseeker_ids.update(results['ids'][0])
for jobseeker_id in jobseeker_ids:
seeker_embeddings = {}
for job_field, seeker_field in FIELD_MAPPING.items():
if job_field in field_results:
results = field_results[job_field]
if ('ids' in results and results['ids'] and
'embeddings' in results and results['embeddings']):
if jobseeker_id in results['ids'][0]:
idx = results['ids'][0].index(jobseeker_id)
if idx < len(results['embeddings'][0]):
embedding = results['embeddings'][0][idx]
seeker_embeddings[seeker_field] = np.array(embedding)
if seeker_embeddings:
match_result = self.embedding_manager.calculate_similarity(
job_embeddings,
seeker_embeddings
)
matches.append(match_result)
matches.sort(key=lambda x: x.similarity_score, reverse=True)
return matches[:n_results]