rag_chat_with_analytics / observability.py
pvanand's picture
Update observability.py
dcb9038 verified
# File: llm_observability.py
import sqlite3
import json
from datetime import datetime
from typing import Dict, Any, List, Optional, Callable
import logging
import functools
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def log_execution(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
logger.info(f"Executing {func.__name__}")
try:
result = func(*args, **kwargs)
logger.info(f"{func.__name__} completed successfully")
return result
except Exception as e:
logger.error(f"Error in {func.__name__}: {e}")
raise
return wrapper
class LLMObservabilityManager:
def __init__(self, db_path: str = "/data/llm_observability_v2.db"):
self.db_path = db_path
self.create_table()
def create_table(self):
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS llm_observations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
conversation_id TEXT,
created_at DATETIME,
status TEXT,
request TEXT,
response TEXT,
model TEXT,
prompt_tokens INTEGER,
completion_tokens INTEGER,
total_tokens INTEGER,
cost FLOAT,
latency FLOAT,
user TEXT
)
''')
def insert_observation(self, response: str, conversation_id: str, status: str, request: str, model: str, prompt_tokens: int,completion_tokens: int, total_tokens: int, cost: float, latency: float, user: str):
created_at = datetime.now()
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT INTO llm_observations
(conversation_id, created_at, status, request, response, model, prompt_tokens, completion_tokens,total_tokens, cost, latency, user)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
conversation_id,
created_at,
status,
request,
response,
model,
prompt_tokens,
completion_tokens,
total_tokens,
cost,
latency,
user
))
def get_observations(self, conversation_id: Optional[str] = None) -> List[Dict[str, Any]]:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
if conversation_id:
cursor.execute('SELECT * FROM llm_observations WHERE conversation_id = ? ORDER BY created_at', (conversation_id,))
else:
cursor.execute('SELECT * FROM llm_observations ORDER BY created_at')
rows = cursor.fetchall()
column_names = [description[0] for description in cursor.description]
return [dict(zip(column_names, row)) for row in rows]
def get_all_observations(self) -> List[Dict[str, Any]]:
return self.get_observations()
def get_all_unique_conversation_observations(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Get the latest observation for each unique conversation_id
query = '''
SELECT * FROM llm_observations o1
WHERE created_at = (
SELECT MAX(created_at)
FROM llm_observations o2
WHERE o2.conversation_id = o1.conversation_id
)
ORDER BY created_at DESC
'''
if limit is not None:
query += f' LIMIT {limit}'
cursor.execute(query)
rows = cursor.fetchall()
column_names = [description[0] for description in cursor.description]
return [dict(zip(column_names, row)) for row in rows]
def get_dashboard_statistics(self, days: Optional[int] = None, time_series_interval: str = 'day') -> Dict[str, Any]:
"""
Get statistical metrics for LLM usage dashboard with time series data.
Args:
days (int, optional): Number of days to look back. If None, returns all-time statistics
time_series_interval (str): Interval for time series data ('hour', 'day', 'week', 'month')
Returns:
Dict containing dashboard statistics and time series data
"""
def safe_round(value: Any, decimals: int = 2) -> float:
"""Safely round a value, returning 0 if the value is None or invalid."""
try:
return round(float(value), decimals) if value is not None else 0.0
except (TypeError, ValueError):
return 0.0
def safe_divide(numerator: Any, denominator: Any, decimals: int = 2) -> float:
"""Safely divide two numbers, handling None and zero division."""
try:
if not denominator or denominator is None:
return 0.0
return round(float(numerator or 0) / float(denominator), decimals)
except (TypeError, ValueError):
return 0.0
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Build time filter
time_filter = ""
if days is not None:
time_filter = f"WHERE created_at >= datetime('now', '-{days} days')"
# Get general statistics
cursor.execute(f"""
SELECT
COUNT(*) as total_requests,
COUNT(DISTINCT conversation_id) as unique_conversations,
COUNT(DISTINCT user) as unique_users,
SUM(total_tokens) as total_tokens,
SUM(cost) as total_cost,
AVG(latency) as avg_latency,
SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as error_count
FROM llm_observations
{time_filter}
""")
row = cursor.fetchone()
if not row:
return self._get_empty_statistics()
general_stats = dict(zip([col[0] for col in cursor.description], row))
# Get model distribution
cursor.execute(f"""
SELECT model, COUNT(*) as count
FROM llm_observations
{time_filter}
GROUP BY model
ORDER BY count DESC
""")
model_distribution = {row[0]: row[1] for row in cursor.fetchall()} if cursor.fetchall() else {}
# Get average tokens per request
cursor.execute(f"""
SELECT
AVG(prompt_tokens) as avg_prompt_tokens,
AVG(completion_tokens) as avg_completion_tokens
FROM llm_observations
{time_filter}
""")
token_averages = dict(zip([col[0] for col in cursor.description], cursor.fetchone()))
# Get top users by request count
cursor.execute(f"""
SELECT user, COUNT(*) as request_count,
SUM(total_tokens) as total_tokens,
SUM(cost) as total_cost
FROM llm_observations
{time_filter}
GROUP BY user
ORDER BY request_count DESC
LIMIT 5
""")
top_users = [
{
"user": row[0],
"request_count": row[1],
"total_tokens": row[2] or 0,
"total_cost": safe_round(row[3])
}
for row in cursor.fetchall()
]
# Get time series data
time_series_format = {
'hour': "%Y-%m-%d %H:00:00",
'day': "%Y-%m-%d",
'week': "%Y-%W",
'month': "%Y-%m"
}
format_string = time_series_format.get(time_series_interval, "%Y-%m-%d")
cursor.execute(f"""
SELECT
strftime('{format_string}', created_at) as time_bucket,
COUNT(*) as request_count,
SUM(total_tokens) as total_tokens,
SUM(cost) as total_cost,
AVG(latency) as avg_latency,
COUNT(DISTINCT user) as unique_users,
SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as error_count
FROM llm_observations
{time_filter}
GROUP BY time_bucket
ORDER BY time_bucket
""")
time_series = [
{
"timestamp": row[0],
"request_count": row[1] or 0,
"total_tokens": row[2] or 0,
"total_cost": safe_round(row[3]),
"avg_latency": safe_round(row[4]),
"unique_users": row[5] or 0,
"error_count": row[6] or 0
}
for row in cursor.fetchall()
]
# Calculate trends safely
trends = self._calculate_trends(time_series)
return {
"general_stats": {
"total_requests": general_stats["total_requests"] or 0,
"unique_conversations": general_stats["unique_conversations"] or 0,
"unique_users": general_stats["unique_users"] or 0,
"total_tokens": general_stats["total_tokens"] or 0,
"total_cost": safe_round(general_stats["total_cost"]),
"avg_latency": safe_round(general_stats["avg_latency"]),
"error_rate": safe_round(
safe_divide(general_stats["error_count"], general_stats["total_requests"]) * 100
)
},
"model_distribution": model_distribution,
"token_metrics": {
"avg_prompt_tokens": safe_round(token_averages["avg_prompt_tokens"]),
"avg_completion_tokens": safe_round(token_averages["avg_completion_tokens"])
},
"top_users": top_users,
"time_series": time_series,
"trends": trends
}
except sqlite3.Error as e:
logger.error(f"Database error in get_dashboard_statistics: {e}")
return self._get_empty_statistics()
except Exception as e:
logger.error(f"Error in get_dashboard_statistics: {e}")
return self._get_empty_statistics()
def _get_empty_statistics(self) -> Dict[str, Any]:
"""Return an empty statistics structure when no data is available."""
return {
"general_stats": {
"total_requests": 0,
"unique_conversations": 0,
"unique_users": 0,
"total_tokens": 0,
"total_cost": 0.0,
"avg_latency": 0.0,
"error_rate": 0.0
},
"model_distribution": {},
"token_metrics": {
"avg_prompt_tokens": 0.0,
"avg_completion_tokens": 0.0
},
"top_users": [],
"time_series": [],
"trends": {
"request_trend": 0.0,
"cost_trend": 0.0,
"token_trend": 0.0
}
}
def _calculate_trends(self, time_series: List[Dict[str, Any]]) -> Dict[str, float]:
"""Calculate trends safely from time series data."""
if len(time_series) >= 2:
current = time_series[-1]
previous = time_series[-2]
return {
"request_trend": self._calculate_percentage_change(
previous["request_count"], current["request_count"]),
"cost_trend": self._calculate_percentage_change(
previous["total_cost"], current["total_cost"]),
"token_trend": self._calculate_percentage_change(
previous["total_tokens"], current["total_tokens"])
}
return {
"request_trend": 0.0,
"cost_trend": 0.0,
"token_trend": 0.0
}
def _calculate_percentage_change(self, old_value: Any, new_value: Any) -> float:
"""Calculate percentage change between two values safely."""
try:
old_value = float(old_value or 0)
new_value = float(new_value or 0)
if old_value == 0:
return 100.0 if new_value > 0 else 0.0
return round(((new_value - old_value) / old_value) * 100, 2)
except (TypeError, ValueError):
return 0.0