thennal's picture
Upload LLMEncoder
d139aed verified
raw
history blame
18.7 kB
import json
import logging
import os
from typing import Dict, List, Optional, Union
import numpy as np
import torch
import torch.multiprocessing as mp
from peft import PeftModel
from torch import Tensor, device, nn
from tqdm.autonotebook import tqdm, trange
from transformers import (
AutoModel,
AutoConfig,
PretrainedConfig,
PreTrainedModel,
AutoTokenizer,
LlamaConfig,
MistralConfig,
GemmaConfig,
Qwen2Config,
)
logger = logging.getLogger(__name__)
def batch_to_device(batch, target_device: device):
"""
send a pytorch batch to a device (CPU/GPU)
"""
for key in batch:
if isinstance(batch[key], Tensor):
batch[key] = batch[key].to(target_device)
return batch
class LLMEncoderConfig(PretrainedConfig):
def __init__(
self,
pooling_mode: str = "weighted_mean",
max_length: int = 512,
doc_max_length: int = 400,
skip_instruction: bool = True,
**kwargs,
):
if pooling_mode not in ["mean", "weighted_mean", "eos_token", "bos_token"]:
raise ValueError(
(f"Pooling mode {pooling_mode} is not supported.",
"Please choose one of 'mean', 'weighted_mean', 'eos_token', 'bos_token'.")
)
self.pooling_mode = pooling_mode
self.max_length = max_length
self.doc_max_length = doc_max_length
self.skip_instruction = skip_instruction
self.model_config = None
self.base_model = None
super().__init__(**kwargs)
class LLMEncoder(PreTrainedModel):
config_class = LLMEncoderConfig
def __init__(
self,
model: PreTrainedModel,
tokenizer: AutoTokenizer,
config: LLMEncoderConfig
):
super().__init__(config)
self.model = model
self.tokenizer = tokenizer
self.pooling_mode = config.pooling_mode
self.max_length = config.max_length
self.doc_max_length = config.doc_max_length
self.skip_instruction = config.skip_instruction
self.model_config = None
@classmethod
def from_pretrained(
self,
base_model_name_or_path,
peft_model_name_or_path=None,
config=None,
**kwargs,
):
"""
Load a pretrained model from a model identifier or path.
Args:
base_model_name_or_path: Model identifier or path to pretrained model.
peft_model_name_or_path: Path to any PEFT models to apply.
Returns: L3Prune model.
"""
if not config:
config = LLMEncoderConfig()
if not config.base_model:
config.base_model = base_model_name_or_path
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
if config.model_config:
model_config = AutoConfig.from_pretrained(config.base_model)
model_config = model_config.from_dict(config.model_config)
else:
model_config = AutoConfig.from_pretrained(base_model_name_or_path)
config.model_config = model_config
model = AutoModel.from_pretrained(base_model_name_or_path, config=model_config, **kwargs)
if peft_model_name_or_path is not None:
model = PeftModel.from_pretrained(
model,
peft_model_name_or_path,
)
model = model.merge_and_unload()
return self(model=model, tokenizer=tokenizer, config=config)
def prune(self, percent_prune=0):
"""
Prune a model to a percentage of layers of the base model. If percent_prune is equal to or greater than 1,
it is taken as the specific layer number to prune to. For example, if percent_prune=0.3, 30% of the layers will be pruned. If
percent_prune=3, the model will be pruned to 3 layers.
"""
# take it as the specific layer number to prune to
if percent_prune >= 1:
new_num_layers = int(percent_prune)
else:
new_num_layers = int(self.model.config.num_hidden_layers * (1 - percent_prune))
print(f"Pruning to {new_num_layers} layer.")
self.model.layers = self.model.layers[:new_num_layers]
self.model.config.num_hidden_layers = new_num_layers
self.config.model_config.num_hidden_layers = new_num_layers
def prepare_for_tokenization(self, text):
if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct":
text = (
"<|start_header_id|>user<|end_header_id|>\n\n"
+ text.strip()
+ "<|eot_id|>"
)
return text
if self.model.config._name_or_path in [
"mistralai/Mistral-7B-Instruct-v0.2",
"meta-llama/Llama-2-7b-chat-hf",
]:
text = "[INST] " + text.strip() + " [/INST]"
if self.model.config._name_or_path in [
"google/gemma-2-9b-it",
]:
text = "<bos><start_of_turn>user\n" + text.strip() + "<end_of_turn>"
if self.model.config._name_or_path in [
"Qwen/Qwen2-1.5B-Instruct",
"Qwen/Qwen2-7B-Instruct",
]:
text = "<|im_start|>user\n" + text.strip() + "<|im_end|>"
if self.pooling_mode == "eos_token":
if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B":
text = text.strip() + "<|end_of_text|>"
elif isinstance(self.model.config, LlamaConfig) or isinstance(
self.model.config, MistralConfig
):
text = text.strip() + " </s>"
elif isinstance(self.model.config, GemmaConfig):
text = text.strip() + "<eos>"
elif isinstance(self.model.config, Qwen2Config):
text = text.strip() + "<|endoftext|>"
return text
def tokenize(self, texts):
texts_2 = []
original_texts = []
for text in texts:
t = text.split("!@#$%^&*()")
texts_2.append(t[1] if len(t) > 1 else "")
original_texts.append("".join(t))
original = self.tokenizer(
original_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length,
)
embed_mask = None
for t_i, t in enumerate(texts_2):
ids = self.tokenizer(
[t],
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length,
add_special_tokens=False,
)
if embed_mask is None:
e_m = torch.zeros_like(original["attention_mask"][t_i])
if len(ids["input_ids"][0]) > 0:
e_m[-len(ids["input_ids"][0]) :] = torch.ones(
len(ids["input_ids"][0])
)
embed_mask = e_m.unsqueeze(0)
else:
e_m = torch.zeros_like(original["attention_mask"][t_i])
if len(ids["input_ids"][0]) > 0:
e_m[-len(ids["input_ids"][0]) :] = torch.ones(
len(ids["input_ids"][0])
)
embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0)
original["embed_mask"] = embed_mask
return original
def _skip_instruction(self, sentence_feature):
assert (
sentence_feature["attention_mask"].shape
== sentence_feature["embed_mask"].shape
)
sentence_feature["attention_mask"] = sentence_feature["embed_mask"]
def forward(self, sentence_feature: Dict[str, Tensor]):
embed_mask = None
if "embed_mask" in sentence_feature:
embed_mask = sentence_feature.pop("embed_mask")
reps = self.model(**sentence_feature)
sentence_feature["embed_mask"] = embed_mask
return self.get_pooling(sentence_feature, reps.last_hidden_state)
def get_pooling(self, features, last_hidden_states): # All models padded from left
assert (
self.tokenizer.padding_side == "left"
), "Pooling modes are implemented for padding from left."
if self.skip_instruction:
self._skip_instruction(features)
seq_lengths = features["attention_mask"].sum(dim=-1)
if self.pooling_mode == "mean":
return torch.stack(
[
last_hidden_states[i, -length:, :].mean(dim=0)
for i, length in enumerate(seq_lengths)
],
dim=0,
)
elif self.pooling_mode == "weighted_mean":
bs, l, _ = last_hidden_states.shape
complete_weights = torch.zeros(bs, l, device=last_hidden_states.device)
for i, seq_l in enumerate(seq_lengths):
if seq_l > 0:
complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1
complete_weights[i] /= torch.clamp(
complete_weights[i].sum(), min=1e-9
)
return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1)
elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token":
return last_hidden_states[:, -1]
elif self.pooling_mode == "bos_token":
return last_hidden_states[
features["input_ids"] == self.tokenizer.bos_token_id
]
else:
raise ValueError(f"{self.pooling_mode} is not implemented yet.")
def _convert_to_str(self, instruction, text):
tokenized_q = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length,
add_special_tokens=False,
)
tokenized_q_length = len(tokenized_q["input_ids"][0])
while tokenized_q_length > self.doc_max_length:
reduction_ratio = self.doc_max_length / tokenized_q_length
reduced_length = int(len(text.split()) * reduction_ratio)
text = " ".join(text.split()[:reduced_length])
tokenized_q = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length,
add_special_tokens=False,
)
tokenized_q_length = len(tokenized_q["input_ids"][0])
return (
f"{instruction.strip()} !@#$%^&*(){text}"
if instruction
else f"!@#$%^&*(){text}"
)
def encode(
self,
sentences: Union[str, List[str]],
batch_size: int = 32,
show_progress_bar: bool = True,
convert_to_numpy: bool = False,
convert_to_tensor: bool = False,
device: Optional[str] = None,
):
"""
Encode a list of sentences to their respective embeddings. The sentences can be a list of strings or a string.
Args:
sentences: sentence or sentences to encode.
batch_size: batch size for turning sentence tokens into embeddings.
show_progress_bar: whether to show progress bars during encoding steps.
convert_to_numpy: If true, return numpy arrays instead of torch tensors.
convert_to_tensor: If true, return torch tensors (default).
device: torch backend device identifier (e.g., 'cuda', 'cpu','mps' etc.). If not specified,
the default is to use cuda when available, otherwise cpu. Note that only the choice of 'cuda' supports
multiprocessing as currently implemented.
Returns: embeddings of the sentences. Embeddings are detached and always on the CPU (see _encode implementation).
"""
if isinstance(sentences[0], str) and isinstance(sentences[-1], int):
sentences = [sentences]
# required for MEDI version of MTEB
if isinstance(sentences[0], str):
sentences = [[""] + [sentence] for sentence in sentences]
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
concatenated_input_texts = []
for sentence in sentences:
assert isinstance(sentence[0], str)
assert isinstance(sentence[1], str)
concatenated_input_texts.append(
self._convert_to_str(sentence[0], sentence[1])
)
sentences = concatenated_input_texts
self.eval()
if convert_to_tensor:
convert_to_numpy = False
length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
all_embeddings = []
if torch.cuda.device_count() <= 1:
# This branch also support mps devices
self.to(device)
for start_index in trange(
0,
len(sentences),
batch_size,
desc="Batches",
disable=not show_progress_bar,
):
sentences_batch = sentences_sorted[
start_index : start_index + batch_size
]
embeddings = self._encode(
sentences_batch, device=device, convert_to_numpy=convert_to_numpy
)
all_embeddings.append(embeddings)
else:
num_proc = torch.cuda.device_count()
cuda_compatible_multiprocess = mp.get_context("spawn")
with cuda_compatible_multiprocess.Pool(num_proc) as p:
sentences_batches = [
sentences_sorted[start_index : start_index + batch_size]
for start_index in range(0, len(sentences), batch_size)
]
progress_bar = tqdm(
total=len(sentences_batches),
desc="Batches",
disable=not show_progress_bar,
)
results = []
def update(*args):
progress_bar.update()
for batch in sentences_batches:
results.append(
p.apply_async(
self._encode,
args=(batch, None, convert_to_numpy, True),
callback=update,
)
)
all_embeddings = [result.get() for result in results]
progress_bar.close()
all_embeddings = torch.cat(all_embeddings, dim=0)
all_embeddings = all_embeddings[np.argsort(length_sorted_idx)]
all_embeddings = all_embeddings.to(torch.float32)
if convert_to_numpy:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
return all_embeddings
def save(self, output_path, merge_before_save=False, save_config=True):
if merge_before_save and isinstance(self.model, PeftModel):
self.model = self.model.merge_and_unload()
if hasattr(self.model, "_hf_peft_config_loaded"):
self.model._hf_peft_config_loaded = False
self.model.save_pretrained(output_path)
self.tokenizer.save_pretrained(output_path)
l3prune_config = {
"pooling_mode": self.pooling_mode,
"max_length": self.max_length,
"doc_max_length": self.doc_max_length,
"skip_instruction": self.skip_instruction,
}
if save_config:
os.makedirs(output_path, exist_ok=True)
with open(f"{output_path}/l3prune_config.json", "w") as fOut:
json.dump(l3prune_config, fOut, indent=4)
def _encode(
self,
sentences_batch,
device: Optional[str] = None,
convert_to_numpy: bool = False,
multiprocessing=False,
):
if multiprocessing:
# multiprocessing only supports CUDA devices at this time, so we ignore the value of device
# and use cuda:rank for the device
rank = mp.current_process()._identity[0]
if device is None and torch.cuda.is_available():
device = f"cuda:{rank % torch.cuda.device_count()}"
self.to(device)
features = self.tokenize(
[self.prepare_for_tokenization(sentence) for sentence in sentences_batch]
)
features = batch_to_device(features, device)
with torch.no_grad():
embeddings = self.forward(features)
embeddings = embeddings.detach()
embeddings = embeddings.cpu()
return embeddings
def _text_length(self, text: Union[List[int], List[List[int]]]):
"""
Help function to get the length for the input text. Text can be either a string (which means a single text)
a list of ints (which means a single tokenized text), or a tuple of list of ints
(representing several text inputs to the model).
"""
if (
isinstance(text, str)
or (isinstance(text, list) and isinstance(text[0], int))
or len(text) == 0
): # Single text, list of ints, or empty
return len(text)
if isinstance(text, dict): # {key: value} case
return len(next(iter(text.values())))
elif not hasattr(text, "__len__"): # Object has no len() method
return 1
else:
return sum([len(t) for t in text])
def resize_token_embeddings(
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
) -> nn.Embedding:
return self.model.resize_token_embeddings(
new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of
)
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
self.model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs
)
def save_pretrained(self, save_directory, **kwargs):
self.tokenizer.save_pretrained(save_directory, **kwargs)
super().save_pretrained(save_directory, **kwargs)
def push_to_hub(self, repo_id, **kwargs):
self.tokenizer.push_to_hub(repo_id, **kwargs)
super().push_to_hub(repo_id, **kwargs)