|
import deepspeed
|
|
import deepspeed
|
|
import torch
|
|
import transformers
|
|
from peft import get_peft_model, PromptTuningConfig, TaskType, PrefixTuningConfig
|
|
from torch import nn, autocast
|
|
from torch.functional import F
|
|
from tqdm import tqdm
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from transformers.deepspeed import HfDeepSpeedConfig
|
|
|
|
from utils.format_inputs import TASK_TYPE
|
|
from utils.format_inputs import format_causal_personachat_input, format_personachat_input, format_generate_persona_input
|
|
from utils.model_helpers import print_trainable_parameters
|
|
|
|
|
|
class SelectLLMChat(nn.Module):
|
|
def __init__(self, config, batch_size, ds_config=None):
|
|
super(SelectLLMChat, self).__init__()
|
|
if ds_config is not None:
|
|
_hfdsc = HfDeepSpeedConfig(ds_config)
|
|
peft_type = config.model.peft_type
|
|
self.peft_type = peft_type
|
|
assert config.model.peft_type in ['prompt_tuning', 'prefix_tuning',
|
|
], 'only prompt tuning is supported!'
|
|
K = config.model.K
|
|
self.K = K
|
|
self.ensemble_training = config.training.ensemble
|
|
self.model_name = config.model.model_name
|
|
self.load_bit = config.model.load_bit
|
|
self.left_tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
|
|
if config.training.mode != 'causal':
|
|
self.left_tokenizer.add_special_tokens({'pad_token': '[PAD]',
|
|
'bos_token': '[BOS]',
|
|
'eos_token': '[EOS]',
|
|
'unk_token': '[UNK]',
|
|
'sep_token': '[SEP]',
|
|
'cls_token': '[CLS]',
|
|
'mask_token': '[MASK]'})
|
|
self.left_tokenizer.padding_side = 'left'
|
|
self.left_tokenizer.truncation_side = 'left'
|
|
self.right_tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
|
|
if config.training.mode != 'causal':
|
|
self.right_tokenizer.add_special_tokens({'pad_token': '[PAD]',
|
|
'bos_token': '[BOS]',
|
|
'eos_token': '[EOS]',
|
|
'unk_token': '[UNK]',
|
|
'sep_token': '[SEP]',
|
|
'cls_token': '[CLS]',
|
|
'mask_token': '[MASK]'})
|
|
self.right_tokenizer.padding_side = 'right'
|
|
self.right_tokenizer.truncation_side = 'right'
|
|
if self.left_tokenizer.pad_token is None and config.model.pad_token=='bos':
|
|
self.left_tokenizer.pad_token = self.left_tokenizer.bos_token
|
|
self.right_tokenizer.pad_token = self.right_tokenizer.bos_token
|
|
elif self.left_tokenizer.pad_token_id is None:
|
|
self.left_tokenizer.pad_token = self.left_tokenizer.eos_token
|
|
self.right_tokenizer.pad_token = self.right_tokenizer.eos_token
|
|
self.batch_size = batch_size
|
|
load_bit_map = {4: {'load_in_4bit': True,
|
|
'bnb_4bit_compute_dtype': torch.bfloat16},
|
|
8: {'load_in_8bit': True},
|
|
16: {'torch_dtype': torch.float16},
|
|
'bf16': {'torch_dtype': torch.bfloat16},
|
|
32: {'torch_dtype': torch.float32}}
|
|
assert config.model.load_bit in [16, 32, 'bf16'], 'deepspeed is not friendly with bnb!'
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
config.model.model_name,
|
|
**load_bit_map[config.model.load_bit]
|
|
)
|
|
if config.training.mode != 'causal':
|
|
model.resize_token_embeddings(len(self.left_tokenizer))
|
|
model.gradient_checkpointing_enable()
|
|
if config.model.peft_config is not None:
|
|
for param in model.parameters():
|
|
param.requires_grad = False
|
|
if param.ndim == 1:
|
|
|
|
param.data = param.data.to(torch.float32)
|
|
model.enable_input_require_grads()
|
|
class CastOutputToFloat(nn.Sequential):
|
|
def forward(self, x): return super().forward(x).to(torch.float32)
|
|
|
|
model.lm_head = CastOutputToFloat(model.lm_head)
|
|
self.model = model
|
|
models = []
|
|
peft_config = None
|
|
for _ in range(K):
|
|
if config.model.peft_type == 'prompt_tuning':
|
|
peft_config = PromptTuningConfig(
|
|
**config.model.peft_config,
|
|
task_type=TaskType.CAUSAL_LM,
|
|
)
|
|
elif config.model.peft_type == 'prefix_tuning':
|
|
peft_config = PrefixTuningConfig(
|
|
**config.model.peft_config,
|
|
task_type=TaskType.CAUSAL_LM,
|
|
)
|
|
else:
|
|
raise NotImplementedError()
|
|
_peft_model = get_peft_model(model, peft_config)
|
|
models.append(_peft_model)
|
|
self.models = nn.ModuleList(models)
|
|
self.learning_rate = config.training.learning_rate
|
|
self.warmup_steps = config.training.warmup_steps
|
|
self.config = config
|
|
self.find_batch = False
|
|
self.retriever = None
|
|
if config.model.retriever.retriever_type == 'transformer_encoder':
|
|
encoder_layer = nn.TransformerEncoderLayer(d_model=self.models[0].word_embeddings.weight.shape[1],
|
|
nhead=config.model.retriever.n_head)
|
|
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.model.retriever.num_layers)
|
|
self.retriever = transformer_encoder
|
|
if config.model.peft_type in ['prompt_tuning'] and config.model.normalizer.__class__ is not str:
|
|
class DoNothing(nn.Sequential):
|
|
def forward(self, x): return x
|
|
|
|
self.prompt_normalizer = DoNothing()
|
|
elif config.model.normalizer == 'linear':
|
|
if config.model.peft_type in ['prompt_tuning', 'prefix_tuning']:
|
|
_d_peft = self.models[0].prompt_encoder.default.embedding.weight.shape[1]
|
|
else:
|
|
raise NotImplementedError('check here!')
|
|
self.prompt_normalizer = nn.Linear(_d_peft, _d_peft)
|
|
|
|
if config.model.score_activation == 'softplus':
|
|
self.score_activation = nn.Softplus(threshold=1, beta=10)
|
|
elif config.model.score_activation == 'relu':
|
|
self.score_activation = nn.ReLU()
|
|
elif config.model.score_activation == 'leaky_relu':
|
|
self.score_activation = nn.LeakyReLU()
|
|
else:
|
|
self.score_activation = nn.Softplus(threshold=1, beta=10)
|
|
|
|
self.retriever_on = ['extra']
|
|
if config.model.retriever.retriever_on.__class__ is list:
|
|
self.retriever_on = config.model.retriever.retriever_on
|
|
if config.training.all_tunable.__class__ is bool and config.training.all_tunable:
|
|
for param in self.parameters():
|
|
param.requires_grad = True
|
|
print_trainable_parameters(self)
|
|
self.contrastive_metric = None
|
|
if config.training.contrastive_metric.__class__ is str:
|
|
self.contrastive_metric = config.training.contrastive_metric
|
|
self.contrastive_threshold = 0.0
|
|
if config.training.contrastive_threshold.__class__ is float:
|
|
self.contrastive_threshold = config.training.contrastive_threshold
|
|
self.config = config
|
|
self.annealing_nll = False
|
|
self.annealing_scalar = 0.0
|
|
if self.config.training.annealing_nll.__class__ == bool:
|
|
self.annealing_nll = self.config.training.annealing_nll
|
|
self.annealing_scalar = self.config.training.annealing_scalar
|
|
|
|
|
|
def print_llm_trainable_parameters(self):
|
|
print_trainable_parameters(self.model)
|
|
|
|
def retrieve_based_on_input_x(self, x, K):
|
|
return self.retrieve_prompts(x, K)
|
|
|
|
@autocast('cuda')
|
|
def retrieve_prompts(self, x, K):
|
|
batch_size = x['input_ids'].shape[0]
|
|
input_ids = x['input_ids']
|
|
spawned_x = input_ids.repeat(K, 1)
|
|
if self.models[0].base_model.__class__ == transformers.models.llama.modeling_llama.LlamaForCausalLM:
|
|
spawned_x_emb = self.models[0].base_model.model.embed_tokens(spawned_x)
|
|
else:
|
|
spawned_x_emb = self.models[0].base_model.model.decoder.embed_tokens(spawned_x)
|
|
if spawned_x_emb.shape[-1] != self.models[0].config.hidden_size:
|
|
|
|
spawned_x_emb = self.models[0].base_model.model.decoder.project_in(spawned_x_emb)
|
|
prompt_embeddings = torch.stack([_model.prompt_encoder.default.embedding.weight for _model in self.models])
|
|
if self.retriever is not None:
|
|
if 'extra' in self.retriever_on:
|
|
prompt_embeddings = self.retriever(self.prompt_normalizer(prompt_embeddings))
|
|
if 'lm' in self.retriever_on:
|
|
spawned_x_emb = self.retriever(spawned_x_emb)
|
|
spawned_x_emb_mean = spawned_x_emb.mean(dim=1)
|
|
prompt_embeddings_mean = prompt_embeddings.mean(dim=1)
|
|
if self.retriever is None:
|
|
normalizer_on = self.config.model.normalizer_on
|
|
if normalizer_on.__class__ is not list:
|
|
prompt_embeddings_mean = self.prompt_normalizer(prompt_embeddings_mean)
|
|
if 'prompt' in normalizer_on:
|
|
prompt_embeddings_mean = self.prompt_normalizer(prompt_embeddings_mean)
|
|
if 'lm' in normalizer_on:
|
|
spawned_x_emb_mean = self.prompt_normalizer(spawned_x_emb_mean)
|
|
prompt_embeddings_mean_spawn = torch.repeat_interleave(prompt_embeddings_mean, batch_size, dim=0)
|
|
sim_scores = self.score_activation(
|
|
torch.nn.CosineSimilarity()(prompt_embeddings_mean_spawn, spawned_x_emb_mean))
|
|
return sim_scores
|
|
|
|
@autocast('cuda')
|
|
def forward(self, x, mode='training'):
|
|
for k in x.keys():
|
|
x[k] = x[k].cuda(device=deepspeed.comm.get_local_rank())
|
|
if self.find_batch:
|
|
x['attention_mask'] = x['attention_mask'].new_ones(x['attention_mask'].shape)
|
|
if mode == 'training':
|
|
if self.config.training.skip_retrieval.__class__ is bool and self.config.training.skip_retrieval:
|
|
sim_scores = None
|
|
else:
|
|
sim_scores = self.retrieve_based_on_input_x(x, self.K)
|
|
|
|
_outputs = [_model(**x) for _model in self.models]
|
|
_logits = torch.stack([_output['logits'] for _output in _outputs])
|
|
return {'logits': _logits, 'sim_scores': sim_scores}
|
|
else:
|
|
raise NotImplementedError('validation and testing not implemented')
|
|
|
|
def on_train_start(self) -> None:
|
|
self.print_llm_trainable_parameters()
|
|
deepspeed.zero.Init()
|
|
|
|
@staticmethod
|
|
def training_step(model, batch, left_tokenizer, right_tokenizer, config, mode='normal',
|
|
task_type=TASK_TYPE.GENERATE_RESPONSE, training_process=0.0):
|
|
assert mode in ['normal', 'causal']
|
|
if task_type == TASK_TYPE.GENERATE_PERSONA and mode == 'normal':
|
|
lm_input, lm_target = format_generate_persona_input(batch, left_tokenizer, right_tokenizer,
|
|
config)
|
|
elif task_type == TASK_TYPE.GENERATE_RESPONSE and mode == 'causal':
|
|
lm_input, lm_target = format_causal_personachat_input(batch, left_tokenizer, right_tokenizer,
|
|
config)
|
|
elif task_type == TASK_TYPE.GENERATE_RESPONSE and mode == 'normal':
|
|
lm_input, lm_target = format_personachat_input(batch, left_tokenizer, right_tokenizer, config)
|
|
else:
|
|
raise NotImplementedError('mode and task_type not implemented')
|
|
output = model.module(dict(lm_input))
|
|
|
|
|
|
|
|
|
|
logits = output['logits']
|
|
logits = logits.view(-1, logits.shape[2], logits.shape[3])
|
|
sim_scores = output['sim_scores']
|
|
batch_size = lm_target.size()[0]
|
|
if config.model.peft_type == 'prompt_tuning':
|
|
virtual_tokens = config.model.peft_config.num_virtual_tokens
|
|
_lm_target = torch.cat(
|
|
(lm_target.new_ones((batch_size, virtual_tokens)) * left_tokenizer.pad_token_id, lm_target), dim=1)
|
|
else:
|
|
_lm_target = lm_target
|
|
_lm_target_spawn = _lm_target.repeat(config.model.K, 1)
|
|
losses = nn.CrossEntropyLoss(ignore_index=left_tokenizer.pad_token_id, reduction='none')(
|
|
logits.view(-1, logits.shape[-1]),
|
|
_lm_target_spawn.cuda(device=deepspeed.comm.get_local_rank()).view(-1))
|
|
if config.training.only_nll.__class__ == bool and config.training.only_nll:
|
|
return losses[losses != 0].mean()
|
|
|
|
reshaped_losses = losses.view(logits.shape[0], logits.shape[1]).detach().clone()
|
|
reshaped_losses = torch.stack([_losses[_losses != 0].mean() for _losses in reshaped_losses.detach().clone()])
|
|
|
|
|
|
softmaxed_neg_losses = nn.Softmax(dim=0)(
|
|
-reshaped_losses.view(config.model.K, batch_size) / config.training.tau_gold).permute(1, 0)
|
|
if config.training.adding_noise.__class__ is float:
|
|
noise = torch.randn_like(softmaxed_neg_losses, device=softmaxed_neg_losses.device)
|
|
softmaxed_neg_losses = softmaxed_neg_losses + config.training.adding_noise * noise
|
|
logsoftmaxed_sim_scores = F.log_softmax(sim_scores.view(config.model.K, batch_size) / config.training.tau_sim,
|
|
dim=0).permute(1, 0)
|
|
kldiv_loss = nn.KLDivLoss(reduction='batchmean')(logsoftmaxed_sim_scores,
|
|
softmaxed_neg_losses)
|
|
selective_loss_weight = 1.0
|
|
if config.training.annealing_nll.__class__ is bool and config.training.annealing_nll:
|
|
_ann_scalar = config.training.annealing_scalar * (1 - training_process)
|
|
_sim_score = torch.clamp(_ann_scalar * nn.Softmax(-1)(sim_scores),
|
|
config.training.annealing_min, config.training.annealing_max).detach()
|
|
losses = torch.einsum('ab,a->ab', losses.view(logits.shape[0], logits.shape[1]), _sim_score).view(-1)
|
|
|
|
if config.training.selective_loss_weight.__class__ == float:
|
|
selective_loss_weight = config.training.selective_loss_weight
|
|
if config.training.selective_loss.__class__ == bool and (config.training.selective_loss == False):
|
|
loss = losses[losses != 0].mean()
|
|
elif config.training.disable_nll.__class__ is bool and config.training.disable_nll:
|
|
loss = selective_loss_weight * kldiv_loss
|
|
else:
|
|
loss = losses[losses != 0].mean() + selective_loss_weight * kldiv_loss
|
|
|
|
if model.module.ensemble_training:
|
|
K = config.model.K
|
|
enb_losses = []
|
|
for data_idx in range(batch_size):
|
|
data_indices = [data_idx + (batch_size * inc) for inc in range(K)]
|
|
ensemble_preds = logits[data_indices, :, :]
|
|
ensemble_sims = sim_scores[data_indices]
|
|
normed_preds = ensemble_sims.unsqueeze(-1).unsqueeze(-1).mul(ensemble_preds)
|
|
normed_preds = normed_preds.sum(dim=0)
|
|
_target = _lm_target_spawn[data_indices, :]
|
|
assert _target.unique(dim=0).shape[0] == 1, 'error in resemble the preds'
|
|
enb_loss = nn.CrossEntropyLoss(ignore_index=left_tokenizer.pad_token_id)(normed_preds,
|
|
_target[0].cuda(
|
|
device=deepspeed.comm.get_local_rank()))
|
|
enb_losses.append(enb_loss)
|
|
loss += torch.stack(enb_losses).mean()
|
|
if model.module.contrastive_metric:
|
|
ctr_losses = []
|
|
from sacrebleu import BLEU
|
|
ctr_metrics = BLEU(effective_order=True)
|
|
batch_persona = [' '.join(row) for row in batch['persona_list']]
|
|
statics = {}
|
|
|
|
|
|
|
|
|
|
|
|
permuted_sim_scores = sim_scores.unsqueeze(0).view(model.module.K, batch_size)
|
|
if model.module.contrastive_metric == 'bleu':
|
|
for idx in range(len(batch_persona) - 1):
|
|
for jdx in range(idx + 1, len(batch_persona)):
|
|
iele = batch_persona[idx]
|
|
jele = batch_persona[jdx]
|
|
scores = ctr_metrics.sentence_score(iele, [jele]).score
|
|
idist = permuted_sim_scores[:, idx]
|
|
jdist = permuted_sim_scores[:, jdx]
|
|
cosine_emb_loss = nn.CosineEmbeddingLoss()
|
|
if scores > model.module.contrastive_threshold:
|
|
cosine_target = 1
|
|
else:
|
|
cosine_target = -1
|
|
cos_loss = cosine_emb_loss(idist, jdist, torch.tensor(cosine_target))
|
|
ctr_losses.append(cos_loss)
|
|
statics[(idx, jdx)] = {'iele': iele, 'jele': jele, 'scores': scores,
|
|
'idist': idist,
|
|
'jdist': jdist, 'cos_emb_loss': cos_loss}
|
|
if len(ctr_losses) != 0:
|
|
ctr_losses_pt = torch.stack(ctr_losses).mean()
|
|
loss += config.training.contrastive_weight * ctr_losses_pt
|
|
else:
|
|
print(f'CTR ERROR: {statics}')
|
|
return loss
|
|
|
|
@staticmethod
|
|
def validation_step(model, batch, left_tokenizer, right_tokenizer, config, task_type, mode='normal'):
|
|
loss = SelectLLMChat.training_step(model, batch, left_tokenizer, right_tokenizer, config, task_type=task_type,
|
|
mode=mode, training_process=0.0)
|
|
return loss
|
|
|
|
@staticmethod
|
|
@autocast('cuda')
|
|
def test_step(model, batch, left_tokenizer, right_tokenizer, config, max_new_tokens=16, tqdm_instance: tqdm = None,
|
|
selection_noise=None, **gen_kwargs):
|
|
model.eval()
|
|
with torch.no_grad():
|
|
if config.training.mode == 'causal':
|
|
lm_input, lm_target, inference_tokenized = format_causal_personachat_input(batch,
|
|
left_tokenizer,
|
|
right_tokenizer,
|
|
config,
|
|
for_test=True)
|
|
else:
|
|
lm_input, lm_target, inference_tokenized = format_personachat_input(batch, left_tokenizer,
|
|
right_tokenizer,
|
|
config,
|
|
for_test=True)
|
|
inference_tokenized.to('cuda')
|
|
if 'deepspeed' in str(model.__class__):
|
|
batch_size = inference_tokenized['input_ids'].shape[0]
|
|
sim_scores = model.module.retrieve_based_on_input_x(inference_tokenized, config.model.K)
|
|
sim_scores = sim_scores.reshape(config.model.K, batch_size).permute(1, 0)
|
|
if selection_noise:
|
|
noise = torch.randn_like(sim_scores, device=sim_scores.device)
|
|
sim_scores = sim_scores + selection_noise * noise
|
|
selected_prompts = torch.argmax(sim_scores, dim=1)
|
|
if tqdm_instance is not None:
|
|
tqdm_instance.set_postfix_str(f"selected prompts: {selected_prompts}")
|
|
detached_selected_prompts = selected_prompts.detach().cpu().numpy()
|
|
selected_prompts_set = set(detached_selected_prompts)
|
|
output_dicts = {}
|
|
|
|
for key in selected_prompts_set:
|
|
outputs = model.module.models[key].generate(
|
|
input_ids=inference_tokenized['input_ids'],
|
|
attention_mask=inference_tokenized['attention_mask'],
|
|
max_new_tokens=max_new_tokens,
|
|
do_sample=False,
|
|
**gen_kwargs
|
|
)
|
|
output_dicts[key] = outputs.detach().cpu()
|
|
raw_output = []
|
|
for idx, prompt_idx in enumerate(detached_selected_prompts):
|
|
raw_output.append(output_dicts[prompt_idx][idx][inference_tokenized['input_ids'].shape[1]:])
|
|
|
|
trunc_output = raw_output
|
|
text_output = right_tokenizer.batch_decode(trunc_output, skip_special_tokens=True)
|
|
return trunc_output, text_output, selected_prompts
|
|
else:
|
|
raise NotImplementedError('not implemented')
|
|
|