Spaces:
Paused
Paused
import yaml | |
import sys | |
from fastapi import FastAPI | |
from fastapi.middleware.cors import CORSMiddleware | |
import uvicorn | |
from .api import LLMApi | |
from .routes import router, init_router | |
from .utils.logging import setup_logger | |
from huggingface_hub import login | |
from pathlib import Path | |
from dotenv import load_dotenv | |
import os | |
def validate_hf(): | |
""" | |
Validate Hugging Face authentication. | |
Checks for .env file, loads environment variables, and attempts HF login if token exists. | |
""" | |
logger = setup_logger(config, "hf_validation") | |
# Check for .env file | |
env_path = Path('.env') | |
if env_path.exists(): | |
logger.info("Found .env file, loading environment variables") | |
load_dotenv() | |
else: | |
logger.warning("No .env file found. Fine if you're on Huggingface, but you need one to run locally on your PC.") | |
# Check for HF token | |
hf_token = os.getenv('HF_TOKEN') | |
if not hf_token: | |
logger.error("No HF_TOKEN found in environment variables") | |
return False | |
try: | |
# Attempt login | |
login(token=hf_token) | |
logger.info("Successfully authenticated with Hugging Face") | |
return True | |
except Exception as e: | |
logger.error(f"Failed to authenticate with Hugging Face: {str(e)}") | |
return False | |
def load_config(): | |
"""Load configuration from yaml file""" | |
with open("main/config.yaml", "r") as f: | |
return yaml.safe_load(f) | |
def create_app(): | |
config = load_config() | |
logger = setup_logger(config, "main") | |
logger.info("Starting LLM API server") | |
app = FastAPI( | |
title="LLM API", | |
description="API for Large Language Model operations", | |
version=config["api"]["version"] | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=config["api"]["cors"]["origins"], | |
allow_credentials=config["api"]["cors"]["credentials"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize routes with config | |
init_router(config) | |
app.include_router(router, prefix=f"{config['api']['prefix']}/{config['api']['version']}") | |
logger.info("FastAPI application created successfully") | |
return app | |
def test_locally(): | |
"""Run local tests for development and debugging""" | |
config = load_config() | |
logger = setup_logger(config, "test") | |
logger.info("Starting local tests") | |
api = LLMApi(config) | |
model_name = config["model"]["defaults"]["model_name"] | |
logger.info(f"Testing with model: {model_name}") | |
# Test download | |
logger.info("Testing model download...") | |
api.download_model(model_name) | |
logger.info("Download complete") | |
# Test initialization | |
logger.info("Initializing model...") | |
api.initialize_model(model_name) | |
logger.info("Model initialized") | |
# Test embedding | |
test_text = "Dette er en test av embeddings generering fra en teknisk tekst om HMS rutiner på arbeidsplassen." | |
logger.info("Testing embedding generation...") | |
embedding = api.generate_embedding(test_text) | |
logger.info(f"Generated embedding of length: {len(embedding)}") | |
logger.info(f"First few values: {embedding[:5]}") | |
# Test generation | |
test_prompts = [ | |
"Tell me what happens in a nuclear reactor.", | |
] | |
# Test regular generation | |
logger.info("Testing regular generation:") | |
for prompt in test_prompts: | |
logger.info(f"Prompt: {prompt}") | |
response = api.generate_response( | |
prompt=prompt, | |
system_message="You are a helpful assistant." | |
) | |
logger.info(f"Response: {response}") | |
# Test streaming generation | |
logger.info("Testing streaming generation:") | |
logger.info(f"Prompt: {test_prompts[0]}") | |
for chunk in api.generate_stream( | |
prompt=test_prompts[0], | |
system_message="You are a helpful assistant." | |
): | |
print(chunk, end="", flush=True) | |
print("\n") | |
logger.info("Local tests completed") | |
app = create_app() | |
if __name__ == "__main__": | |
config = load_config() | |
#validate_hf() | |
if len(sys.argv) > 1 and sys.argv[1] == "test": | |
test_locally() | |
else: | |
uvicorn.run( | |
"main.app:app", | |
host=config["server"]["host"], | |
port=config["server"]["port"], | |
log_level="trace", | |
reload=True, | |
workers=1, | |
access_log=False, | |
use_colors=True | |
) |