Spaces:
Build error
Build error
from fastapi import FastAPI, Request, status | |
from fastapi.responses import JSONResponse | |
from fastapi.responses import Response | |
from fastapi.exceptions import HTTPException | |
from fastapi.background import BackgroundTasks | |
from slowapi import Limiter, _rate_limit_exceeded_handler | |
from slowapi.errors import RateLimitExceeded | |
from slowapi.util import get_remote_address | |
from slowapi.middleware import SlowAPIMiddleware | |
from typing import Dict, List | |
from prometheus_client import Counter, Histogram, start_http_server | |
from pydantic import BaseModel, ValidationError | |
from app.services.message import generate_reply, send_reply | |
import logging | |
import httpx | |
from datetime import datetime | |
from sentence_transformers import SentenceTransformer | |
from app.search.rag_pipeline import RAGSystem | |
from contextlib import asynccontextmanager | |
# from app.db.database import create_indexes, init_db | |
# from app.services.webhook_handler import verify_webhook | |
from app.handlers.message_handler import MessageHandler | |
from app.handlers.webhook_handler import WebhookHandler | |
from app.handlers.media_handler import WhatsAppMediaHandler | |
from app.services.cache import MessageCache | |
from app.services.chat_manager import ChatManager | |
from app.api.api_prompt import prompt_router | |
from app.api.api_file import file_router, load_file_with_markdown_function | |
from app.utils.load_env import ACCESS_TOKEN, WHATSAPP_API_URL, GEMINI_API | |
from markitdown import MarkItDown | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Initialize handlers at startup | |
message_handler = None | |
webhook_handler = None | |
indexed_links = ["https://sswalfa.surabaya.go.id/info/detail/izin-pengumpulan-sumbangan-bencana", | |
"https://sswalfa.surabaya.go.id/info/detail/izin-pemakaian-ruang-terbuka-hijau", | |
"https://sswalfa.surabaya.go.id/info/detail/pengganti-ipt", | |
"https://sswalfa.surabaya.go.id/info/detail/arahan-sistem-drainase", | |
"https://sswalfa.surabaya.go.id/info/detail/rangkaian-pelayanan-surat-pernyataan-belum-menikah-lagi-bagi-jandaduda" | |
] | |
async def setup_message_handler(): | |
logger = logging.getLogger(__name__) | |
message_cache = MessageCache() | |
chat_manager = ChatManager() | |
media_handler = WhatsAppMediaHandler() | |
return MessageHandler( | |
message_cache=message_cache, | |
chat_manager=chat_manager, | |
media_handler=media_handler, | |
logger=logger | |
) | |
async def setup_rag_system(): | |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # Replace with your model if different | |
rag_system = RAGSystem(embedding_model) | |
return rag_system | |
# Initialize FastAPI app | |
async def lifespan(app: FastAPI): | |
try: | |
# await init_db() | |
logger.info("Connected to the MongoDB database!") | |
rag_system = await setup_rag_system() | |
app.state.rag_system = rag_system | |
global message_handler, webhook_handler | |
message_handler = await setup_message_handler() | |
webhook_handler = WebhookHandler(message_handler) | |
# collections = app.database.list_collection_names() | |
# print(f"Collections in {db_name}: {collections}") | |
await load_file_with_markdown_function(rag_system=rag_system, filepaths=indexed_links) | |
yield | |
except Exception as e: | |
logger.error(e) | |
# Initialize Limiter and Prometheus Metrics | |
limiter = Limiter(key_func=get_remote_address) | |
app = FastAPI(lifespan=lifespan) | |
app.state.limiter = limiter | |
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
# Add SlowAPI Middleware | |
app.add_middleware(SlowAPIMiddleware) | |
# app.include_router(users.router, prefix="/users", tags=["Users"]) | |
app.include_router(prompt_router, prefix="/prompts", tags=["Prompts"]) | |
app.include_router(file_router, prefix="/file_load", tags=["File Load"]) | |
# Prometheus metrics | |
webhook_requests = Counter('webhook_requests_total', 'Total webhook requests') | |
webhook_processing_time = Histogram('webhook_processing_seconds', 'Time spent processing webhook') | |
# Start Prometheus metrics server on port 8002 | |
# start_http_server(8002) | |
# Register webhook routes | |
# app.post("/webhook")(webhook) | |
# Define Pydantic schema for request validation | |
class WebhookPayload(BaseModel): | |
entry: List[Dict] | |
# @limiter.limit("20/minute") | |
async def webhook(request: Request, background_tasks: BackgroundTasks): | |
try: | |
payload = await request.json() | |
rag_system = request.app.state.rag_system | |
# validated_payload = WebhookPayload(**payload) # Validate payload | |
# logger.info(f"Validated Payload: {validated_payload}") | |
# Process the webhook payload here | |
# For example: | |
# results = process_webhook_entries(validated_payload.entry) | |
# e.g., whatsapp_token, verify_token, llm_api_key, llm_model | |
whatsapp_token = request.query_params.get("whatsapp_token") | |
whatsapp_url = request.query_params.get("whatsapp_url") | |
gemini_api = request.query_params.get("gemini_api") | |
llm_model = request.query_params.get("cx_code") | |
# Return HTTP 200 immediately | |
# response = JSONResponse( | |
# content={"status": "received"}, | |
# status_code=200 | |
# ) | |
print(f"payload: {payload}") | |
# response = await webhook_handler.process_webhook( | |
# payload=payload, | |
# whatsapp_token=ACCESS_TOKEN, | |
# whatsapp_url=WHATSAPP_API_URL, | |
# gemini_api=GEMINI_API, | |
# rag_system=rag_system, | |
# ) | |
# Add the processing to background tasks | |
background_tasks.add_task( | |
webhook_handler.process_webhook, | |
payload=payload, | |
whatsapp_token=ACCESS_TOKEN, | |
whatsapp_url=WHATSAPP_API_URL, | |
gemini_api=GEMINI_API, | |
rag_system=rag_system, | |
) | |
# Return HTTP 200 immediately | |
return JSONResponse( | |
content={"status": "received"}, | |
status_code=status.HTTP_200_OK | |
) | |
# return JSONResponse( | |
# content=response.__dict__, | |
# status_code=status.HTTP_200_OK | |
# ) | |
except ValidationError as ve: | |
logger.error(f"Validation error: {ve}") | |
return JSONResponse( | |
content={"status": "error", "detail": ve.errors()}, | |
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY | |
) | |
except Exception as e: | |
logger.error(f"Unexpected error: {str(e)}") | |
return JSONResponse( | |
content={"status": "error", "detail": str(e)}, | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR | |
) | |
async def verify_webhook(request: Request): | |
mode = request.query_params.get('hub.mode') | |
token = request.query_params.get('hub.verify_token') | |
challenge = request.query_params.get('hub.challenge') | |
if mode == 'subscribe' and token == 'test': | |
return Response(content=challenge, media_type="text/plain") | |
else: | |
raise HTTPException(status_code=403, detail="Verification failed") | |
async def load_file_with_markitdown(file_path:str, llm_client:str=None, model:str=None): | |
if llm_client and model: | |
markitdown = MarkItDown(llm_client, model) | |
documents = markitdown.convert(file_path) | |
else: | |
markitdown = MarkItDown() | |
documents = markitdown.convert(file_path) | |
print(f"documents: {documents}") | |
return documents | |
# Add a route for Prometheus metrics (optional, if not using a separate Prometheus server) | |
async def metrics(): | |
from prometheus_client import generate_latest | |
return Response(content=generate_latest(), media_type="text/plain") | |
# In-memory cache with timestamp cleanup | |
# class MessageCache: | |
# def __init__(self, max_age_hours: int = 24): | |
# self.messages: Dict[str, float] = {} | |
# self.max_age_seconds = max_age_hours * 3600 | |
# def add(self, message_id: str) -> None: | |
# self.cleanup() | |
# self.messages[message_id] = time.time() | |
# def exists(self, message_id: str) -> bool: | |
# self.cleanup() | |
# return message_id in self.messages | |
# def cleanup(self) -> None: | |
# current_time = time.time() | |
# self.messages = { | |
# msg_id: timestamp | |
# for msg_id, timestamp in self.messages.items() | |
# if current_time - timestamp < self.max_age_seconds | |
# } | |
# message_cache = MessageCache() | |
# user_chats = {} | |