chat / app /main.py
ariansyahdedy's picture
Without database
1157bd1
raw
history blame
9 kB
from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse
from fastapi.responses import Response
from fastapi.exceptions import HTTPException
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from typing import Dict, List
from prometheus_client import Counter, Histogram, start_http_server
from pydantic import BaseModel, ValidationError
from app.services.message import generate_reply, send_reply
import logging
from datetime import datetime
import time
from contextlib import asynccontextmanager
# from app.db.database import create_indexes, init_db
from app.services.webhook_handler import verify_webhook
from app.handlers.message_handler import MessageHandler
from app.handlers.webhook_handler import WebhookHandler
from app.handlers.media_handler import WhatsAppMediaHandler
from app.services.cache import MessageCache
from app.services.chat_manager import ChatManager
from app.utils.load_env import ACCESS_TOKEN
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize handlers at startup
message_handler = None
webhook_handler = None
async def setup_message_handler():
logger = logging.getLogger(__name__)
message_cache = MessageCache()
chat_manager = ChatManager()
media_handler = WhatsAppMediaHandler()
return MessageHandler(
message_cache=message_cache,
chat_manager=chat_manager,
media_handler=media_handler,
logger=logger
)
# Initialize FastAPI app
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
# await init_db()
logger.info("Connected to the MongoDB database!")
global message_handler, webhook_handler
message_handler = await setup_message_handler()
webhook_handler = WebhookHandler(message_handler)
# collections = app.database.list_collection_names()
# print(f"Collections in {db_name}: {collections}")
yield
except Exception as e:
logger.error(e)
# Initialize Limiter and Prometheus Metrics
limiter = Limiter(key_func=get_remote_address)
app = FastAPI(lifespan=lifespan)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# app.include_router(users.router, prefix="/users", tags=["Users"])
# Prometheus metrics
webhook_requests = Counter('webhook_requests_total', 'Total webhook requests')
webhook_processing_time = Histogram('webhook_processing_seconds', 'Time spent processing webhook')
# Start Prometheus metrics server on port 8002
# start_http_server(8002)
# Register webhook routes
# app.post("/webhook")(webhook)
# Define Pydantic schema for request validation
class WebhookPayload(BaseModel):
entry: List[Dict]
@app.post("/webhook")
@limiter.limit("100/minute")
async def webhook(request: Request):
try:
payload = await request.json()
# validated_payload = WebhookPayload(**payload) # Validate payload
# logger.info(f"Validated Payload: {validated_payload}")
# Process the webhook payload here
# For example:
# results = process_webhook_entries(validated_payload.entry)
response = await webhook_handler.process_webhook(
payload=payload,
access_token=ACCESS_TOKEN
)
return JSONResponse(
content=response.__dict__,
status_code=status.HTTP_200_OK
)
except ValidationError as ve:
logger.error(f"Validation error: {ve}")
return JSONResponse(
content={"status": "error", "detail": ve.errors()},
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
return JSONResponse(
content={"status": "error", "detail": str(e)},
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
app.get("/webhook")(verify_webhook)
# Add a route for Prometheus metrics (optional, if not using a separate Prometheus server)
@app.get("/metrics")
async def metrics():
from prometheus_client import generate_latest
return Response(content=generate_latest(), media_type="text/plain")
# In-memory cache with timestamp cleanup
# class MessageCache:
# def __init__(self, max_age_hours: int = 24):
# self.messages: Dict[str, float] = {}
# self.max_age_seconds = max_age_hours * 3600
# def add(self, message_id: str) -> None:
# self.cleanup()
# self.messages[message_id] = time.time()
# def exists(self, message_id: str) -> bool:
# self.cleanup()
# return message_id in self.messages
# def cleanup(self) -> None:
# current_time = time.time()
# self.messages = {
# msg_id: timestamp
# for msg_id, timestamp in self.messages.items()
# if current_time - timestamp < self.max_age_seconds
# }
# message_cache = MessageCache()
# user_chats = {}
# @app.post("/webhook")
# async def webhook(request: Request):
# request_id = f"req_{int(time.time()*1000)}"
# logger.info(f"Processing webhook request {request_id}")
# payload = await request.json()
# print("Webhook received:", payload)
# processed_count = 0
# error_count = 0
# results = []
# entries = payload.get("entry", [])
# for entry in entries:
# entry_id = entry.get("id")
# logger.info(f"Processing entry_id: {entry_id}")
# changes = entry.get("changes", [])
# for change in changes:
# messages = change.get("value", {}).get("messages", [])
# for message in messages:
# message_id = message.get("id")
# timestamp = message.get("timestamp")
# content = message.get("text", {}).get("body")
# sender_id = message.get("from")
# msg_type = message.get('type')
# # Deduplicate messages based on message_id
# if message_cache.exists(message_id):
# logger.info(f"Duplicate message detected and skipped: {message_id}")
# continue
# if sender_id not in user_chats:
# user_chats[sender_id] = []
# user_chats[sender_id].append({
# "role": "user",
# "content": content
# })
# history = "".join([f"{item['role']}: {item['content']}\n" for item in user_chats[sender_id]])
# print(f"history: {history}")
# try:
# # Process message with retry logic
# result = await process_message_with_retry(
# sender_id,content,
# history,
# timestamp,
# )
# user_chats[sender_id].append({
# "role": "assistant",
# "content": result
# })
# # Add the message ID to the cache
# message_cache.add(message_id)
# processed_count += 1
# results.append(result)
# except Exception as e:
# error_count += 1
# logger.error(
# f"Failed to process message {message_id}: {str(e)}",
# exc_info=True
# )
# results.append({
# "status": "error",
# "message_id": message_id,
# "error": str(e)
# })
# response_data = {
# "request_id": request_id,
# "processed": processed_count,
# "errors": error_count,
# "results": results
# }
# logger.info(
# f"Webhook processing completed - "
# f"Processed: {processed_count}, Errors: {error_count}"
# )
# return JSONResponse(
# content=response_data,
# status_code=status.HTTP_200_OK
# )
# @app.get("/webhook")
# async def verify_webhook(request: Request):
# mode = request.query_params.get('hub.mode')
# token = request.query_params.get('hub.verify_token')
# challenge = request.query_params.get('hub.challenge')
# # Replace 'your_verification_token' with the token you set in Facebook Business Manager
# if mode == 'subscribe' and token == 'test':
# # Return the challenge as plain text
# return Response(content=challenge, media_type="text/plain")
# else:
# raise HTTPException(status_code=403, detail="Verification failed")