Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import json | |
from typing import List, Dict | |
import os | |
from dotenv import load_dotenv | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from anthropic import Anthropic | |
import time | |
# Import our modules | |
from src.invoice_generator import InvoiceGenerator | |
from src.vector_store import ContractVectorStore | |
# Load environment variables | |
load_dotenv() | |
# Page configuration | |
st.set_page_config( | |
page_title="Enterprise Pricing Audit Assistant", | |
page_icon="π°", | |
layout="wide" | |
) | |
# Load custom CSS | |
def load_css(): | |
with open("styles.css") as f: | |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
# Initialize LLM client | |
def init_llm(): | |
return Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) | |
# Initialize the sentence transformer model | |
def load_embedding_model(): | |
from sentence_transformers import SentenceTransformer | |
return SentenceTransformer('all-MiniLM-L6-v2') | |
def analyze_invoice_with_rag(invoice: Dict, contract: Dict, vector_store: ContractVectorStore) -> Dict: | |
base_rate = contract["terms"]["base_rate"] | |
quantity = invoice["quantity"] | |
charged_amount = invoice["amount_charged"] | |
correct_amount = invoice["correct_amount"] | |
# Search for relevant contract terms | |
relevant_terms = vector_store.search_relevant_terms( | |
f"pricing rules for quantity {quantity} and amount {charged_amount}" | |
) | |
# Prepare context for LLM | |
context = { | |
"invoice_details": { | |
"invoice_id": invoice["invoice_id"], | |
"quantity": quantity, | |
"charged_amount": charged_amount, | |
"correct_amount": correct_amount, | |
"date": invoice["date"] | |
}, | |
"relevant_terms": [term["text"] for term in relevant_terms], | |
"discrepancy": round(charged_amount - correct_amount, 2), | |
"discrepancy_percentage": round((charged_amount - correct_amount) / correct_amount * 100, 2) | |
} | |
# Generate explanation using LLM if there's a discrepancy | |
if abs(context["discrepancy"]) > 0.01: | |
prompt = f""" | |
Analyze this invoice for pricing accuracy: | |
Invoice Details: | |
- Invoice ID: {context['invoice_details']['invoice_id']} | |
- Quantity: {context['invoice_details']['quantity']} | |
- Charged Amount: ${context['invoice_details']['charged_amount']:.2f} | |
- Correct Amount: ${context['invoice_details']['correct_amount']:.2f} | |
- Date: {context['invoice_details']['date']} | |
Relevant Contract Terms: | |
{chr(10).join('- ' + term for term in context['relevant_terms'])} | |
Discrepancy found: | |
- Amount Difference: ${context['discrepancy']:.2f} | |
- Percentage Difference: {context['discrepancy_percentage']:.2f}% | |
Please provide a detailed explanation of: | |
1. Why there is a pricing discrepancy | |
2. Which contract terms were violated | |
3. How the correct price should have been calculated | |
Keep the explanation clear and concise, focusing on the specific pricing rules that were not properly applied. | |
""" | |
anthropic = init_llm() | |
response = anthropic.messages.create( | |
model="claude-3-sonnet-20240229", | |
max_tokens=1000, | |
messages=[{"role": "user", "content": prompt}] | |
) | |
explanation = response.content[0].text | |
else: | |
explanation = "Invoice pricing is correct according to contract terms." | |
return { | |
**context, | |
"explanation": explanation, | |
"relevant_terms": relevant_terms | |
} | |
def display_metrics(invoices_df): | |
with st.container(): | |
st.markdown('<div class="metrics-container">', unsafe_allow_html=True) | |
col1, col2, col3, col4 = st.columns(4) | |
total_invoices = len(invoices_df) | |
incorrect_invoices = len(invoices_df[invoices_df['has_error']]) | |
total_value = invoices_df['amount_charged'].sum() | |
total_discrepancy = (invoices_df['amount_charged'] - invoices_df['correct_amount']).sum() | |
with col1: | |
st.metric("Total Invoices", total_invoices) | |
with col2: | |
st.metric("Incorrect Invoices", incorrect_invoices) | |
with col3: | |
st.metric("Total Invoice Value", f"${total_value:,.2f}") | |
with col4: | |
st.metric("Total Pricing Discrepancy", f"${total_discrepancy:,.2f}") | |
st.markdown('</div>', unsafe_allow_html=True) | |
def display_invoice_tables(invoices_df): | |
st.markdown('<div class="invoice-table">', unsafe_allow_html=True) | |
# Separate correct and incorrect invoices | |
correct_invoices = invoices_df[~invoices_df['has_error']].copy() | |
incorrect_invoices = invoices_df[invoices_df['has_error']].copy() | |
# Format currency columns | |
currency_cols = ['amount_charged', 'correct_amount'] | |
for df in [correct_invoices, incorrect_invoices]: | |
for col in currency_cols: | |
df[col] = df[col].apply(lambda x: f"${x:,.2f}") | |
# Display tables in tabs | |
tab1, tab2 = st.tabs(["π’ Correct Invoices", "π΄ Incorrect Invoices"]) | |
with tab1: | |
if not correct_invoices.empty: | |
st.dataframe( | |
correct_invoices, | |
column_config={ | |
"invoice_id": "Invoice ID", | |
"date": "Date", | |
"quantity": "Quantity", | |
"amount_charged": "Amount", | |
}, | |
hide_index=True | |
) | |
else: | |
st.info("No correctly priced invoices found.") | |
with tab2: | |
if not incorrect_invoices.empty: | |
st.dataframe( | |
incorrect_invoices, | |
column_config={ | |
"invoice_id": "Invoice ID", | |
"date": "Date", | |
"quantity": "Quantity", | |
"amount_charged": "Charged Amount", | |
"correct_amount": "Correct Amount" | |
}, | |
hide_index=True | |
) | |
else: | |
st.info("No pricing discrepancies found.") | |
st.markdown('</div>', unsafe_allow_html=True) | |
def display_contract_details(contract): | |
st.markdown('<div class="contract-details">', unsafe_allow_html=True) | |
st.subheader("π Contract Details") | |
# Basic contract information | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.write("**Contract ID:**", contract['contract_id']) | |
with col2: | |
st.write("**Client:**", contract['client']) | |
with col3: | |
st.write("**Base Rate:**", f"${contract['terms']['base_rate']}") | |
# Pricing rules | |
with st.expander("π·οΈ Pricing Rules"): | |
if "volume_discounts" in contract["terms"]: | |
st.write("**Volume Discounts:**") | |
for discount in contract["terms"]["volume_discounts"]: | |
st.write(f"β’ {discount['discount']*100}% off for quantities β₯ {discount['threshold']:,}") | |
if "tiered_pricing" in contract["terms"]: | |
st.write("**Tiered Pricing:**") | |
for tier in contract["terms"]["tiered_pricing"]: | |
st.write(f"β’ {tier['tier']}: {tier['rate']}x base rate") | |
# Special conditions | |
with st.expander("π Special Conditions"): | |
for condition in contract["terms"]["special_conditions"]: | |
st.write(f"β’ {condition}") | |
st.markdown('</div>', unsafe_allow_html=True) | |
def initialize_data(): | |
"""Initialize data and models""" | |
try: | |
# Initialize embedding model | |
embedding_model = load_embedding_model() | |
# Initialize invoice generator | |
generator = InvoiceGenerator(data_dir="data") | |
# Ensure we have both contracts and invoices | |
if not os.path.exists("data/contracts.json") or not os.path.exists("data/invoices.json"): | |
generator.generate_and_save() | |
# Load contracts and invoices | |
contracts = generator.load_contracts() | |
invoices = generator.load_or_generate_invoices() | |
if not contracts or not invoices: | |
st.error("No data found. Generating new data...") | |
generator.generate_and_save() | |
contracts = generator.load_contracts() | |
invoices = generator.load_or_generate_invoices() | |
# Initialize vector store | |
vector_store = ContractVectorStore(embedding_model) | |
for contract in contracts: | |
vector_store.add_contract_terms(contract) | |
return contracts, invoices, vector_store | |
except Exception as e: | |
st.error(f"Error initializing data: {str(e)}") | |
st.stop() | |
def main(): | |
# Load custom CSS | |
try: | |
load_css() | |
except Exception as e: | |
st.warning(f"Could not load custom CSS: {str(e)}") | |
st.title("π Enterprise Pricing Audit Assistant") | |
try: | |
# Initialize data and models | |
with st.spinner('Loading data and initializing models...'): | |
contracts, invoices, vector_store = initialize_data() | |
# Convert invoices to DataFrame | |
invoices_df = pd.DataFrame(invoices) | |
# Display metrics | |
display_metrics(invoices_df) | |
# Display contract selection | |
selected_contract_id = st.selectbox( | |
"Select Contract", | |
options=[c["contract_id"] for c in contracts], | |
format_func=lambda x: f"{x} - {next(c['client'] for c in contracts if c['contract_id'] == x)}" | |
) | |
# Get selected contract | |
selected_contract = next(c for c in contracts if c["contract_id"] == selected_contract_id) | |
# Display contract details | |
display_contract_details(selected_contract) | |
# Filter invoices for selected contract | |
contract_invoices_df = invoices_df[invoices_df['contract_id'] == selected_contract_id] | |
# Display invoice analysis | |
st.subheader("π Invoice Analysis") | |
# Create tabs for different views | |
tab1, tab2, tab3 = st.tabs(["π Overview", "π Invoice Details", "π Detailed Analysis"]) | |
with tab1: | |
# Display summary metrics for the selected contract | |
total_contract_value = contract_invoices_df['amount_charged'].sum() | |
total_contract_discrepancy = ( | |
contract_invoices_df['amount_charged'] - contract_invoices_df['correct_amount'] | |
).sum() | |
error_rate = ( | |
len(contract_invoices_df[contract_invoices_df['has_error']]) / | |
len(contract_invoices_df) * 100 | |
) | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("Total Contract Value", f"${total_contract_value:,.2f}") | |
with col2: | |
st.metric("Total Discrepancy", f"${total_contract_discrepancy:,.2f}") | |
with col3: | |
st.metric("Error Rate", f"{error_rate:.1f}%") | |
# Create visualization | |
if not contract_invoices_df.empty: | |
# Prepare data for visualization | |
contract_invoices_df['error_amount'] = ( | |
contract_invoices_df['amount_charged'] - | |
contract_invoices_df['correct_amount'] | |
) | |
# Create scatter plot | |
fig = go.Figure() | |
# Add points for correct invoices | |
correct_invoices = contract_invoices_df[~contract_invoices_df['has_error']] | |
if not correct_invoices.empty: | |
fig.add_trace(go.Scatter( | |
x=correct_invoices['date'], | |
y=correct_invoices['amount_charged'], | |
mode='markers', | |
name='Correct Invoices', | |
marker=dict(color='green', size=10), | |
)) | |
# Add points for incorrect invoices | |
incorrect_invoices = contract_invoices_df[contract_invoices_df['has_error']] | |
if not incorrect_invoices.empty: | |
fig.add_trace(go.Scatter( | |
x=incorrect_invoices['date'], | |
y=incorrect_invoices['amount_charged'], | |
mode='markers', | |
name='Incorrect Invoices', | |
marker=dict(color='red', size=10), | |
)) | |
fig.update_layout( | |
title='Invoice Amounts Over Time', | |
xaxis_title='Date', | |
yaxis_title='Amount ($)', | |
hovermode='closest' | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
with tab2: | |
# Display invoice tables | |
display_invoice_tables(contract_invoices_df) | |
with tab3: | |
# Detailed analysis of incorrect invoices | |
incorrect_invoices = contract_invoices_df[contract_invoices_df['has_error']] | |
if not incorrect_invoices.empty: | |
for _, invoice in incorrect_invoices.iterrows(): | |
with st.expander(f"Invoice {invoice['invoice_id']} Analysis"): | |
analysis = analyze_invoice_with_rag( | |
invoice.to_dict(), | |
selected_contract, | |
vector_store | |
) | |
# Display analysis results | |
st.write("**Discrepancy Amount:**", | |
f"${analysis['discrepancy']:.2f} " | |
f"({analysis['discrepancy_percentage']}%)") | |
st.write("**Relevant Contract Terms:**") | |
for term in analysis['relevant_terms']: | |
st.write(f"β’ {term['text']}") | |
st.write("**Analysis:**") | |
st.write(analysis['explanation']) | |
else: | |
st.info("No pricing discrepancies found for this contract.") | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
st.stop() | |
if __name__ == "__main__": | |
main() |