PricingHelper / src /vector_store.py
AshwinP's picture
Pricing Helper
5707fbc
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict
import faiss
class ContractVectorStore:
def __init__(self, model: SentenceTransformer):
self.model = model
self.index = None
self.texts = []
self.dimension = 384 # dimension for 'all-MiniLM-L6-v2'
def add_contract_terms(self, contract: Dict) -> None:
"""Add contract terms to the vector store"""
terms = []
# Add volume discounts
if "volume_discounts" in contract["terms"]:
for discount in contract["terms"]["volume_discounts"]:
terms.append(
f"Volume discount: {discount['discount']*100}% off for quantities >= {discount['threshold']}"
)
# Add tiered pricing
if "tiered_pricing" in contract["terms"]:
for tier in contract["terms"]["tiered_pricing"]:
terms.append(
f"Tier {tier['tier']}: Rate multiplier of {tier['rate']}x base rate"
)
# Add special conditions
for condition in contract["terms"]["special_conditions"]:
terms.append(condition)
# Add base rate
terms.append(f"Base rate is ${contract['terms']['base_rate']} per unit")
# Create embeddings and update index
self._add_texts(terms)
def _add_texts(self, texts: List[str]) -> None:
"""Add texts to the vector store"""
if not texts:
return
# Generate embeddings
embeddings = self.model.encode(texts)
# Initialize index if needed
if self.index is None:
self.index = faiss.IndexFlatL2(self.dimension)
# Add to index
self.index.add(np.array(embeddings).astype('float32'))
self.texts.extend(texts)
def search_relevant_terms(self, query: str, k: int = 3) -> List[Dict]:
"""Search for relevant terms using the query"""
if not self.index or not self.texts:
return []
# Generate query embedding
query_embedding = self.model.encode([query])[0].reshape(1, -1)
# Search
distances, indices = self.index.search(
np.array(query_embedding).astype('float32'),
k
)
# Return results
results = []
for i, (dist, idx) in enumerate(zip(distances[0], indices[0])):
if idx < len(self.texts): # Ensure valid index
results.append({
"text": self.texts[idx],
"score": float(1 / (1 + dist)) # Convert distance to similarity score
})
return results