Spaces:
Paused
Paused
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import Optional, Dict, Any, Union | |
import torch | |
import logging | |
from pathlib import Path | |
from litgpt.api import LLM | |
import os | |
import uvicorn | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="LLM Engine Service") | |
# Global variable to store the LLM instance | |
llm_instance = None | |
class InitializeRequest(BaseModel): | |
""" | |
Configuration for model initialization including model path | |
""" | |
mode: str = "cpu" | |
precision: Optional[str] = None | |
quantize: Optional[str] = None | |
gpu_count: Union[str, int] = "auto" | |
model_path: str | |
class GenerateRequest(BaseModel): | |
prompt: str | |
max_new_tokens: int = 50 | |
temperature: float = 1.0 | |
top_k: Optional[int] = None | |
top_p: float = 1.0 | |
return_as_token_ids: bool = False | |
stream: bool = False | |
async def initialize_model(request: InitializeRequest): | |
""" | |
Initialize the LLM model with specified configuration. | |
""" | |
global llm_instance | |
try: | |
if request.precision is None and request.quantize is None: | |
# Use auto distribution from load when no specific precision or quantization is set | |
llm_instance = LLM.load( | |
model=request.model_path, | |
distribute="auto" # Let the load function handle distribution automatically | |
) | |
logger.info( | |
f"Model initialized with auto settings:\n" | |
f"Model Path: {request.model_path}\n" | |
f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, " | |
f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved" | |
) | |
else: | |
# Original initialization path for when specific settings are requested | |
llm_instance = LLM.load( | |
model=request.model_path, | |
distribute=None # We'll distribute manually | |
) | |
# Distribute the model according to the configuration | |
llm_instance.distribute( | |
accelerator="cuda" if request.mode == "gpu" else "cpu", | |
devices=request.gpu_count, | |
precision=request.precision, | |
quantize=request.quantize | |
) | |
logger.info( | |
f"Model initialized successfully with config:\n" | |
f"Mode: {request.mode}\n" | |
f"Precision: {request.precision}\n" | |
f"Quantize: {request.quantize}\n" | |
f"GPU Count: {request.gpu_count}\n" | |
f"Model Path: {request.model_path}\n" | |
f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, " | |
f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved" | |
) | |
return {"success": True, "message": "Model initialized successfully"} | |
except Exception as e: | |
logger.error(f"Error initializing model: {str(e)}") | |
# Print detailed memory statistics on failure | |
logger.error(f"GPU Memory Stats:\n" | |
f"Allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB\n" | |
f"Reserved: {torch.cuda.memory_reserved()/1024**3:.2f}GB\n" | |
f"Max Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f}GB") | |
raise HTTPException(status_code=500, detail=f"Error initializing model: {str(e)}") | |
async def generate(request: GenerateRequest): | |
""" | |
Generate text using the initialized model. | |
""" | |
global llm_instance | |
if llm_instance is None: | |
raise HTTPException(status_code=400, detail="Model not initialized. Call /initialize first.") | |
try: | |
if request.stream: | |
# For streaming responses, we need to handle differently | |
# This is a placeholder as the actual streaming implementation | |
# would need to use StreamingResponse from FastAPI | |
raise HTTPException( | |
status_code=400, | |
detail="Streaming is not currently supported through the API" | |
) | |
generated_text = llm_instance.generate( | |
prompt=request.prompt, | |
max_new_tokens=request.max_new_tokens, | |
temperature=request.temperature, | |
top_k=request.top_k, | |
top_p=request.top_p, | |
return_as_token_ids=request.return_as_token_ids, | |
stream=False # Force stream to False for now | |
) | |
response = { | |
"generated_text": generated_text if not request.return_as_token_ids else generated_text.tolist(), | |
"metadata": { | |
"prompt": request.prompt, | |
"max_new_tokens": request.max_new_tokens, | |
"temperature": request.temperature, | |
"top_k": request.top_k, | |
"top_p": request.top_p | |
} | |
} | |
return response | |
except Exception as e: | |
logger.error(f"Error generating text: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}") | |
async def health_check(): | |
""" | |
Check if the service is running and model is loaded. | |
""" | |
global llm_instance | |
status = { | |
"status": "healthy", | |
"model_loaded": llm_instance is not None, | |
} | |
if llm_instance is not None: | |
status["model_info"] = { | |
"model_path": llm_instance.config.name, | |
"device": str(next(llm_instance.model.parameters()).device) | |
} | |
return status | |
def main(): | |
# Load environment variables or configuration here | |
host = os.getenv("LLM_ENGINE_HOST", "0.0.0.0") | |
port = int(os.getenv("LLM_ENGINE_PORT", "8001")) | |
# Start the server | |
uvicorn.run( | |
app, | |
host=host, | |
port=port, | |
log_level="info", | |
reload=False | |
) | |
if __name__ == "__main__": | |
main() |