LLMServer / main /main.py
AurelioAguirre's picture
First commit
d828ce4
raw
history blame
5.99 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, Dict, Any, Union
import torch
import logging
from pathlib import Path
from litgpt.api import LLM
import os
import uvicorn
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="LLM Engine Service")
# Global variable to store the LLM instance
llm_instance = None
class InitializeRequest(BaseModel):
"""
Configuration for model initialization including model path
"""
mode: str = "cpu"
precision: Optional[str] = None
quantize: Optional[str] = None
gpu_count: Union[str, int] = "auto"
model_path: str
class GenerateRequest(BaseModel):
prompt: str
max_new_tokens: int = 50
temperature: float = 1.0
top_k: Optional[int] = None
top_p: float = 1.0
return_as_token_ids: bool = False
stream: bool = False
@app.post("/initialize")
async def initialize_model(request: InitializeRequest):
"""
Initialize the LLM model with specified configuration.
"""
global llm_instance
try:
if request.precision is None and request.quantize is None:
# Use auto distribution from load when no specific precision or quantization is set
llm_instance = LLM.load(
model=request.model_path,
distribute="auto" # Let the load function handle distribution automatically
)
logger.info(
f"Model initialized with auto settings:\n"
f"Model Path: {request.model_path}\n"
f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, "
f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved"
)
else:
# Original initialization path for when specific settings are requested
llm_instance = LLM.load(
model=request.model_path,
distribute=None # We'll distribute manually
)
# Distribute the model according to the configuration
llm_instance.distribute(
accelerator="cuda" if request.mode == "gpu" else "cpu",
devices=request.gpu_count,
precision=request.precision,
quantize=request.quantize
)
logger.info(
f"Model initialized successfully with config:\n"
f"Mode: {request.mode}\n"
f"Precision: {request.precision}\n"
f"Quantize: {request.quantize}\n"
f"GPU Count: {request.gpu_count}\n"
f"Model Path: {request.model_path}\n"
f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, "
f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved"
)
return {"success": True, "message": "Model initialized successfully"}
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
# Print detailed memory statistics on failure
logger.error(f"GPU Memory Stats:\n"
f"Allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB\n"
f"Reserved: {torch.cuda.memory_reserved()/1024**3:.2f}GB\n"
f"Max Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f}GB")
raise HTTPException(status_code=500, detail=f"Error initializing model: {str(e)}")
@app.post("/generate")
async def generate(request: GenerateRequest):
"""
Generate text using the initialized model.
"""
global llm_instance
if llm_instance is None:
raise HTTPException(status_code=400, detail="Model not initialized. Call /initialize first.")
try:
if request.stream:
# For streaming responses, we need to handle differently
# This is a placeholder as the actual streaming implementation
# would need to use StreamingResponse from FastAPI
raise HTTPException(
status_code=400,
detail="Streaming is not currently supported through the API"
)
generated_text = llm_instance.generate(
prompt=request.prompt,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
return_as_token_ids=request.return_as_token_ids,
stream=False # Force stream to False for now
)
response = {
"generated_text": generated_text if not request.return_as_token_ids else generated_text.tolist(),
"metadata": {
"prompt": request.prompt,
"max_new_tokens": request.max_new_tokens,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p
}
}
return response
except Exception as e:
logger.error(f"Error generating text: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
@app.get("/health")
async def health_check():
"""
Check if the service is running and model is loaded.
"""
global llm_instance
status = {
"status": "healthy",
"model_loaded": llm_instance is not None,
}
if llm_instance is not None:
status["model_info"] = {
"model_path": llm_instance.config.name,
"device": str(next(llm_instance.model.parameters()).device)
}
return status
def main():
# Load environment variables or configuration here
host = os.getenv("LLM_ENGINE_HOST", "0.0.0.0")
port = int(os.getenv("LLM_ENGINE_PORT", "8001"))
# Start the server
uvicorn.run(
app,
host=host,
port=port,
log_level="info",
reload=False
)
if __name__ == "__main__":
main()