File size: 4,445 Bytes
5d274cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import yaml
import sys
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from .api import LLMApi
from .routes import router, init_router
from .utils.logging import setup_logger
from huggingface_hub import login
from pathlib import Path
from dotenv import load_dotenv
import os

def validate_hf():
    """
    Validate Hugging Face authentication.
    Checks for .env file, loads environment variables, and attempts HF login if token exists.
    """
    logger = setup_logger(config, "hf_validation")

    # Check for .env file
    env_path = Path('.env')
    if env_path.exists():
        logger.info("Found .env file, loading environment variables")
        load_dotenv()
    else:
        logger.warning("No .env file found. Fine if you're on Huggingface, but you need one to run locally on your PC.")

    # Check for HF token
    hf_token = os.getenv('HF_TOKEN')
    if not hf_token:
        logger.error("No HF_TOKEN found in environment variables")
        return False

    try:
        # Attempt login
        login(token=hf_token)
        logger.info("Successfully authenticated with Hugging Face")
        return True
    except Exception as e:
        logger.error(f"Failed to authenticate with Hugging Face: {str(e)}")
        return False

def load_config():
    """Load configuration from yaml file"""
    with open("main/config.yaml", "r") as f:
        return yaml.safe_load(f)

def create_app():
    config = load_config()
    logger = setup_logger(config, "main")
    logger.info("Starting LLM API server")

    app = FastAPI(
        title="LLM API",
        description="API for Large Language Model operations",
        version=config["api"]["version"]
    )

    # Add CORS middleware
    app.add_middleware(
        CORSMiddleware,
        allow_origins=config["api"]["cors"]["origins"],
        allow_credentials=config["api"]["cors"]["credentials"],
        allow_methods=["*"],
        allow_headers=["*"],
    )

    # Initialize routes with config
    init_router(config)

    app.include_router(router, prefix=f"{config['api']['prefix']}/{config['api']['version']}")

    logger.info("FastAPI application created successfully")
    return app

def test_locally():
    """Run local tests for development and debugging"""
    config = load_config()
    logger = setup_logger(config, "test")
    logger.info("Starting local tests")

    api = LLMApi(config)
    model_name = config["model"]["defaults"]["model_name"]

    logger.info(f"Testing with model: {model_name}")

    # Test download
    logger.info("Testing model download...")
    api.download_model(model_name)
    logger.info("Download complete")

    # Test initialization
    logger.info("Initializing model...")
    api.initialize_model(model_name)
    logger.info("Model initialized")

    # Test embedding
    test_text = "Dette er en test av embeddings generering fra en teknisk tekst om HMS rutiner på arbeidsplassen."
    logger.info("Testing embedding generation...")
    embedding = api.generate_embedding(test_text)
    logger.info(f"Generated embedding of length: {len(embedding)}")
    logger.info(f"First few values: {embedding[:5]}")

    # Test generation
    test_prompts = [
        "Tell me what happens in a nuclear reactor.",
    ]

    # Test regular generation
    logger.info("Testing regular generation:")
    for prompt in test_prompts:
        logger.info(f"Prompt: {prompt}")
        response = api.generate_response(
            prompt=prompt,
            system_message="You are a helpful assistant."
        )
        logger.info(f"Response: {response}")

    # Test streaming generation
    logger.info("Testing streaming generation:")
    logger.info(f"Prompt: {test_prompts[0]}")
    for chunk in api.generate_stream(
            prompt=test_prompts[0],
            system_message="You are a helpful assistant."
    ):
        print(chunk, end="", flush=True)
    print("\n")

    logger.info("Local tests completed")

app = create_app()

if __name__ == "__main__":
    config = load_config()
    #validate_hf()
    if len(sys.argv) > 1 and sys.argv[1] == "test":
        test_locally()
    else:
        uvicorn.run(
            "main.app:app",
            host=config["server"]["host"],
            port=config["server"]["port"],
            log_level="trace",
            reload=True,
            workers=1,
            access_log=False,
            use_colors=True
        )