Spaces:
Paused
Paused
from fastapi import APIRouter, HTTPException | |
from pydantic import BaseModel | |
from typing import Optional, List, Dict, Union | |
from .api import LLMApi | |
from .utils.logging import setup_logger | |
from .utils.helpers import get_system_info, format_memory_size | |
from .utils.validation import validate_model_path | |
import psutil | |
from pathlib import Path | |
router = APIRouter() | |
logger = None | |
api = None | |
config = None | |
def init_router(config_dict: dict): | |
"""Initialize router with config and LLM API instance""" | |
global logger, api, config | |
config = config_dict | |
logger = setup_logger(config, "api_routes") | |
api = LLMApi(config) | |
logger.info("Router initialized with LLM API instance") | |
class GenerateRequest(BaseModel): | |
prompt: str | |
system_message: Optional[str] = None | |
max_new_tokens: Optional[int] = None | |
class EmbeddingRequest(BaseModel): | |
text: str | |
class EmbeddingResponse(BaseModel): | |
embedding: List[float] | |
dimension: int | |
class SystemStatusResponse(BaseModel): | |
"""Pydantic model for system status response""" | |
cpu: Optional[Dict[str, Union[float, str]]] = None | |
memory: Optional[Dict[str, Union[float, str]]] = None | |
gpu: Optional[Dict[str, Union[bool, str, float]]] = None | |
storage: Optional[Dict[str, str]] = None | |
model: Optional[Dict[str, Union[bool, str]]] = None | |
class ValidationResponse(BaseModel): | |
config_validation: Dict[str, bool] | |
model_validation: Dict[str, bool] | |
folder_validation: Dict[str, bool] | |
overall_status: str | |
issues: List[str] | |
async def validate_system(): | |
""" | |
Validates: | |
- Configuration parameters | |
- Model setup | |
- Folder structure | |
- Required permissions | |
""" | |
logger.info("Starting system validation") | |
issues = [] | |
# Validate configuration | |
try: | |
config_status = { | |
"has_required_fields": True, # Check if all required config fields exist | |
"valid_paths": True, # Check if paths are valid | |
"valid_parameters": True # Check if parameters are within acceptable ranges | |
} | |
# Example validation checks | |
if not api.models_path.exists(): | |
config_status["valid_paths"] = False | |
issues.append("Models directory does not exist") | |
if api.temperature < 0 or api.temperature > 2: | |
config_status["valid_parameters"] = False | |
issues.append("Temperature parameter out of valid range (0-2)") | |
except Exception as e: | |
logger.error(f"Configuration validation failed: {str(e)}") | |
config_status = {"error": str(e)} | |
issues.append(f"Config validation error: {str(e)}") | |
# Validate model setup | |
try: | |
model_status = { | |
"model_files_exist": False, | |
"model_loadable": False, | |
"tokenizer_valid": False | |
} | |
if api.model_name: | |
model_path = api.models_path / api.model_name.split('/')[-1] | |
model_status["model_files_exist"] = validate_model_path(model_path) | |
if not model_status["model_files_exist"]: | |
issues.append("Model files are missing or incomplete") | |
model_status["model_loadable"] = api.model is not None | |
model_status["tokenizer_valid"] = api.tokenizer is not None | |
except Exception as e: | |
logger.error(f"Model validation failed: {str(e)}") | |
model_status = {"error": str(e)} | |
issues.append(f"Model validation error: {str(e)}") | |
# Validate folder structure and permissions | |
try: | |
folder_status = {"models_folder": api.models_path.exists(), "cache_folder": api.cache_path.exists(), | |
"logs_folder": Path(api.base_path / "logs").exists(), "write_permissions": False} | |
# Test write permissions by attempting to create a test file | |
test_file = api.models_path / ".test_write" | |
try: | |
test_file.touch() | |
test_file.unlink() | |
folder_status["write_permissions"] = True | |
except: | |
folder_status["write_permissions"] = False | |
issues.append("Insufficient write permissions in models directory") | |
except Exception as e: | |
logger.error(f"Folder validation failed: {str(e)}") | |
folder_status = {"error": str(e)} | |
issues.append(f"Folder validation error: {str(e)}") | |
# Determine overall status | |
if not issues: | |
overall_status = "valid" | |
elif len(issues) < 3: | |
overall_status = "warning" | |
else: | |
overall_status = "invalid" | |
validation_response = ValidationResponse( | |
config_validation=config_status, | |
model_validation=model_status, | |
folder_validation=folder_status, | |
overall_status=overall_status, | |
issues=issues | |
) | |
logger.info(f"System validation completed with status: {overall_status}") | |
return validation_response | |
async def check_system(): | |
""" | |
Get system status including: | |
- CPU usage | |
- Memory usage | |
- GPU availability and usage | |
- Storage status for model and cache directories | |
- Current model status | |
""" | |
logger.info("Checking system status") | |
status = SystemStatusResponse() | |
system_info = None | |
# Check CPU and Memory | |
try: | |
system_info = get_system_info() | |
status.cpu = { | |
"usage_percent": system_info["cpu_percent"], | |
"status": "healthy" if system_info["cpu_percent"] < 90 else "high" | |
} | |
logger.debug(f"CPU status retrieved: {status.cpu}") | |
except Exception as e: | |
logger.error(f"Failed to get CPU info: {str(e)}") | |
status.cpu = {"status": "error", "message": str(e)} | |
# Check Memory | |
try: | |
if not system_info: | |
system_info = get_system_info() | |
status.memory = { | |
"usage_percent": system_info["memory_percent"], | |
"status": "healthy" if system_info["memory_percent"] < 90 else "critical", | |
"available": format_memory_size(psutil.virtual_memory().available) | |
} | |
logger.debug(f"Memory status retrieved: {status.memory}") | |
except Exception as e: | |
logger.error(f"Failed to get memory info: {str(e)}") | |
status.memory = {"status": "error", "message": str(e)} | |
# Check GPU | |
try: | |
if not system_info: | |
system_info = get_system_info() | |
status.gpu = { | |
"available": system_info["gpu_available"], | |
"memory_used": format_memory_size(system_info["gpu_memory_used"]), | |
"memory_total": format_memory_size(system_info["gpu_memory_total"]), | |
"utilization_percent": system_info["gpu_memory_used"] / system_info["gpu_memory_total"] * 100 if system_info["gpu_available"] else 0 | |
} | |
logger.debug(f"GPU status retrieved: {status.gpu}") | |
except Exception as e: | |
logger.error(f"Failed to get GPU info: {str(e)}") | |
status.gpu = {"status": "error", "message": str(e)} | |
# Check Storage | |
try: | |
models_path = Path(api.models_path) | |
cache_path = Path(api.cache_path) | |
status.storage = { | |
"models_directory": str(models_path), | |
"models_size": format_memory_size(sum(f.stat().st_size for f in models_path.glob('**/*') if f.is_file())), | |
"cache_directory": str(cache_path), | |
"cache_size": format_memory_size(sum(f.stat().st_size for f in cache_path.glob('**/*') if f.is_file())) | |
} | |
logger.debug(f"Storage status retrieved: {status.storage}") | |
except Exception as e: | |
logger.error(f"Failed to get storage info: {str(e)}") | |
status.storage = {"status": "error", "message": str(e)} | |
# Check Model Status | |
try: | |
current_model_path = api.models_path / api.model_name.split('/')[-1] if api.model_name else None | |
status.model = { | |
"is_loaded": api.model is not None, | |
"current_model": api.model_name, | |
"is_valid": validate_model_path(current_model_path) if current_model_path else False, | |
"has_chat_template": api.has_chat_template() if api.model else False | |
} | |
logger.debug(f"Model status retrieved: {status.model}") | |
except Exception as e: | |
logger.error(f"Failed to get model status: {str(e)}") | |
status.model = {"status": "error", "message": str(e)} | |
logger.info("System status check completed") | |
return status | |
async def generate_text(request: GenerateRequest): | |
"""Generate text response from prompt""" | |
logger.info(f"Received generation request for prompt: {request.prompt[:50]}...") | |
try: | |
response = api.generate_response( | |
prompt=request.prompt, | |
system_message=request.system_message, | |
max_new_tokens=request.max_new_tokens or api.max_new_tokens | |
) | |
logger.info("Successfully generated response") | |
return {"generated_text": response} | |
except Exception as e: | |
logger.error(f"Error in generate_text endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_stream(request: GenerateRequest): | |
"""Generate streaming text response from prompt""" | |
logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...") | |
try: | |
return api.generate_stream( | |
prompt=request.prompt, | |
system_message=request.system_message, | |
max_new_tokens=request.max_new_tokens or api.max_new_tokens | |
) | |
except Exception as e: | |
logger.error(f"Error in generate_stream endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_embedding(request: EmbeddingRequest): | |
"""Generate embedding vector from text""" | |
logger.info(f"Received embedding request for text: {request.text[:50]}...") | |
try: | |
embedding = api.generate_embedding(request.text) | |
logger.info(f"Successfully generated embedding of dimension {len(embedding)}") | |
return EmbeddingResponse( | |
embedding=embedding, | |
dimension=len(embedding) | |
) | |
except Exception as e: | |
logger.error(f"Error in generate_embedding endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def download_model(model_name: Optional[str] = None): | |
"""Download model files to local storage""" | |
try: | |
# Use model name from config if none provided | |
model_to_download = model_name or config["model"]["defaults"]["model_name"] | |
logger.info(f"Received request to download model: {model_to_download}") | |
api.download_model(model_to_download) | |
logger.info(f"Successfully downloaded model: {model_to_download}") | |
return { | |
"status": "success", | |
"message": f"Model {model_to_download} downloaded", | |
"model_name": model_to_download | |
} | |
except Exception as e: | |
logger.error(f"Error downloading model: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def initialize_model(model_name: Optional[str] = None): | |
"""Initialize a model for use""" | |
try: | |
# Use model name from config if none provided | |
model_to_init = model_name or config["model"]["defaults"]["model_name"] | |
logger.info(f"Received request to initialize model: {model_to_init}") | |
api.initialize_model(model_to_init) | |
logger.info(f"Successfully initialized model: {model_to_init}") | |
return { | |
"status": "success", | |
"message": f"Model {model_to_init} initialized", | |
"model_name": model_to_init | |
} | |
except Exception as e: | |
logger.error(f"Error initializing model: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_model_status(): | |
"""Get current model status""" | |
try: | |
status = { | |
"model_loaded": api.model is not None, | |
"current_model": api.model_name if api.model_name else None, | |
"has_chat_template": api.has_chat_template() if api.model else False | |
} | |
logger.info(f"Retrieved model status: {status}") | |
return status | |
except Exception as e: | |
logger.error(f"Error getting model status: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) |