AurelioAguirre commited on
Commit
19b1be5
·
1 Parent(s): 712d19c

Massive update, added download and convert options.

Browse files
.idea/Inference-Server.iml CHANGED
@@ -4,6 +4,7 @@
4
  <exclude-output />
5
  <content url="file://$MODULE_DIR$">
6
  <excludeFolder url="file://$MODULE_DIR$/myenv" />
 
7
  </content>
8
  <orderEntry type="inheritedJdk" />
9
  <orderEntry type="sourceFolder" forTests="false" />
 
4
  <exclude-output />
5
  <content url="file://$MODULE_DIR$">
6
  <excludeFolder url="file://$MODULE_DIR$/myenv" />
7
+ <excludeFolder url="file://$MODULE_DIR$/venv" />
8
  </content>
9
  <orderEntry type="inheritedJdk" />
10
  <orderEntry type="sourceFolder" forTests="false" />
README.md CHANGED
@@ -24,4 +24,8 @@ folders
24
  LLM-Engine
25
  Main
26
  main.py
 
 
 
 
27
  ```
 
24
  LLM-Engine
25
  Main
26
  main.py
27
+ routes.py
28
+ checkpoints
29
+ meta
30
+
31
  ```
client/__init__.py ADDED
File without changes
client/client.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import sseclient
4
+ import sys
5
+ from pathlib import Path
6
+ import yaml
7
+ from typing import Optional
8
+ import os
9
+
10
+ from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint
11
+ from litgpt.scripts.download import download_from_hub
12
+
13
+ DEFAULT_CONFIG = {
14
+ 'server': {'url': 'http://localhost:7860'},
15
+ 'model': {
16
+ 'name': 'Qwen2.5-Coder-7B-Instruct',
17
+ 'download_location': 'huihui-ai/Qwen2.5-Coder-7B-Instruct-abliterated',
18
+ 'folder_path': 'huihui-ai/Qwen2.5-Coder-7B-Instruct-abliterated',
19
+ 'model_filename': 'model.safetensors'
20
+ }
21
+ }
22
+
23
+ def get_project_root(config: dict) -> Path:
24
+ client_dir = Path(__file__).parent
25
+ return (client_dir / config['project']['root_dir']).resolve()
26
+
27
+ def get_checkpoints_dir(config: dict) -> Path:
28
+ root = get_project_root(config)
29
+ return root / config['project']['checkpoints_dir']
30
+
31
+ class LLMClient:
32
+ def __init__(self, config: dict):
33
+ self.config = config
34
+ self.base_url = config['server']['url'].rstrip('/')
35
+ self.session = requests.Session()
36
+ self.checkpoints_dir = get_checkpoints_dir(config)
37
+
38
+ def download_model(
39
+ self,
40
+ repo_id: Optional[str] = None,
41
+ access_token: Optional[str] = os.getenv("HF_TOKEN"),
42
+ ) -> None:
43
+ repo_id = repo_id or self.config['model']['folder_path']
44
+
45
+ print(f"\nDownloading model from: {repo_id}")
46
+ download_from_hub(
47
+ repo_id=repo_id,
48
+ model_name=self.config['model']['name'],
49
+ access_token=access_token,
50
+ tokenizer_only=False,
51
+ checkpoint_dir=self.checkpoints_dir
52
+ )
53
+
54
+ def convert_model(
55
+ self,
56
+ folder_path: Optional[str] = None,
57
+ model_name: Optional[str] = None,
58
+ ) -> None:
59
+ """Convert downloaded model to LitGPT format."""
60
+ folder_path = folder_path or self.config['model']['folder_path']
61
+ model_name = model_name or self.config['model']['name']
62
+
63
+ model_dir = self.checkpoints_dir / folder_path
64
+ print(f"\nConverting model in: {model_dir}")
65
+ print(f"Using model name: {model_name}")
66
+
67
+ try:
68
+ convert_hf_checkpoint(
69
+ checkpoint_dir=model_dir,
70
+ model_name=model_name
71
+ )
72
+ print("Conversion complete!")
73
+ except ValueError as e:
74
+ if "is not a supported config name" in str(e):
75
+ print(f"\nNote: Model '{model_name}' isn't in LitGPT's predefined configs.")
76
+ print("You may need to use the model's safetensors files directly.")
77
+ raise
78
+
79
+ def initialize_model(
80
+ self,
81
+ folder_path: Optional[str] = None,
82
+ mode: Optional[str] = None,
83
+ **kwargs
84
+ ) -> dict:
85
+ """Initialize a converted model using the standard initialize endpoint."""
86
+ url = f"{self.base_url}/initialize"
87
+
88
+ folder_path = folder_path or self.config['model']['folder_path']
89
+ mode = mode or self.config['hardware']['mode']
90
+
91
+ # Debug prints
92
+ print(f"\nDebug - Attempting to initialize model with:")
93
+ print(f"Model path: {folder_path}")
94
+ print(f"Mode: {mode}")
95
+
96
+ payload = {
97
+ "model_path": folder_path, # This is what the regular initialize endpoint expects
98
+ "mode": mode,
99
+ "precision": self.config['hardware'].get('precision'),
100
+ "quantize": self.config['hardware'].get('quantize'),
101
+ "gpu_count": self.config['hardware'].get('gpu_count', 'auto'),
102
+ **kwargs
103
+ }
104
+
105
+ response = self.session.post(url, json=payload)
106
+ response.raise_for_status()
107
+ return response.json()
108
+
109
+ def generate_stream(
110
+ self,
111
+ prompt: str,
112
+ max_new_tokens: Optional[int] = None,
113
+ temperature: Optional[float] = None,
114
+ top_k: Optional[int] = None,
115
+ top_p: Optional[float] = None
116
+ ):
117
+ url = f"{self.base_url}/generate/stream"
118
+
119
+ gen_config = self.config.get('generation', {})
120
+ payload = {
121
+ "prompt": prompt,
122
+ "max_new_tokens": max_new_tokens or gen_config.get('max_new_tokens', 50),
123
+ "temperature": temperature or gen_config.get('temperature', 1.0),
124
+ "top_k": top_k or gen_config.get('top_k'),
125
+ "top_p": top_p or gen_config.get('top_p', 1.0)
126
+ }
127
+
128
+ response = self.session.post(url, json=payload, stream=True)
129
+ response.raise_for_status()
130
+
131
+ client = sseclient.SSEClient(response)
132
+ for event in client.events():
133
+ yield json.loads(event.data)
134
+
135
+ def clear_screen():
136
+ os.system('cls' if os.name == 'nt' else 'clear')
137
+
138
+ def load_config(config_path: str = "client_config.yaml") -> dict:
139
+ try:
140
+ with open(config_path, 'r') as f:
141
+ config = yaml.safe_load(f)
142
+ return config
143
+ except Exception as e:
144
+ print(f"Warning: Could not load config file: {str(e)}")
145
+ print("Using default configuration.")
146
+ return DEFAULT_CONFIG
147
+
148
+
149
+
150
+ def main():
151
+ config = load_config()
152
+ client = LLMClient(config)
153
+
154
+ while True:
155
+ clear_screen()
156
+ print("\nLLM Engine Client")
157
+ print("================")
158
+ print(f"Server: {client.base_url}")
159
+ print(f"Current Model: {config['model']['name']}")
160
+ print("\nOptions:")
161
+ print("1. Download Model")
162
+ print("2. Convert Model")
163
+ print("3. Initialize Model")
164
+ print("4. Generate Text (Streaming)")
165
+ print("5. Exit")
166
+
167
+ choice = input("\nEnter your choice (1-5): ").strip()
168
+
169
+ if choice == "1":
170
+ try:
171
+ print("\nDownload Model")
172
+ print("==============")
173
+ print(f"Default location: {config['model']['download_location']}")
174
+ if input("\nUse default? (Y/n): ").lower() != 'n':
175
+ repo_id = config['model']['download_location']
176
+ else:
177
+ repo_id = input("Enter download location: ").strip()
178
+
179
+ access_token = input("Enter HF access token (or press Enter to use HF_TOKEN env var): ").strip() or None
180
+ client.download_model(repo_id=repo_id, access_token=access_token)
181
+ print("\nModel downloaded successfully!")
182
+ input("\nPress Enter to continue...")
183
+
184
+ except Exception as e:
185
+ print(f"\nError: {str(e)}")
186
+ input("\nPress Enter to continue...")
187
+
188
+ elif choice == "2":
189
+ try:
190
+ print("\nConvert Model")
191
+ print("=============")
192
+ print(f"Default folder path: {config['model']['folder_path']}")
193
+ print(f"Default model name: {config['model']['name']}")
194
+ if input("\nUse defaults? (Y/n): ").lower() != 'n':
195
+ folder_path = config['model']['folder_path']
196
+ model_name = config['model']['name']
197
+ else:
198
+ folder_path = input("Enter folder path: ").strip()
199
+ model_name = input("Enter model name: ").strip()
200
+
201
+ client.convert_model(
202
+ folder_path=folder_path,
203
+ model_name=model_name
204
+ )
205
+ print("\nModel converted successfully!")
206
+ input("\nPress Enter to continue...")
207
+
208
+ except Exception as e:
209
+ print(f"\nError: {str(e)}")
210
+ input("\nPress Enter to continue...")
211
+
212
+ elif choice == "3":
213
+ try:
214
+ print("\nInitialize Model")
215
+ print("================")
216
+ print(f"Default folder path: {config['model']['folder_path']}")
217
+ if input("\nUse defaults? (Y/n): ").lower() != 'n':
218
+ result = client.initialize_model()
219
+ else:
220
+ folder_path = input("Enter model folder path: ").strip()
221
+ mode = input("Enter mode (cpu/gpu): ").strip()
222
+ result = client.initialize_model(
223
+ folder_path=folder_path,
224
+ mode=mode
225
+ )
226
+ print("\nSuccess! Model initialized.")
227
+ print(json.dumps(result, indent=2))
228
+ input("\nPress Enter to continue...")
229
+
230
+ except Exception as e:
231
+ print(f"\nError: {str(e)}")
232
+ input("\nPress Enter to continue...")
233
+
234
+ elif choice == "4":
235
+ try:
236
+ print("\nGenerate Text (Streaming)")
237
+ print("========================")
238
+ prompt = input("Enter your prompt: ").strip()
239
+
240
+ print("\nGenerating (Ctrl+C to stop)...")
241
+ print("\nResponse:")
242
+ try:
243
+ for chunk in client.generate_stream(prompt=prompt):
244
+ if "error" in chunk:
245
+ print(f"\nError: {chunk['error']}")
246
+ break
247
+
248
+ token = chunk.get("token", "")
249
+ is_finished = chunk.get("metadata", {}).get("is_finished", False)
250
+
251
+ if is_finished:
252
+ print("\n[Generation Complete]")
253
+ break
254
+
255
+ print(token, end="", flush=True)
256
+
257
+ except KeyboardInterrupt:
258
+ print("\n\n[Generation Stopped]")
259
+
260
+ input("\nPress Enter to continue...")
261
+
262
+ except Exception as e:
263
+ print(f"\nError: {str(e)}")
264
+ input("\nPress Enter to continue...")
265
+
266
+ elif choice == "5":
267
+ print("\nGoodbye!")
268
+ break
269
+
270
+ else:
271
+ print("\nInvalid choice. Please try again.")
272
+ input("\nPress Enter to continue...")
273
+
274
+ if __name__ == "__main__":
275
+ main()
client/client_config.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project Configuration
2
+ project:
3
+ root_dir: ".."
4
+ checkpoints_dir: "checkpoints"
5
+
6
+ # Server Configuration
7
+ server:
8
+ url: "http://localhost:7860"
9
+
10
+ # Model Configuration
11
+ model:
12
+ name: "Llama-3.2-3B"
13
+ download_location: "huihui-ai/Llama-3.2-3B-Instruct-abliterated"
14
+ folder_path: "huihui-ai/Llama-3.2-3B-Instruct-abliterated"
15
+ model_filename: "lit_model.pth"
16
+ config_filename: "config.json"
17
+ tokenizer_filename: "tokenizer.json"
18
+
19
+ # Hardware Configuration
20
+ hardware:
21
+ mode: "gpu"
22
+ precision: "16-true"
23
+ # Precision Options: "32-true", "16-mixed", "16-true", "bf16-mixed", "bf16-true"
24
+ quantize: "bnb.int8"
25
+ # Quantization Options: "bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"
26
+ gpu_count: "auto"
27
+
28
+ # Generation Parameters
29
+ generation:
30
+ max_new_tokens: 500
31
+ temperature: 1.0
32
+ top_k: null
33
+ top_p: 1.0
main/hf_downloader.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from transformers import AutoTokenizer, AutoModel
4
+ from huggingface_hub import login, HfApi
5
+ import logging
6
+ from tqdm import tqdm
7
+
8
+ # Set up logging
9
+ logging.basicConfig(
10
+ level=logging.INFO,
11
+ format='%(asctime)s - %(levelname)s - %(message)s'
12
+ )
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def setup_auth(token):
16
+ """Setup Hugging Face authentication"""
17
+ try:
18
+ login(token)
19
+ logger.info("Successfully authenticated with Hugging Face")
20
+ except Exception as e:
21
+ logger.error(f"Authentication failed: {str(e)}")
22
+ raise
23
+
24
+ def list_models(pattern=None):
25
+ """List available models matching the pattern"""
26
+ try:
27
+ api = HfApi()
28
+ models = api.list_models(pattern=pattern, full=True)
29
+ return [(model.modelId, model.downloads) for model in models]
30
+ except Exception as e:
31
+ logger.error(f"Failed to list models: {str(e)}")
32
+ raise
33
+
34
+ def download_model(model_name, output_dir):
35
+ """Download model and tokenizer"""
36
+ try:
37
+ logger.info(f"Downloading model: {model_name}")
38
+
39
+ # Create output directory if it doesn't exist
40
+ os.makedirs(output_dir, exist_ok=True)
41
+
42
+ # Download tokenizer
43
+ logger.info("Downloading tokenizer...")
44
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
45
+ tokenizer.save_pretrained(os.path.join(output_dir, model_name))
46
+
47
+ # Download model
48
+ logger.info("Downloading model...")
49
+ model = AutoModel.from_pretrained(model_name)
50
+ model.save_pretrained(os.path.join(output_dir, model_name))
51
+
52
+ logger.info(f"Successfully downloaded {model_name} to {output_dir}")
53
+ return True
54
+ except Exception as e:
55
+ logger.error(f"Failed to download model {model_name}: {str(e)}")
56
+ raise
57
+
58
+ def main():
59
+ parser = argparse.ArgumentParser(description='Download models from Hugging Face')
60
+ parser.add_argument('--token', type=str, help='Hugging Face API token')
61
+ parser.add_argument('--model', type=str, help='Model name to download')
62
+ parser.add_argument('--output', type=str, default='./models',
63
+ help='Output directory for downloaded models')
64
+ parser.add_argument('--search', type=str, help='Search pattern for models')
65
+ parser.add_argument('--list', action='store_true',
66
+ help='List available models matching the search pattern')
67
+
68
+ args = parser.parse_args()
69
+
70
+ try:
71
+ # Setup authentication if token provided
72
+ if args.token:
73
+ setup_auth(args.token)
74
+
75
+ # List models if requested
76
+ if args.list:
77
+ logger.info(f"Searching for models matching: {args.search}")
78
+ models = list_models(args.search)
79
+ print("\nAvailable models:")
80
+ for model_id, downloads in sorted(models, key=lambda x: x[1], reverse=True):
81
+ print(f"- {model_id} (Downloads: {downloads:,})")
82
+ return
83
+
84
+ # Download specific model
85
+ if args.model:
86
+ download_model(args.model, args.output)
87
+ else:
88
+ logger.error("Please specify a model to download using --model")
89
+ return
90
+
91
+ except KeyboardInterrupt:
92
+ logger.info("\nOperation cancelled by user")
93
+ except Exception as e:
94
+ logger.error(f"An error occurred: {str(e)}")
95
+
96
+ if __name__ == "__main__":
97
+ main()
main/main.py CHANGED
@@ -39,10 +39,12 @@ def main():
39
  logger.info("Available endpoints:")
40
  logger.info(" - /")
41
  logger.info(" - /health")
 
42
  logger.info(" - /initialize")
43
  logger.info(" - /generate")
44
- logger.info(" - /initialize/custom")
45
  logger.info(" - /generate/stream")
 
 
46
  logger.info(" - /docs")
47
  logger.info(" - /redoc")
48
  logger.info(" - /openapi.json")
 
39
  logger.info("Available endpoints:")
40
  logger.info(" - /")
41
  logger.info(" - /health")
42
+ logger.info(" - /models")
43
  logger.info(" - /initialize")
44
  logger.info(" - /generate")
 
45
  logger.info(" - /generate/stream")
46
+ logger.info(" - /download")
47
+ logger.info(" - /convert")
48
  logger.info(" - /docs")
49
  logger.info(" - /redoc")
50
  logger.info(" - /openapi.json")
main/routes.py CHANGED
@@ -1,11 +1,14 @@
 
1
  from fastapi import APIRouter, HTTPException
2
  from fastapi.responses import StreamingResponse
3
- from pydantic import BaseModel
4
- from typing import Optional, Union, AsyncGenerator
5
  import torch
6
  import logging
7
  from pathlib import Path
8
  from litgpt.api import LLM
 
 
9
  import json
10
  import asyncio
11
 
@@ -19,224 +22,204 @@ router = APIRouter()
19
  llm_instance = None
20
 
21
  class InitializeRequest(BaseModel):
22
- """
23
- Configuration for model initialization including model path
24
- """
25
- mode: str = "cpu"
26
- precision: Optional[str] = None
27
- quantize: Optional[str] = None
28
- gpu_count: Union[str, int] = "auto"
29
- model_path: str
30
 
31
  class GenerateRequest(BaseModel):
32
- prompt: str
33
- max_new_tokens: int = 50
34
- temperature: float = 1.0
35
- top_k: Optional[int] = None
36
- top_p: float = 1.0
37
- return_as_token_ids: bool = False
38
- stream: bool = False
39
-
40
- # A Pydantic model for the streaming generation request
41
  class StreamGenerateRequest(BaseModel):
42
- prompt: str
43
- max_new_tokens: int = 50
44
- temperature: float = 1.0
45
- top_k: Optional[int] = None
46
- top_p: float = 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- class InitializeCustomRequest(BaseModel):
49
- """
50
- Configuration for custom model initialization using from_pretrained
51
- """
52
- mode: str = "cpu"
53
- precision: Optional[str] = None
54
- quantize: Optional[str] = None
55
- gpu_count: Union[str, int] = "auto"
56
- folder_path: str # Path to the model folder relative to checkpoints
57
- model_filename: str # Name of the model file (e.g., "lit_model.pth")
58
- config_filename: str = "config.json" # Default config filename
59
- tokenizer_filename: Optional[str] = "tokenizer.json" # Optional tokenizer filename
60
-
61
-
62
- @router.post("/initialize/custom")
63
- async def initialize_custom_model(request: InitializeCustomRequest):
64
- """
65
- Initialize a custom model using from_pretrained method.
66
- This is for models that are already downloaded and stored in the checkpoints directory.
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  """
68
- global llm_instance
 
 
 
 
69
 
 
 
 
70
  try:
71
  # Get the project root directory and construct paths
72
- project_root = Path(__file__).parent
73
  checkpoints_dir = project_root / "checkpoints"
74
- model_dir = checkpoints_dir / request.folder_path
75
-
76
- logger.info(f"Loading custom model from directory: {model_dir}")
 
 
 
 
 
 
77
 
78
- # Verify that all required files exist
79
- model_path = model_dir / request.model_filename
80
- config_path = model_dir / request.config_filename
 
 
81
 
82
- if not model_path.exists():
83
- raise HTTPException(
84
- status_code=400,
85
- detail=f"Model file not found: {request.model_filename}"
86
- )
 
 
 
 
 
 
 
 
 
87
 
88
- if not config_path.exists():
89
- raise HTTPException(
90
- status_code=400,
91
- detail=f"Config file not found: {request.config_filename}"
92
- )
93
 
94
- # Check for tokenizer if specified
95
- tokenizer_path = None
96
- if request.tokenizer_filename:
97
- tokenizer_path = model_dir / request.tokenizer_filename
98
- if not tokenizer_path.exists():
99
- raise HTTPException(
100
- status_code=400,
101
- detail=f"Tokenizer file not found: {request.tokenizer_filename}"
102
- )
103
-
104
- # Load the model using from_pretrained
105
- llm_instance = LLM.from_pretrained(
106
- path=str(model_dir),
107
- model_file=request.model_filename,
108
- config_file=request.config_filename,
109
- tokenizer_file=request.tokenizer_filename if request.tokenizer_filename else None,
110
- distribute=None if request.precision or request.quantize else "auto"
111
- )
112
 
113
- # If manual distribution is needed
114
- if request.precision or request.quantize:
115
- llm_instance.distribute(
116
- accelerator="cuda" if request.mode == "gpu" else "cpu",
117
- devices=request.gpu_count,
118
- precision=request.precision,
119
- quantize=request.quantize
120
  )
121
 
122
- # Log success and memory stats
123
- logger.info(
124
- f"Custom model initialized successfully with config:\n"
125
- f"Mode: {request.mode}\n"
126
- f"Precision: {request.precision}\n"
127
- f"Quantize: {request.quantize}\n"
128
- f"GPU Count: {request.gpu_count}\n"
129
- f"Model Directory: {model_dir}\n"
130
- f"Model File: {request.model_filename}\n"
131
- f"Config File: {request.config_filename}\n"
132
- f"Tokenizer File: {request.tokenizer_filename}\n"
133
- f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, "
134
- f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved"
135
  )
136
 
137
  return {
138
- "success": True,
139
- "message": "Custom model initialized successfully",
140
- "model_info": {
141
- "folder": str(model_dir),
142
- "model_file": request.model_filename,
143
- "config_file": request.config_filename,
144
- "tokenizer_file": request.tokenizer_filename
145
- }
146
  }
147
 
148
  except Exception as e:
149
- logger.error(f"Error initializing custom model: {str(e)}")
150
- # Print detailed memory statistics on failure
151
- logger.error(f"GPU Memory Stats:\n"
152
- f"Allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB\n"
153
- f"Reserved: {torch.cuda.memory_reserved()/1024**3:.2f}GB\n"
154
- f"Max Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f}GB")
155
- raise HTTPException(status_code=500, detail=f"Error initializing custom model: {str(e)}")
156
-
157
-
158
- # Endpoint for streaming generation
159
- @router.post("/generate/stream")
160
- async def generate_stream(request: StreamGenerateRequest):
161
  """
162
- Generate text using the initialized model with streaming response.
163
- Returns a StreamingResponse that yields JSON-formatted chunks of text.
 
 
 
 
 
 
164
  """
165
- global llm_instance
166
-
167
- if llm_instance is None:
168
- raise HTTPException(
169
- status_code=400,
170
- detail="Model not initialized. Call /initialize first."
171
- )
172
-
173
- async def event_generator() -> AsyncGenerator[str, None]:
174
- try:
175
- # Start the generation with streaming enabled
176
- async for token in llm_instance.generate(
177
- prompt=request.prompt,
178
- max_new_tokens=request.max_new_tokens,
179
- temperature=request.temperature,
180
- top_k=request.top_k,
181
- top_p=request.top_p,
182
- stream=True # Enable streaming
183
- ):
184
- # Create a JSON response for each token
185
- chunk = {
186
- "token": token,
187
- "metadata": {
188
- "prompt": request.prompt,
189
- "is_finished": False
190
- }
191
- }
192
- # Format as SSE data
193
- yield f"data: {json.dumps(chunk)}\n\n"
194
-
195
- # Small delay to prevent overwhelming the client
196
- await asyncio.sleep(0.01)
197
-
198
- # Send final message indicating completion
199
- final_chunk = {
200
- "token": "",
201
- "metadata": {
202
- "prompt": request.prompt,
203
- "is_finished": True
204
- }
205
- }
206
- yield f"data: {json.dumps(final_chunk)}\n\n"
207
-
208
- except Exception as e:
209
- logger.error(f"Error in stream generation: {str(e)}")
210
- error_chunk = {
211
- "error": str(e),
212
- "metadata": {
213
- "prompt": request.prompt,
214
- "is_finished": True
215
- }
216
- }
217
- yield f"data: {json.dumps(error_chunk)}\n\n"
218
-
219
- return StreamingResponse(
220
- event_generator(),
221
- media_type="text/event-stream",
222
- headers={
223
- 'Cache-Control': 'no-cache',
224
- 'Connection': 'keep-alive',
225
- }
226
- )
227
 
228
- @router.get("/")
229
- async def root():
230
- """Root endpoint to verify service is running"""
231
- return {
232
- "status": "running",
233
- "service": "LLM Engine",
234
- "endpoints": {
235
- "initialize": "/initialize",
236
- "generate": "/generate",
237
- "health": "/health"
238
- }
239
- }
240
 
241
  @router.post("/initialize")
242
  async def initialize_model(request: InitializeRequest):
@@ -247,7 +230,7 @@ async def initialize_model(request: InitializeRequest):
247
 
248
  try:
249
  # Get the project root directory (where main.py is located)
250
- project_root = Path(__file__).parent
251
  checkpoints_dir = project_root / "checkpoints"
252
  logger.info(f"Checkpoint dir is: {checkpoints_dir}")
253
 
@@ -344,10 +327,80 @@ async def generate(request: GenerateRequest):
344
  logger.error(f"Error generating text: {str(e)}")
345
  raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  @router.get("/health")
348
  async def health_check():
349
  """
350
  Check if the service is running and model is loaded.
 
351
  """
352
  global llm_instance
353
 
 
1
+
2
  from fastapi import APIRouter, HTTPException
3
  from fastapi.responses import StreamingResponse
4
+ from pydantic import BaseModel, Field
5
+ from typing import Optional, Union, AsyncGenerator, List
6
  import torch
7
  import logging
8
  from pathlib import Path
9
  from litgpt.api import LLM
10
+ from litgpt.scripts.download import download_from_hub
11
+ from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint
12
  import json
13
  import asyncio
14
 
 
22
  llm_instance = None
23
 
24
  class InitializeRequest(BaseModel):
25
+ """Configuration for model initialization including model path"""
26
+ mode: str = Field(default="cpu", description="Execution mode ('cpu' or 'gpu')")
27
+ precision: Optional[str] = Field(None, description="Precision format (e.g., 'bf16-true', 'bf16-mixed')")
28
+ quantize: Optional[str] = Field(None, description="Quantization format (e.g., 'bnb.nf4')")
29
+ gpu_count: Union[str, int] = Field(default="auto", description="Number of GPUs to use or 'auto'")
30
+ model_path: str = Field(..., description="Path to the model relative to checkpoints directory")
 
 
31
 
32
  class GenerateRequest(BaseModel):
33
+ """Request parameters for text generation"""
34
+ prompt: str = Field(..., description="Input text prompt for generation")
35
+ max_new_tokens: int = Field(default=50, description="Maximum number of tokens to generate")
36
+ temperature: float = Field(default=1.0, description="Sampling temperature")
37
+ top_k: Optional[int] = Field(None, description="Top-k sampling parameter")
38
+ top_p: float = Field(default=1.0, description="Top-p sampling parameter")
39
+ return_as_token_ids: bool = Field(default=False, description="Whether to return token IDs instead of text")
40
+ stream: bool = Field(default=False, description="Whether to stream the response")
41
+
42
  class StreamGenerateRequest(BaseModel):
43
+ """Request parameters for streaming text generation"""
44
+ prompt: str = Field(..., description="Input text prompt for generation")
45
+ max_new_tokens: int = Field(default=50, description="Maximum number of tokens to generate")
46
+ temperature: float = Field(default=1.0, description="Sampling temperature")
47
+ top_k: Optional[int] = Field(None, description="Top-k sampling parameter")
48
+ top_p: float = Field(default=1.0, description="Top-p sampling parameter")
49
+
50
+ class DownloadModelRequest(BaseModel):
51
+ """Request to download a model from HuggingFace"""
52
+ repo_id: str = Field(
53
+ ...,
54
+ description="HuggingFace repository ID (e.g., 'huihui-ai/Llama-3.2-3B-Instruct-abliterated')"
55
+ )
56
+ model_name: str = Field(
57
+ ...,
58
+ description="Model architecture name (e.g., 'Llama-3.2-3B-Instruct')"
59
+ )
60
+ access_token: Optional[str] = Field(
61
+ None,
62
+ description="HuggingFace access token for private models"
63
+ )
64
 
65
+ class ConvertModelRequest(BaseModel):
66
+ """Request to convert a downloaded model"""
67
+ folder_path: str = Field(
68
+ ...,
69
+ description="Path relative to checkpoints where model was downloaded"
70
+ )
71
+ model_name: str = Field(
72
+ ...,
73
+ description="Model architecture name for conversion"
74
+ )
75
+
76
+ class ModelResponse(BaseModel):
77
+ """Model information response"""
78
+ name: str = Field(..., description="Full model name including organization")
79
+ path: str = Field(..., description="Relative path in checkpoints directory")
80
+ downloaded: bool = Field(..., description="Whether the model files are downloaded")
81
+ converted: bool = Field(..., description="Whether the model is converted to LitGPT format")
82
+ has_safetensors: bool = Field(..., description="Whether safetensors files are present")
83
+ files: List[str] = Field(..., description="List of files in model directory")
84
+
85
+ class ModelsListResponse(BaseModel):
86
+ """Response for listing models"""
87
+ models: List[ModelResponse] = Field(..., description="List of available models")
88
+
89
+ @router.post(
90
+ "/download",
91
+ response_model=dict,
92
+ summary="Download a model from HuggingFace Hub",
93
+ description="Downloads a model from HuggingFace to the LLM Engine's checkpoints directory",
94
+ response_description="Download status and location information"
95
+ )
96
+ async def download_model(request: DownloadModelRequest):
97
  """
98
+ Download a model from HuggingFace Hub.
99
+
100
+ - Downloads model files to the checkpoints directory
101
+ - Creates necessary subdirectories
102
+ - Handles authentication for private models
103
 
104
+ Returns:
105
+ A JSON object containing download status and path information
106
+ """
107
  try:
108
  # Get the project root directory and construct paths
109
+ project_root = Path(__file__).parent.parent
110
  checkpoints_dir = project_root / "checkpoints"
111
+ logger.info(f"Downloading model {request.repo_id} to {checkpoints_dir}")
112
+
113
+ download_from_hub(
114
+ repo_id=request.repo_id,
115
+ model_name=request.model_name,
116
+ access_token=request.access_token,
117
+ checkpoint_dir=checkpoints_dir,
118
+ tokenizer_only=False
119
+ )
120
 
121
+ return {
122
+ "status": "success",
123
+ "message": f"Model downloaded to {checkpoints_dir / request.repo_id}",
124
+ "path": str(request.repo_id)
125
+ }
126
 
127
+ except Exception as e:
128
+ logger.error(f"Error downloading model: {str(e)}")
129
+ raise HTTPException(status_code=500, detail=f"Error downloading model: {str(e)}")
130
+
131
+ @router.post(
132
+ "/convert",
133
+ response_model=dict,
134
+ summary="Convert a model to LitGPT format",
135
+ description="Converts a downloaded model to the LitGPT format required for inference",
136
+ response_description="Conversion status and location information"
137
+ )
138
+ async def convert_model(request: ConvertModelRequest):
139
+ """
140
+ Convert a downloaded model to LitGPT format.
141
 
142
+ - Converts model files to LitGPT's format
143
+ - Creates lit_model.pth file
144
+ - Maintains original files
 
 
145
 
146
+ Returns:
147
+ A JSON object containing conversion status and path information
148
+ """
149
+ try:
150
+ project_root = Path(__file__).parent.parent
151
+ checkpoints_dir = project_root / "checkpoints"
152
+ model_dir = checkpoints_dir / request.folder_path
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ if not model_dir.exists():
155
+ raise HTTPException(
156
+ status_code=404,
157
+ detail=f"Model directory not found: {request.folder_path}"
 
 
 
158
  )
159
 
160
+ logger.info(f"Converting model in {model_dir}")
161
+ convert_hf_checkpoint(
162
+ checkpoint_dir=model_dir,
163
+ model_name=request.model_name
 
 
 
 
 
 
 
 
 
164
  )
165
 
166
  return {
167
+ "status": "success",
168
+ "message": f"Model converted successfully",
169
+ "path": str(request.folder_path)
 
 
 
 
 
170
  }
171
 
172
  except Exception as e:
173
+ logger.error(f"Error converting model: {str(e)}")
174
+ raise HTTPException(status_code=500, detail=f"Error converting model: {str(e)}")
175
+
176
+ @router.get(
177
+ "/models",
178
+ response_model=ModelsListResponse,
179
+ summary="List available models",
180
+ description="Lists all models in the checkpoints directory with their status",
181
+ response_description="List of models with their details and status"
182
+ )
183
+ async def list_models():
 
184
  """
185
+ List all models in the checkpoints directory.
186
+
187
+ Returns:
188
+ A JSON object containing:
189
+ - List of models
190
+ - Each model's download status
191
+ - Each model's conversion status
192
+ - Available files for each model
193
  """
194
+ try:
195
+ project_root = Path(__file__).parent.parent
196
+ checkpoints_dir = project_root / "checkpoints"
197
+ models = []
198
+
199
+ if checkpoints_dir.exists():
200
+ for org_dir in checkpoints_dir.iterdir():
201
+ if org_dir.is_dir():
202
+ for model_dir in org_dir.iterdir():
203
+ if model_dir.is_dir():
204
+ files = [f.name for f in model_dir.iterdir()]
205
+ has_safetensors = any(f.endswith('.safetensors') for f in files)
206
+ has_lit_model = 'lit_model.pth' in files
207
+
208
+ model_info = ModelResponse(
209
+ name=f"{org_dir.name}/{model_dir.name}",
210
+ path=str(model_dir.relative_to(checkpoints_dir)),
211
+ downloaded=True,
212
+ converted=has_lit_model,
213
+ has_safetensors=has_safetensors,
214
+ files=files
215
+ )
216
+ models.append(model_info)
217
+
218
+ return ModelsListResponse(models=models)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
+ except Exception as e:
221
+ logger.error(f"Error listing models: {str(e)}")
222
+ raise HTTPException(status_code=500, detail=f"Error listing models: {str(e)}")
 
 
 
 
 
 
 
 
 
223
 
224
  @router.post("/initialize")
225
  async def initialize_model(request: InitializeRequest):
 
230
 
231
  try:
232
  # Get the project root directory (where main.py is located)
233
+ project_root = Path(__file__).parent.parent
234
  checkpoints_dir = project_root / "checkpoints"
235
  logger.info(f"Checkpoint dir is: {checkpoints_dir}")
236
 
 
327
  logger.error(f"Error generating text: {str(e)}")
328
  raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
329
 
330
+ @router.post("/generate/stream")
331
+ async def generate_stream(request: StreamGenerateRequest):
332
+ """
333
+ Generate text using the initialized model with streaming response.
334
+ Returns a StreamingResponse that yields JSON-formatted chunks of text.
335
+ """
336
+ global llm_instance
337
+
338
+ if llm_instance is None:
339
+ raise HTTPException(
340
+ status_code=400,
341
+ detail="Model not initialized. Call /initialize first."
342
+ )
343
+
344
+ async def event_generator() -> AsyncGenerator[str, None]:
345
+ try:
346
+ # Start the generation with streaming enabled
347
+ for token in llm_instance.generate(
348
+ prompt=request.prompt,
349
+ max_new_tokens=request.max_new_tokens,
350
+ temperature=request.temperature,
351
+ top_k=request.top_k,
352
+ top_p=request.top_p,
353
+ stream=True # Enable streaming
354
+ ):
355
+ # Create a JSON response for each token
356
+ chunk = {
357
+ "token": token,
358
+ "metadata": {
359
+ "prompt": request.prompt,
360
+ "is_finished": False
361
+ }
362
+ }
363
+ # Format as SSE data
364
+ yield f"data: {json.dumps(chunk)}\n\n"
365
+
366
+ # Small delay to prevent overwhelming the client
367
+ await asyncio.sleep(0.01)
368
+
369
+ # Send final message indicating completion
370
+ final_chunk = {
371
+ "token": "",
372
+ "metadata": {
373
+ "prompt": request.prompt,
374
+ "is_finished": True
375
+ }
376
+ }
377
+ yield f"data: {json.dumps(final_chunk)}\n\n"
378
+
379
+ except Exception as e:
380
+ logger.error(f"Error in stream generation: {str(e)}")
381
+ error_chunk = {
382
+ "error": str(e),
383
+ "metadata": {
384
+ "prompt": request.prompt,
385
+ "is_finished": True
386
+ }
387
+ }
388
+ yield f"data: {json.dumps(error_chunk)}\n\n"
389
+
390
+ return StreamingResponse(
391
+ event_generator(),
392
+ media_type="text/event-stream",
393
+ headers={
394
+ 'Cache-Control': 'no-cache',
395
+ 'Connection': 'keep-alive',
396
+ }
397
+ )
398
+
399
  @router.get("/health")
400
  async def health_check():
401
  """
402
  Check if the service is running and model is loaded.
403
+ Returns status information including model details if loaded.
404
  """
405
  global llm_instance
406