Spaces:
Paused
Paused
import requests | |
import json | |
import sseclient | |
import sys | |
from pathlib import Path | |
import yaml | |
from typing import Optional | |
import os | |
from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint | |
from litgpt.scripts.download import download_from_hub | |
DEFAULT_CONFIG = { | |
'server': {'url': 'http://localhost:7860'}, | |
'model': { | |
'name': 'Qwen2.5-Coder-7B-Instruct', | |
'download_location': 'huihui-ai/Qwen2.5-Coder-7B-Instruct-abliterated', | |
'folder_path': 'huihui-ai/Qwen2.5-Coder-7B-Instruct-abliterated', | |
'model_filename': 'model.safetensors' | |
} | |
} | |
def get_project_root(config: dict) -> Path: | |
client_dir = Path(__file__).parent | |
return (client_dir / config['project']['root_dir']).resolve() | |
def get_checkpoints_dir(config: dict) -> Path: | |
root = get_project_root(config) | |
return root / config['project']['checkpoints_dir'] | |
class LLMClient: | |
def __init__(self, config: dict): | |
self.config = config | |
self.base_url = config['server']['url'].rstrip('/') | |
self.session = requests.Session() | |
self.checkpoints_dir = get_checkpoints_dir(config) | |
def download_model( | |
self, | |
repo_id: Optional[str] = None, | |
access_token: Optional[str] = os.getenv("HF_TOKEN"), | |
) -> None: | |
repo_id = repo_id or self.config['model']['folder_path'] | |
print(f"\nDownloading model from: {repo_id}") | |
download_from_hub( | |
repo_id=repo_id, | |
model_name=self.config['model']['name'], | |
access_token=access_token, | |
tokenizer_only=False, | |
checkpoint_dir=self.checkpoints_dir | |
) | |
def convert_model( | |
self, | |
folder_path: Optional[str] = None, | |
model_name: Optional[str] = None, | |
) -> None: | |
"""Convert downloaded model to LitGPT format.""" | |
folder_path = folder_path or self.config['model']['folder_path'] | |
model_name = model_name or self.config['model']['name'] | |
model_dir = self.checkpoints_dir / folder_path | |
print(f"\nConverting model in: {model_dir}") | |
print(f"Using model name: {model_name}") | |
try: | |
convert_hf_checkpoint( | |
checkpoint_dir=model_dir, | |
model_name=model_name | |
) | |
print("Conversion complete!") | |
except ValueError as e: | |
if "is not a supported config name" in str(e): | |
print(f"\nNote: Model '{model_name}' isn't in LitGPT's predefined configs.") | |
print("You may need to use the model's safetensors files directly.") | |
raise | |
def initialize_model( | |
self, | |
folder_path: Optional[str] = None, | |
mode: Optional[str] = None, | |
**kwargs | |
) -> dict: | |
"""Initialize a converted model using the standard initialize endpoint.""" | |
url = f"{self.base_url}/initialize" | |
folder_path = folder_path or self.config['model']['folder_path'] | |
mode = mode or self.config['hardware']['mode'] | |
# Debug prints | |
print(f"\nDebug - Attempting to initialize model with:") | |
print(f"Model path: {folder_path}") | |
print(f"Mode: {mode}") | |
payload = { | |
"model_path": folder_path, # This is what the regular initialize endpoint expects | |
"mode": mode, | |
"precision": self.config['hardware'].get('precision'), | |
"quantize": self.config['hardware'].get('quantize'), | |
"gpu_count": self.config['hardware'].get('gpu_count', 'auto'), | |
**kwargs | |
} | |
response = self.session.post(url, json=payload) | |
response.raise_for_status() | |
return response.json() | |
def generate_stream( | |
self, | |
prompt: str, | |
max_new_tokens: Optional[int] = None, | |
temperature: Optional[float] = None, | |
top_k: Optional[int] = None, | |
top_p: Optional[float] = None | |
): | |
url = f"{self.base_url}/generate/stream" | |
gen_config = self.config.get('generation', {}) | |
payload = { | |
"prompt": prompt, | |
"max_new_tokens": max_new_tokens or gen_config.get('max_new_tokens', 50), | |
"temperature": temperature or gen_config.get('temperature', 1.0), | |
"top_k": top_k or gen_config.get('top_k'), | |
"top_p": top_p or gen_config.get('top_p', 1.0) | |
} | |
response = self.session.post(url, json=payload, stream=True) | |
response.raise_for_status() | |
client = sseclient.SSEClient(response) | |
for event in client.events(): | |
yield json.loads(event.data) | |
def clear_screen(): | |
os.system('cls' if os.name == 'nt' else 'clear') | |
def load_config(config_path: str = "client_config.yaml") -> dict: | |
try: | |
with open(config_path, 'r') as f: | |
config = yaml.safe_load(f) | |
return config | |
except Exception as e: | |
print(f"Warning: Could not load config file: {str(e)}") | |
print("Using default configuration.") | |
return DEFAULT_CONFIG | |
def main(): | |
config = load_config() | |
client = LLMClient(config) | |
while True: | |
clear_screen() | |
print("\nLLM Engine Client") | |
print("================") | |
print(f"Server: {client.base_url}") | |
print(f"Current Model: {config['model']['name']}") | |
print("\nOptions:") | |
print("1. Download Model") | |
print("2. Convert Model") | |
print("3. Initialize Model") | |
print("4. Generate Text (Streaming)") | |
print("5. Exit") | |
choice = input("\nEnter your choice (1-5): ").strip() | |
if choice == "1": | |
try: | |
print("\nDownload Model") | |
print("==============") | |
print(f"Default location: {config['model']['download_location']}") | |
if input("\nUse default? (Y/n): ").lower() != 'n': | |
repo_id = config['model']['download_location'] | |
else: | |
repo_id = input("Enter download location: ").strip() | |
access_token = input("Enter HF access token (or press Enter to use HF_TOKEN env var): ").strip() or None | |
client.download_model(repo_id=repo_id, access_token=access_token) | |
print("\nModel downloaded successfully!") | |
input("\nPress Enter to continue...") | |
except Exception as e: | |
print(f"\nError: {str(e)}") | |
input("\nPress Enter to continue...") | |
elif choice == "2": | |
try: | |
print("\nConvert Model") | |
print("=============") | |
print(f"Default folder path: {config['model']['folder_path']}") | |
print(f"Default model name: {config['model']['name']}") | |
if input("\nUse defaults? (Y/n): ").lower() != 'n': | |
folder_path = config['model']['folder_path'] | |
model_name = config['model']['name'] | |
else: | |
folder_path = input("Enter folder path: ").strip() | |
model_name = input("Enter model name: ").strip() | |
client.convert_model( | |
folder_path=folder_path, | |
model_name=model_name | |
) | |
print("\nModel converted successfully!") | |
input("\nPress Enter to continue...") | |
except Exception as e: | |
print(f"\nError: {str(e)}") | |
input("\nPress Enter to continue...") | |
elif choice == "3": | |
try: | |
print("\nInitialize Model") | |
print("================") | |
print(f"Default folder path: {config['model']['folder_path']}") | |
if input("\nUse defaults? (Y/n): ").lower() != 'n': | |
result = client.initialize_model() | |
else: | |
folder_path = input("Enter model folder path: ").strip() | |
mode = input("Enter mode (cpu/gpu): ").strip() | |
result = client.initialize_model( | |
folder_path=folder_path, | |
mode=mode | |
) | |
print("\nSuccess! Model initialized.") | |
print(json.dumps(result, indent=2)) | |
input("\nPress Enter to continue...") | |
except Exception as e: | |
print(f"\nError: {str(e)}") | |
input("\nPress Enter to continue...") | |
elif choice == "4": | |
try: | |
print("\nGenerate Text (Streaming)") | |
print("========================") | |
prompt = input("Enter your prompt: ").strip() | |
print("\nGenerating (Ctrl+C to stop)...") | |
print("\nResponse:") | |
try: | |
for chunk in client.generate_stream(prompt=prompt): | |
if "error" in chunk: | |
print(f"\nError: {chunk['error']}") | |
break | |
token = chunk.get("token", "") | |
is_finished = chunk.get("metadata", {}).get("is_finished", False) | |
if is_finished: | |
print("\n[Generation Complete]") | |
break | |
print(token, end="", flush=True) | |
except KeyboardInterrupt: | |
print("\n\n[Generation Stopped]") | |
input("\nPress Enter to continue...") | |
except Exception as e: | |
print(f"\nError: {str(e)}") | |
input("\nPress Enter to continue...") | |
elif choice == "5": | |
print("\nGoodbye!") | |
break | |
else: | |
print("\nInvalid choice. Please try again.") | |
input("\nPress Enter to continue...") | |
if __name__ == "__main__": | |
main() |