Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Request, Security, Depends | |
from fastapi.middleware.cors import CORSMiddleware | |
from typing import Dict, List, Optional, Union, Any | |
from pydantic import BaseModel, Field | |
from datetime import datetime | |
import logging | |
import json | |
import os | |
from dotenv import load_dotenv | |
from dify_client_python.dify_client import models | |
from sse_starlette.sse import EventSourceResponse | |
import httpx | |
from json_parser import SSEParser | |
from logger_config import setup_logger | |
from fastapi.responses import StreamingResponse | |
from fastapi.responses import JSONResponse | |
from response_formatter import ResponseFormatter | |
import traceback | |
from fastapi.security.api_key import APIKeyHeader, APIKey | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.openapi.docs import get_swagger_ui_html | |
from fastapi.responses import HTMLResponse | |
# Load environment variables | |
load_dotenv() | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Add these constants near the top of the file after imports | |
API_KEY_NAME = "X-API-Key" | |
API_KEY = os.getenv("CLIENT_API_KEY") # Add this to your .env file | |
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=True) | |
class AgentOutput(BaseModel): | |
"""Structured output from agent processing""" | |
thought_content: str | |
observation: Optional[str] | |
tool_outputs: List[Dict] | |
citations: List[Dict] | |
metadata: Dict | |
raw_response: str | |
class AgentRequest(BaseModel): | |
"""Enhanced request model with additional parameters""" | |
query: str | |
conversation_id: Optional[str] = None | |
stream: bool = True | |
inputs: Dict = {} | |
files: List = [] | |
user: str = "default_user" | |
response_mode: str = "streaming" | |
class AgentProcessor: | |
def __init__(self, api_key: str): | |
self.api_key = api_key | |
# Update API base to use environment variable with fallback | |
self.api_base = os.getenv( | |
"API_BASE_URL", | |
"https://rag-engine.go-yamamoto.com/v1" | |
) | |
self.formatter = ResponseFormatter() | |
self.client = httpx.AsyncClient(timeout=60.0) | |
self.logger = setup_logger("agent_processor") | |
async def log_request_details( | |
self, | |
request: AgentRequest, | |
start_time: datetime | |
) -> None: | |
"""Log detailed request information""" | |
self.logger.debug( | |
"Request details: \n" | |
f"Query: {request.query}\n" | |
f"User: {request.user}\n" | |
f"Conversation ID: {request.conversation_id}\n" | |
f"Stream mode: {request.stream}\n" | |
f"Start time: {start_time}\n" | |
f"Inputs: {request.inputs}\n" | |
f"Files: {len(request.files)} files attached" | |
) | |
async def log_error( | |
self, | |
error: Exception, | |
context: Optional[Dict] = None | |
) -> None: | |
"""Log detailed error information""" | |
error_msg = ( | |
f"Error type: {type(error).__name__}\n" | |
f"Error message: {str(error)}\n" | |
f"Stack trace:\n{traceback.format_exc()}\n" | |
) | |
if context: | |
error_msg += f"Context:\n{json.dumps(context, indent=2)}" | |
self.logger.error(error_msg) | |
async def cleanup(self): | |
"""Cleanup method to properly close client""" | |
await self.client.aclose() | |
async def process_stream(self, request: AgentRequest): | |
start_time = datetime.now() | |
await self.log_request_details(request, start_time) | |
headers = { | |
"Authorization": f"Bearer {self.api_key}", | |
"Content-Type": "application/json", | |
"Accept": "text/event-stream" | |
} | |
chat_request = { | |
"query": request.query, | |
"inputs": request.inputs, | |
"response_mode": "streaming" if request.stream else "blocking", | |
"user": request.user, | |
"conversation_id": request.conversation_id, | |
"files": request.files | |
} | |
async def event_generator(): | |
parser = SSEParser() | |
citations = [] | |
metadata = {} | |
try: | |
async with self.client.stream( | |
"POST", | |
f"{self.api_base}/chat-messages", | |
headers=headers, | |
json=chat_request | |
) as response: | |
self.logger.debug( | |
f"Stream connection established\n" | |
f"Status: {response.status_code}\n" | |
f"Headers: {dict(response.headers)}" | |
) | |
buffer = "" | |
async for line in response.aiter_lines(): | |
if not line.strip(): | |
continue | |
self.logger.debug(f"Raw SSE line: {line}") | |
if "data:" in line: | |
try: | |
data = line.split("data:", 1)[1].strip() | |
parsed = json.loads(data) | |
if parsed.get("event") == "message_end": | |
citations = parsed.get("retriever_resources", []) | |
metadata = parsed.get("metadata", {}) | |
self.logger.debug( | |
f"Message end event:\n" | |
f"Citations: {citations}\n" | |
f"Metadata: {metadata}" | |
) | |
formatted = self.format_terminal_output( | |
parsed, | |
citations=citations, | |
metadata=metadata | |
) | |
if formatted: | |
self.logger.info(formatted) | |
except Exception as e: | |
await self.log_error( | |
e, | |
{"line": line, "event": "parse_data"} | |
) | |
buffer += line + "\n" | |
if line.startswith("data:") or buffer.strip().endswith("}"): | |
try: | |
processed_response = parser.parse_sse_event(buffer) | |
if processed_response and isinstance(processed_response, dict): | |
cleaned_response = self.clean_response(processed_response) | |
if cleaned_response: | |
xml_content = cleaned_response.get("content", "") | |
yield f"data: {xml_content}\n\n" | |
except Exception as parse_error: | |
await self.log_error( | |
parse_error, | |
{"buffer": buffer, "event": "process_buffer"} | |
) | |
error_xml = ( | |
f"<agent_response>" | |
f"<error>{str(parse_error)}</error>" | |
f"</agent_response>" | |
) | |
yield f"data: {error_xml}\n\n" | |
finally: | |
buffer = "" | |
except httpx.ConnectError as e: | |
await self.log_error(e, {"event": "connection_error"}) | |
error_xml = ( | |
f"<agent_response>" | |
f"<error>Connection error: {str(e)}</error>" | |
f"</agent_response>" | |
) | |
yield f"data: {error_xml}\n\n" | |
except Exception as e: | |
await self.log_error(e, {"event": "stream_error"}) | |
error_xml = ( | |
f"<agent_response>" | |
f"<error>Streaming error: {str(e)}</error>" | |
f"</agent_response>" | |
) | |
yield f"data: {error_xml}\n\n" | |
finally: | |
end_time = datetime.now() | |
duration = (end_time - start_time).total_seconds() | |
self.logger.info(f"Request completed in {duration:.2f} seconds") | |
return StreamingResponse( | |
event_generator(), | |
media_type="text/event-stream", | |
headers={ | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
"X-Accel-Buffering": "no", | |
"Access-Control-Allow-Origin": "*" | |
} | |
) | |
def format_terminal_output( | |
self, | |
response: Dict, | |
citations: List[Dict] = None, | |
metadata: Dict = None | |
) -> Optional[str]: | |
"""Format response for terminal output""" | |
event_type = response.get("event") | |
if event_type == "agent_thought": | |
thought = response.get("thought", "") | |
observation = response.get("observation", "") | |
terminal_output, _ = self.formatter.format_thought( | |
thought, | |
observation, | |
citations=citations, | |
metadata=metadata | |
) | |
return terminal_output | |
elif event_type == "agent_message": | |
message = response.get("answer", "") | |
terminal_output, _ = self.formatter.format_message(message) | |
return terminal_output | |
elif event_type == "error": | |
error = response.get("error", "Unknown error") | |
terminal_output, _ = self.formatter.format_error(error) | |
return terminal_output | |
return None | |
def clean_response(self, response: Dict) -> Optional[Dict]: | |
"""Clean and transform the response for frontend consumption""" | |
try: | |
event_type = response.get("event") | |
if not event_type: | |
return None | |
# Handle different event types | |
if event_type == "agent_thought": | |
thought = response.get("thought", "") | |
observation = response.get("observation", "") | |
_, xml_output = self.formatter.format_thought(thought, observation) | |
return { | |
"type": "thought", | |
"content": xml_output | |
} | |
elif event_type == "agent_message": | |
message = response.get("answer", "") | |
_, xml_output = self.formatter.format_message(message) | |
return { | |
"type": "message", | |
"content": xml_output | |
} | |
elif event_type == "error": | |
error = response.get("error", "Unknown error") | |
_, xml_output = self.formatter.format_error(error) | |
return { | |
"type": "error", | |
"content": xml_output | |
} | |
return None | |
except Exception as e: | |
logger.error(f"Error cleaning response: {str(e)}") | |
return None | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Agent API", | |
description="API requiring X-API-Key authentication", | |
version="1.0.0", | |
docs_url=None, # Disable default docs | |
openapi_tags=[{"name": "agent", "description": "Agent endpoints"}], | |
) | |
agent_processor = None | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Add security scheme | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Add security scheme | |
app.add_security_requirement({"ApiKeyAuth": []}) | |
app.openapi_schema = None # Reset OpenAPI schema | |
# Define the security scheme | |
security_scheme = { | |
"ApiKeyAuth": { | |
"type": "apiKey", | |
"in": "header", | |
"name": "X-API-Key", | |
"description": "API key required for authentication" | |
} | |
} | |
app.openapi_components = {"securitySchemes": security_scheme} | |
async def startup_event(): | |
global agent_processor | |
api_key = os.getenv("DIFY_API_KEY", "app-kVHTrZzEmFXEBfyXOi4rro7M") | |
agent_processor = AgentProcessor(api_key=api_key) | |
async def shutdown_event(): | |
global agent_processor | |
if agent_processor: | |
await agent_processor.cleanup() | |
# Add this function before your routes | |
async def get_api_key( | |
api_key_header: str = Security(api_key_header) | |
) -> APIKey: | |
"""Validate API key from header.""" | |
if not API_KEY: | |
raise HTTPException( | |
status_code=500, | |
detail="API key configuration is missing on server" | |
) | |
if api_key_header == API_KEY: | |
return api_key_header | |
raise HTTPException( | |
status_code=403, | |
detail="Invalid or missing API key" | |
) | |
# Update your route to require API key | |
async def process_agent_request( | |
request: AgentRequest, | |
api_key: APIKey = Security(api_key_header, scopes=[]) | |
): | |
try: | |
logger.info(f"Processing agent request: {request.query}") | |
return await agent_processor.process_stream(request) | |
except Exception as e: | |
logger.error(f"Error in agent request processing: {e}", exc_info=True) | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def error_handling_middleware(request: Request, call_next): | |
try: | |
response = await call_next(request) | |
return response | |
except Exception as e: | |
logger.error(f"Unhandled error: {str(e)}", exc_info=True) | |
return JSONResponse( | |
status_code=500, | |
content={"error": "Internal server error occurred"} | |
) | |
# Add host and port parameters to the launch | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.getenv("PORT", 7860)) | |
uvicorn.run( | |
"api:app", | |
host="0.0.0.0", | |
port=port, | |
reload=True | |
) | |
# Add custom docs endpoint | |
async def custom_swagger_ui_html(): | |
return get_swagger_ui_html( | |
openapi_url=app.openapi_url, | |
title=app.title + " - Swagger UI", | |
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url, | |
swagger_js_url="https://unpkg.com/swagger-ui-dist@5.9.0/swagger-ui-bundle.js", | |
swagger_css_url="https://unpkg.com/swagger-ui-dist@5.9.0/swagger-ui.css", | |
swagger_favicon_url="https://fastapi.tiangolo.com/img/favicon.png", | |
extra_html=""" | |
<style> | |
/* Dark theme with cool colors */ | |
:root { | |
--primary-color: #00b4d8; | |
--secondary-color: #90e0ef; | |
--background-color: #0d1117; | |
--text-color: #e6edf3; | |
--border-color: #30363d; | |
} | |
body { | |
background-color: var(--background-color); | |
color: var(--text-color); | |
} | |
.swagger-ui { | |
background-color: var(--background-color); | |
color: var(--text-color); | |
} | |
/* Headers and text */ | |
.swagger-ui .info .title, | |
.swagger-ui .info .base-url, | |
.swagger-ui .info li, | |
.swagger-ui .info p, | |
.swagger-ui .info table { | |
color: var(--text-color); | |
} | |
/* Operation buttons */ | |
.swagger-ui .opblock.opblock-post { | |
background: rgba(0, 180, 216, 0.1); | |
border-color: var(--primary-color); | |
} | |
.swagger-ui .opblock.opblock-post .opblock-summary-method { | |
background: var(--primary-color); | |
} | |
/* Authorize button */ | |
.swagger-ui .btn.authorize { | |
background: var(--primary-color); | |
border-color: var(--primary-color); | |
color: white; | |
} | |
.swagger-ui .btn.authorize svg { | |
fill: white; | |
} | |
/* Schema sections */ | |
.swagger-ui .model-box { | |
background: rgba(48, 54, 61, 0.4); | |
} | |
.swagger-ui .model { | |
color: var(--text-color); | |
} | |
/* Try it out section */ | |
.swagger-ui textarea, | |
.swagger-ui input[type=text] { | |
background: var(--background-color); | |
color: var(--text-color); | |
border-color: var(--border-color); | |
} | |
/* Response section */ | |
.swagger-ui .responses-table th, | |
.swagger-ui .responses-table td { | |
color: var(--text-color); | |
border-color: var(--border-color); | |
} | |
/* Scrollbar */ | |
::-webkit-scrollbar { | |
width: 8px; | |
height: 8px; | |
} | |
::-webkit-scrollbar-track { | |
background: var(--background-color); | |
} | |
::-webkit-scrollbar-thumb { | |
background: var(--primary-color); | |
border-radius: 4px; | |
} | |
/* Code blocks */ | |
.swagger-ui .highlight-code { | |
background-color: #1b1f24; | |
} | |
/* Modal dialogs */ | |
.swagger-ui .dialog-ux .modal-ux { | |
background: var(--background-color); | |
border-color: var(--border-color); | |
} | |
.swagger-ui .dialog-ux .modal-ux-header h3 { | |
color: var(--text-color); | |
} | |
/* Tables */ | |
.swagger-ui table thead tr td, | |
.swagger-ui table thead tr th { | |
color: var(--text-color); | |
border-color: var(--border-color); | |
} | |
/* Links */ | |
.swagger-ui a { | |
color: var(--primary-color); | |
} | |
/* Schema dropdowns */ | |
.swagger-ui select { | |
background: var(--background-color); | |
color: var(--text-color); | |
border-color: var(--border-color); | |
} | |
</style> | |
""" | |
) |