Spaces:
Sleeping
Sleeping
# aiclient.py | |
import os | |
import time | |
import json | |
from typing import List, Dict, Optional, Union, AsyncGenerator | |
from openai import AsyncOpenAI | |
from starlette.responses import StreamingResponse | |
from observability import log_execution ,LLMObservabilityManager | |
import psycopg2 | |
import requests | |
from functools import lru_cache | |
import logging | |
import pandas as pd | |
logger = logging.getLogger(__name__) | |
def get_model_info(): | |
try: | |
model_info_dict = requests.get( | |
'https://openrouter.ai/api/v1/models', | |
headers={'accept': 'application/json'} | |
).json()["data"] | |
# Save the model info to a JSON file | |
with open('model_info.json', 'w') as json_file: | |
json.dump(model_info_dict, json_file, indent=4) | |
except Exception as e: | |
logger.error(f"Failed to fetch model info: {e}. Loading from file.") | |
if os.path.exists('model_info.json'): | |
with open('model_info.json', 'r') as json_file: | |
model_info_dict = json.load(json_file) | |
model_info = pd.DataFrame(model_info_dict) | |
return model_info | |
else: | |
logger.error("No model info file found") | |
return None | |
model_info = pd.DataFrame(model_info_dict) | |
return model_info | |
class AIClient: | |
def __init__(self): | |
self.client = AsyncOpenAI( | |
base_url="https://openrouter.ai/api/v1", | |
api_key=os.environ['OPENROUTER_API_KEY'] | |
) | |
self.observability_manager = LLMObservabilityManager() | |
self.model_info = get_model_info() | |
#@log_execution | |
async def generate_response( | |
self, | |
messages: List[Dict[str, str]], | |
model: str = "openai/gpt-4o-mini", | |
max_tokens: int = 32000, | |
conversation_id: Optional[str] = None, | |
user: str = "anonymous" | |
) -> AsyncGenerator[str, None]: | |
if not messages: | |
return | |
start_time = time.time() | |
full_response = "" | |
usage = {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0} | |
status = "success" | |
try: | |
response = await self.client.chat.completions.create( | |
model=model, | |
messages=messages, | |
max_tokens=max_tokens, | |
stream=True, | |
stream_options={"include_usage": True} | |
) | |
end_time = time.time() | |
latency = end_time - start_time | |
async for chunk in response: | |
if chunk.choices[0].delta.content: | |
yield chunk.choices[0].delta.content | |
full_response += chunk.choices[0].delta.content | |
if chunk.usage: | |
model = chunk.model | |
usage["completion_tokens"] = chunk.usage.completion_tokens | |
usage["prompt_tokens"] = chunk.usage.prompt_tokens | |
usage["total_tokens"] = chunk.usage.total_tokens | |
print(usage) | |
print(model) | |
except Exception as e: | |
status = "error" | |
full_response = str(e) | |
latency = time.time() - start_time | |
print(f"Error in generate_response: {e}") | |
finally: | |
# Log the observation | |
try: | |
pricing_data = self.model_info[self.model_info.id == model]["pricing"].values[0] | |
cost = float(pricing_data["completion"]) * float(usage["completion_tokens"]) + float(pricing_data["prompt"]) * float(usage["prompt_tokens"]) | |
self.observability_manager.insert_observation( | |
response=full_response, | |
model=model, | |
completion_tokens=usage["completion_tokens"], | |
prompt_tokens=usage["prompt_tokens"], | |
total_tokens=usage["total_tokens"], | |
cost=cost, | |
conversation_id=conversation_id or "default", | |
status=status, | |
request=json.dumps([msg for msg in messages if msg.get('role') != 'system']), | |
latency=latency, | |
user=user | |
) | |
except Exception as obs_error: | |
print(f"Error logging observation: {obs_error}") | |
class DatabaseManager: | |
"""Manages database operations.""" | |
def __init__(self): | |
self.db_params = { | |
"dbname": "postgres", | |
"user": os.environ['SUPABASE_USER'], | |
"password": os.environ['SUPABASE_PASSWORD'], | |
"host": "aws-0-us-west-1.pooler.supabase.com", | |
"port": "5432" | |
} | |
def update_database(self, user_id: str, user_query: str, response: str) -> None: | |
with psycopg2.connect(**self.db_params) as conn: | |
with conn.cursor() as cur: | |
insert_query = """ | |
INSERT INTO ai_document_generator (user_id, user_query, response) | |
VALUES (%s, %s, %s); | |
""" | |
cur.execute(insert_query, (user_id, user_query, response)) |