Spaces:
Sleeping
Sleeping
Enhance lifespan management in FastAPI by initializing PoetryGenerationService and handling model preloading asynchronously
1c1ca6d
import asyncio | |
from contextlib import asynccontextmanager | |
from fastapi import FastAPI | |
from app.api.endpoints.poetry import router as poetry_router | |
import os | |
import logging | |
from typing import Tuple | |
from starlette.responses import Response | |
from starlette.staticfiles import StaticFiles | |
from huggingface_hub import login | |
from functools import lru_cache | |
from app.services.poetry_generation import PoetryGenerationService | |
# Configure logging once at module level | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def get_hf_token() -> str: | |
"""Get Hugging Face token from environment variables.""" | |
token = os.getenv("HF_TOKEN") | |
if not token: | |
raise EnvironmentError( | |
"HF_TOKEN environment variable not found. " | |
"Please set your Hugging Face access token." | |
) | |
return token | |
def init_huggingface(): | |
"""Initialize Hugging Face authentication.""" | |
try: | |
token = get_hf_token() | |
login(token=token) | |
logger.info("Successfully logged in to Hugging Face") | |
except Exception as e: | |
logger.error(f"Failed to login to Hugging Face: {str(e)}") | |
raise | |
async def lifespan(app: FastAPI): | |
# Initialize Hugging Face authentication first | |
init_huggingface() | |
# Initialize poetry service and preload models | |
poetry_service = PoetryGenerationService() | |
# Only await if `preload_models` is asynchronous | |
if callable(getattr(poetry_service, "preload_models", None)): | |
result = poetry_service.preload_models() | |
if asyncio.iscoroutine(result): | |
await result | |
else: | |
result() # Call directly if synchronous | |
yield # Continue to application startup | |
app = FastAPI(lifespan=lifespan) | |
app.include_router(poetry_router, prefix="/api/v1/poetry") | |
async def lifecheck(): | |
return Response("OK", media_type="text/plain") | |
def get_port() -> int: | |
return int(os.getenv("PORT", "8000")) | |
if __name__ == "__main__": | |
import uvicorn | |
port = get_port() | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
logger.info(f"Starting FastAPI server on port {port}") | |
uvicorn.run(app, host="0.0.0.0", port=port) | |