Spaces:
Sleeping
Sleeping
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from typing import List, Dict | |
import faiss | |
class ContractVectorStore: | |
def __init__(self, model: SentenceTransformer): | |
self.model = model | |
self.index = None | |
self.texts = [] | |
self.dimension = 384 # dimension for 'all-MiniLM-L6-v2' | |
def add_contract_terms(self, contract: Dict) -> None: | |
"""Add contract terms to the vector store""" | |
terms = [] | |
# Add volume discounts | |
if "volume_discounts" in contract["terms"]: | |
for discount in contract["terms"]["volume_discounts"]: | |
terms.append( | |
f"Volume discount: {discount['discount']*100}% off for quantities >= {discount['threshold']}" | |
) | |
# Add tiered pricing | |
if "tiered_pricing" in contract["terms"]: | |
for tier in contract["terms"]["tiered_pricing"]: | |
terms.append( | |
f"Tier {tier['tier']}: Rate multiplier of {tier['rate']}x base rate" | |
) | |
# Add special conditions | |
for condition in contract["terms"]["special_conditions"]: | |
terms.append(condition) | |
# Add base rate | |
terms.append(f"Base rate is ${contract['terms']['base_rate']} per unit") | |
# Create embeddings and update index | |
self._add_texts(terms) | |
def _add_texts(self, texts: List[str]) -> None: | |
"""Add texts to the vector store""" | |
if not texts: | |
return | |
# Generate embeddings | |
embeddings = self.model.encode(texts) | |
# Initialize index if needed | |
if self.index is None: | |
self.index = faiss.IndexFlatL2(self.dimension) | |
# Add to index | |
self.index.add(np.array(embeddings).astype('float32')) | |
self.texts.extend(texts) | |
def search_relevant_terms(self, query: str, k: int = 3) -> List[Dict]: | |
"""Search for relevant terms using the query""" | |
if not self.index or not self.texts: | |
return [] | |
# Generate query embedding | |
query_embedding = self.model.encode([query])[0].reshape(1, -1) | |
# Search | |
distances, indices = self.index.search( | |
np.array(query_embedding).astype('float32'), | |
k | |
) | |
# Return results | |
results = [] | |
for i, (dist, idx) in enumerate(zip(distances[0], indices[0])): | |
if idx < len(self.texts): # Ensure valid index | |
results.append({ | |
"text": self.texts[idx], | |
"score": float(1 / (1 + dist)) # Convert distance to similarity score | |
}) | |
return results |