Spaces:
Running
Running
# File: llm_observability.py | |
import sqlite3 | |
import json | |
from datetime import datetime | |
from typing import Dict, Any, List, Optional, Callable | |
import logging | |
import functools | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def log_execution(func: Callable) -> Callable: | |
def wrapper(*args: Any, **kwargs: Any) -> Any: | |
logger.info(f"Executing {func.__name__}") | |
try: | |
result = func(*args, **kwargs) | |
logger.info(f"{func.__name__} completed successfully") | |
return result | |
except Exception as e: | |
logger.error(f"Error in {func.__name__}: {e}") | |
raise | |
return wrapper | |
class LLMObservabilityManager: | |
def __init__(self, db_path: str = "llm_observability_v2.db"): | |
self.db_path = db_path | |
self.create_table() | |
def create_table(self): | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.cursor() | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS llm_observations ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
conversation_id TEXT, | |
created_at DATETIME, | |
status TEXT, | |
request TEXT, | |
response TEXT, | |
model TEXT, | |
prompt_tokens INTEGER, | |
completion_tokens INTEGER, | |
total_tokens INTEGER, | |
cost FLOAT, | |
latency FLOAT, | |
user TEXT | |
) | |
''') | |
def insert_observation(self, response: str, conversation_id: str, status: str, request: str, model: str, prompt_tokens: int,completion_tokens: int, total_tokens: int, cost: float, latency: float, user: str): | |
created_at = datetime.now() | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.cursor() | |
cursor.execute(''' | |
INSERT INTO llm_observations | |
(conversation_id, created_at, status, request, response, model, prompt_tokens, completion_tokens,total_tokens, cost, latency, user) | |
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | |
''', ( | |
conversation_id, | |
created_at, | |
status, | |
request, | |
response, | |
model, | |
prompt_tokens, | |
completion_tokens, | |
total_tokens, | |
cost, | |
latency, | |
user | |
)) | |
def get_observations(self, conversation_id: Optional[str] = None) -> List[Dict[str, Any]]: | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.cursor() | |
if conversation_id: | |
cursor.execute('SELECT * FROM llm_observations WHERE conversation_id = ? ORDER BY created_at', (conversation_id,)) | |
else: | |
cursor.execute('SELECT * FROM llm_observations ORDER BY created_at') | |
rows = cursor.fetchall() | |
column_names = [description[0] for description in cursor.description] | |
return [dict(zip(column_names, row)) for row in rows] | |
def get_all_observations(self) -> List[Dict[str, Any]]: | |
return self.get_observations() | |
def get_all_unique_conversation_observations(self, limit: Optional[int] = None) -> List[Dict[str, Any]]: | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.cursor() | |
# Get the latest observation for each unique conversation_id | |
query = ''' | |
SELECT * FROM llm_observations o1 | |
WHERE created_at = ( | |
SELECT MAX(created_at) | |
FROM llm_observations o2 | |
WHERE o2.conversation_id = o1.conversation_id | |
) | |
ORDER BY created_at DESC | |
''' | |
if limit is not None: | |
query += f' LIMIT {limit}' | |
cursor.execute(query) | |
rows = cursor.fetchall() | |
column_names = [description[0] for description in cursor.description] | |
return [dict(zip(column_names, row)) for row in rows] | |
## OBSERVABILITY | |
from uuid import uuid4 | |
import csv | |
from io import StringIO | |
from fastapi import APIRouter, HTTPException | |
from pydantic import BaseModel | |
from starlette.responses import StreamingResponse | |
router = APIRouter( | |
prefix="/observability", | |
tags=["observability"] | |
) | |
class ObservationResponse(BaseModel): | |
observations: List[Dict] | |
def create_csv_response(observations: List[Dict]) -> StreamingResponse: | |
def iter_csv(data): | |
output = StringIO() | |
writer = csv.DictWriter(output, fieldnames=data[0].keys() if data else []) | |
writer.writeheader() | |
for row in data: | |
writer.writerow(row) | |
output.seek(0) | |
yield output.read() | |
headers = { | |
'Content-Disposition': 'attachment; filename="observations.csv"' | |
} | |
return StreamingResponse(iter_csv(observations), media_type="text/csv", headers=headers) | |
async def get_last_observations(limit: int = 10, format: str = "json"): | |
observability_manager = LLMObservabilityManager() | |
try: | |
# Get all observations, sorted by created_at in descending order | |
all_observations = observability_manager.get_observations() | |
all_observations.sort(key=lambda x: x['created_at'], reverse=True) | |
# Get the last conversation_id | |
if all_observations: | |
last_conversation_id = all_observations[0]['conversation_id'] | |
# Filter observations for the last conversation | |
last_conversation_observations = [ | |
obs for obs in all_observations | |
if obs['conversation_id'] == last_conversation_id | |
][:limit] | |
if format.lower() == "csv": | |
return create_csv_response(last_conversation_observations) | |
else: | |
return ObservationResponse(observations=last_conversation_observations) | |
else: | |
if format.lower() == "csv": | |
return create_csv_response([]) | |
else: | |
return ObservationResponse(observations=[]) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Failed to retrieve observations: {str(e)}") |