poetica / main.py
abhisheksan's picture
Enhance lifespan management in FastAPI by initializing PoetryGenerationService and handling model preloading asynchronously
1c1ca6d
raw
history blame
2.27 kB
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI
from app.api.endpoints.poetry import router as poetry_router
import os
import logging
from typing import Tuple
from starlette.responses import Response
from starlette.staticfiles import StaticFiles
from huggingface_hub import login
from functools import lru_cache
from app.services.poetry_generation import PoetryGenerationService
# Configure logging once at module level
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@lru_cache()
def get_hf_token() -> str:
"""Get Hugging Face token from environment variables."""
token = os.getenv("HF_TOKEN")
if not token:
raise EnvironmentError(
"HF_TOKEN environment variable not found. "
"Please set your Hugging Face access token."
)
return token
def init_huggingface():
"""Initialize Hugging Face authentication."""
try:
token = get_hf_token()
login(token=token)
logger.info("Successfully logged in to Hugging Face")
except Exception as e:
logger.error(f"Failed to login to Hugging Face: {str(e)}")
raise
@asynccontextmanager
async def lifespan(app: FastAPI):
# Initialize Hugging Face authentication first
init_huggingface()
# Initialize poetry service and preload models
poetry_service = PoetryGenerationService()
# Only await if `preload_models` is asynchronous
if callable(getattr(poetry_service, "preload_models", None)):
result = poetry_service.preload_models()
if asyncio.iscoroutine(result):
await result
else:
result() # Call directly if synchronous
yield # Continue to application startup
app = FastAPI(lifespan=lifespan)
app.include_router(poetry_router, prefix="/api/v1/poetry")
@app.get("/healthz")
async def lifecheck():
return Response("OK", media_type="text/plain")
def get_port() -> int:
return int(os.getenv("PORT", "8000"))
if __name__ == "__main__":
import uvicorn
port = get_port()
app.mount("/static", StaticFiles(directory="static"), name="static")
logger.info(f"Starting FastAPI server on port {port}")
uvicorn.run(app, host="0.0.0.0", port=port)