rag_chat_with_analytics / observability.py
pvanand's picture
Upload 11 files
1a6d961 verified
raw
history blame
6.74 kB
# File: llm_observability.py
import sqlite3
import json
from datetime import datetime
from typing import Dict, Any, List, Optional, Callable
import logging
import functools
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def log_execution(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
logger.info(f"Executing {func.__name__}")
try:
result = func(*args, **kwargs)
logger.info(f"{func.__name__} completed successfully")
return result
except Exception as e:
logger.error(f"Error in {func.__name__}: {e}")
raise
return wrapper
class LLMObservabilityManager:
def __init__(self, db_path: str = "llm_observability_v2.db"):
self.db_path = db_path
self.create_table()
def create_table(self):
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS llm_observations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
conversation_id TEXT,
created_at DATETIME,
status TEXT,
request TEXT,
response TEXT,
model TEXT,
prompt_tokens INTEGER,
completion_tokens INTEGER,
total_tokens INTEGER,
cost FLOAT,
latency FLOAT,
user TEXT
)
''')
def insert_observation(self, response: str, conversation_id: str, status: str, request: str, model: str, prompt_tokens: int,completion_tokens: int, total_tokens: int, cost: float, latency: float, user: str):
created_at = datetime.now()
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT INTO llm_observations
(conversation_id, created_at, status, request, response, model, prompt_tokens, completion_tokens,total_tokens, cost, latency, user)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
conversation_id,
created_at,
status,
request,
response,
model,
prompt_tokens,
completion_tokens,
total_tokens,
cost,
latency,
user
))
def get_observations(self, conversation_id: Optional[str] = None) -> List[Dict[str, Any]]:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
if conversation_id:
cursor.execute('SELECT * FROM llm_observations WHERE conversation_id = ? ORDER BY created_at', (conversation_id,))
else:
cursor.execute('SELECT * FROM llm_observations ORDER BY created_at')
rows = cursor.fetchall()
column_names = [description[0] for description in cursor.description]
return [dict(zip(column_names, row)) for row in rows]
def get_all_observations(self) -> List[Dict[str, Any]]:
return self.get_observations()
def get_all_unique_conversation_observations(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Get the latest observation for each unique conversation_id
query = '''
SELECT * FROM llm_observations o1
WHERE created_at = (
SELECT MAX(created_at)
FROM llm_observations o2
WHERE o2.conversation_id = o1.conversation_id
)
ORDER BY created_at DESC
'''
if limit is not None:
query += f' LIMIT {limit}'
cursor.execute(query)
rows = cursor.fetchall()
column_names = [description[0] for description in cursor.description]
return [dict(zip(column_names, row)) for row in rows]
## OBSERVABILITY
from uuid import uuid4
import csv
from io import StringIO
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from starlette.responses import StreamingResponse
router = APIRouter(
prefix="/observability",
tags=["observability"]
)
class ObservationResponse(BaseModel):
observations: List[Dict]
def create_csv_response(observations: List[Dict]) -> StreamingResponse:
def iter_csv(data):
output = StringIO()
writer = csv.DictWriter(output, fieldnames=data[0].keys() if data else [])
writer.writeheader()
for row in data:
writer.writerow(row)
output.seek(0)
yield output.read()
headers = {
'Content-Disposition': 'attachment; filename="observations.csv"'
}
return StreamingResponse(iter_csv(observations), media_type="text/csv", headers=headers)
@router.get("/last-observations/{limit}")
async def get_last_observations(limit: int = 10, format: str = "json"):
observability_manager = LLMObservabilityManager()
try:
# Get all observations, sorted by created_at in descending order
all_observations = observability_manager.get_observations()
all_observations.sort(key=lambda x: x['created_at'], reverse=True)
# Get the last conversation_id
if all_observations:
last_conversation_id = all_observations[0]['conversation_id']
# Filter observations for the last conversation
last_conversation_observations = [
obs for obs in all_observations
if obs['conversation_id'] == last_conversation_id
][:limit]
if format.lower() == "csv":
return create_csv_response(last_conversation_observations)
else:
return ObservationResponse(observations=last_conversation_observations)
else:
if format.lower() == "csv":
return create_csv_response([])
else:
return ObservationResponse(observations=[])
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to retrieve observations: {str(e)}")