Spaces:
Paused
Paused
import os | |
from pathlib import Path | |
from threading import Thread | |
import torch | |
from typing import Optional, Iterator, List | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from .utils.logging import setup_logger | |
class LLMApi: | |
def __init__(self, config: dict): | |
"""Initialize the LLM API with configuration.""" | |
self.logger = setup_logger(config, "llm_api") | |
self.logger.info("Initializing LLM API") | |
# Set up paths | |
self.base_path = Path(config["model"]["base_path"]) | |
self.models_path = self.base_path / config["folders"]["models"] | |
self.cache_path = self.base_path / config["folders"]["cache"] | |
self.model = None | |
self.model_name = None | |
self.tokenizer = None | |
# Generation parameters from config | |
gen_config = config["model"]["generation"] | |
self.max_new_tokens = gen_config["max_new_tokens"] | |
self.do_sample = gen_config["do_sample"] | |
self.temperature = gen_config["temperature"] | |
self.repetition_penalty = gen_config["repetition_penalty"] | |
self.generation_config = { | |
"max_new_tokens": self.max_new_tokens, | |
"do_sample": self.do_sample, | |
"temperature": self.temperature, | |
"repetition_penalty": self.repetition_penalty, | |
"eos_token_id": None, | |
"pad_token_id": None | |
} | |
# Create necessary directories | |
self.models_path.mkdir(parents=True, exist_ok=True) | |
self.cache_path.mkdir(parents=True, exist_ok=True) | |
# Set cache directory for transformers | |
os.environ['HF_HOME'] = str(self.cache_path) | |
self.logger.info("LLM API initialized successfully") | |
def download_model(self, model_name: str) -> None: | |
""" | |
Download a model and its tokenizer to the models directory. | |
Args: | |
model_name: The name of the model to download (e.g., "norallm/normistral-11b-warm") | |
""" | |
self.logger.info(f"Starting download of model: {model_name}") | |
try: | |
model_path = self.models_path / model_name.split('/')[-1] | |
# Download and save model | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.logger.info(f"Saving model to {model_path}") | |
model.save_pretrained(model_path) | |
tokenizer.save_pretrained(model_path) | |
self.logger.info(f"Successfully downloaded model: {model_name}") | |
except Exception as e: | |
self.logger.error(f"Failed to download model {model_name}: {str(e)}") | |
raise | |
def initialize_model(self, model_name: str) -> None: | |
""" | |
Initialize a model and tokenizer, either from local storage or by downloading. | |
Args: | |
model_name: The name of the model to initialize | |
""" | |
self.logger.info(f"Initializing model: {model_name}") | |
try: | |
self.model_name = model_name | |
local_model_path = self.models_path / model_name.split('/')[-1] | |
# Check if model exists locally | |
if local_model_path.exists(): | |
self.logger.info(f"Loading model from local path: {local_model_path}") | |
model_path = local_model_path | |
else: | |
self.logger.info(f"Loading model from source: {model_name}") | |
model_path = model_name | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="auto", | |
load_in_8bit=True, | |
torch_dtype=torch.float16 | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
# Update generation config with tokenizer-specific values | |
self.generation_config["eos_token_id"] = self.tokenizer.eos_token_id | |
self.generation_config["pad_token_id"] = self.tokenizer.eos_token_id | |
self.logger.info(f"Successfully initialized model: {model_name}") | |
except Exception as e: | |
self.logger.error(f"Failed to initialize model {model_name}: {str(e)}") | |
raise | |
def has_chat_template(self) -> bool: | |
"""Check if the current model has a chat template.""" | |
try: | |
self.tokenizer.apply_chat_template( | |
[{"role": "user", "content": "test"}], | |
tokenize=False, | |
) | |
return True | |
except (ValueError, AttributeError): | |
return False | |
def _prepare_prompt(self, prompt: str, system_message: Optional[str] = None) -> str: | |
""" | |
Prepare the prompt text, either using the model's chat template if available, | |
or falling back to a simple OpenAI-style format. | |
""" | |
try: | |
messages = [] | |
if system_message: | |
messages.append({"role": "system", "content": system_message}) | |
messages.append({"role": "user", "content": prompt}) | |
return self.tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
except (ValueError, AttributeError): | |
template = "" | |
if system_message: | |
template += f"System: {system_message}\n\n" | |
template += f"User: {prompt}\n\nAssistant: " | |
return template | |
def generate_response( | |
self, | |
prompt: str, | |
system_message: Optional[str] = None, | |
max_new_tokens: Optional[int] = None | |
) -> str: | |
""" | |
Generate a complete response for the given prompt. | |
""" | |
self.logger.debug(f"Generating response for prompt: {prompt[:50]}...") | |
if self.model is None: | |
raise RuntimeError("Model not initialized. Call initialize_model first.") | |
try: | |
text = self._prepare_prompt(prompt, system_message) | |
inputs = self.tokenizer([text], return_tensors="pt") | |
# Remove token_type_ids if present | |
model_inputs = {k: v.to(self.model.device) for k, v in inputs.items() | |
if k != 'token_type_ids'} | |
generation_config = self.generation_config.copy() | |
if max_new_tokens: | |
generation_config["max_new_tokens"] = max_new_tokens | |
generated_ids = self.model.generate( | |
**model_inputs, | |
**generation_config | |
) | |
generated_ids = [ | |
output_ids[len(input_ids):] | |
for input_ids, output_ids in zip(model_inputs['input_ids'], generated_ids) | |
] | |
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
self.logger.debug(f"Generated response: {response[:50]}...") | |
return response | |
except Exception as e: | |
self.logger.error(f"Error generating response: {str(e)}") | |
raise | |
def generate_stream( | |
self, | |
prompt: str, | |
system_message: Optional[str] = None, | |
max_new_tokens: Optional[int] = None | |
) -> Iterator[str]: | |
""" | |
Generate a streaming response for the given prompt. | |
""" | |
self.logger.debug(f"Starting streaming generation for prompt: {prompt[:50]}...") | |
if self.model is None: | |
raise RuntimeError("Model not initialized. Call initialize_model first.") | |
try: | |
text = self._prepare_prompt(prompt, system_message) | |
inputs = self.tokenizer([text], return_tensors="pt") | |
# Remove token_type_ids if present | |
model_inputs = {k: v.to(self.model.device) for k, v in inputs.items() | |
if k != 'token_type_ids'} | |
# Configure generation | |
generation_config = self.generation_config.copy() | |
if max_new_tokens: | |
generation_config["max_new_tokens"] = max_new_tokens | |
# Set up streaming | |
streamer = TextIteratorStreamer(self.tokenizer) | |
generation_kwargs = dict( | |
**model_inputs, | |
**generation_config, | |
streamer=streamer | |
) | |
# Create a thread to run the generation | |
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Yield the generated text in chunks | |
for new_text in streamer: | |
self.logger.debug(f"Generated chunk: {new_text[:50]}...") | |
yield new_text | |
except Exception as e: | |
self.logger.error(f"Error in streaming generation: {str(e)}") | |
raise | |
def generate_embedding(self, text: str) -> List[float]: | |
""" | |
Generate a single embedding vector for a chunk of text. | |
Returns a list of floats representing the text embedding. | |
""" | |
self.logger.debug(f"Generating embedding for text: {text[:50]}...") | |
if self.model is None or self.tokenizer is None: | |
raise RuntimeError("Model not initialized. Call initialize_model first.") | |
try: | |
# Tokenize the input text and ensure input_ids are Long type | |
inputs = self.tokenizer(text, return_tensors='pt') | |
input_ids = inputs.input_ids.to(dtype=torch.long, device=self.model.device) | |
# Get the model's dtype from its parameters for the attention mask | |
model_dtype = next(self.model.parameters()).dtype | |
# Create an attention mask with matching dtype | |
attention_mask = torch.zeros( | |
input_ids.size(0), | |
1, | |
input_ids.size(1), | |
input_ids.size(1), | |
device=input_ids.device, | |
dtype=model_dtype | |
) | |
# Get model outputs | |
with torch.no_grad(): | |
outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
output_hidden_states=True, | |
return_dict=True | |
) | |
# Get the last hidden state | |
last_hidden_state = outputs.hidden_states[-1] | |
# Average the hidden state over all tokens (excluding padding) | |
embedding = last_hidden_state[0].mean(dim=0) | |
# Convert to regular Python list | |
embedding_list = embedding.cpu().tolist() | |
self.logger.debug(f"Generated embedding of length: {len(embedding_list)}") | |
return embedding_list | |
except Exception as e: | |
self.logger.error(f"Error generating embedding: {str(e)}") | |
raise |