SPT / models /selective_llm_chat.py
hqsiswiliam's picture
Upload 43 files
8359bb1 verified
import deepspeed
import deepspeed
import torch
import transformers
from peft import get_peft_model, PromptTuningConfig, TaskType, PrefixTuningConfig
from torch import nn, autocast
from torch.functional import F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.deepspeed import HfDeepSpeedConfig
from utils.format_inputs import TASK_TYPE
from utils.format_inputs import format_causal_personachat_input, format_personachat_input, format_generate_persona_input
from utils.model_helpers import print_trainable_parameters
class SelectLLMChat(nn.Module):
def __init__(self, config, batch_size, ds_config=None):
super(SelectLLMChat, self).__init__()
if ds_config is not None:
_hfdsc = HfDeepSpeedConfig(ds_config)
peft_type = config.model.peft_type
self.peft_type = peft_type
assert config.model.peft_type in ['prompt_tuning', 'prefix_tuning',
], 'only prompt tuning is supported!'
K = config.model.K
self.K = K
self.ensemble_training = config.training.ensemble
self.model_name = config.model.model_name
self.load_bit = config.model.load_bit
self.left_tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
if config.training.mode != 'causal':
self.left_tokenizer.add_special_tokens({'pad_token': '[PAD]',
'bos_token': '[BOS]',
'eos_token': '[EOS]',
'unk_token': '[UNK]',
'sep_token': '[SEP]',
'cls_token': '[CLS]',
'mask_token': '[MASK]'})
self.left_tokenizer.padding_side = 'left'
self.left_tokenizer.truncation_side = 'left'
self.right_tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
if config.training.mode != 'causal':
self.right_tokenizer.add_special_tokens({'pad_token': '[PAD]',
'bos_token': '[BOS]',
'eos_token': '[EOS]',
'unk_token': '[UNK]',
'sep_token': '[SEP]',
'cls_token': '[CLS]',
'mask_token': '[MASK]'})
self.right_tokenizer.padding_side = 'right'
self.right_tokenizer.truncation_side = 'right'
if self.left_tokenizer.pad_token is None and config.model.pad_token=='bos':
self.left_tokenizer.pad_token = self.left_tokenizer.bos_token
self.right_tokenizer.pad_token = self.right_tokenizer.bos_token
elif self.left_tokenizer.pad_token_id is None:
self.left_tokenizer.pad_token = self.left_tokenizer.eos_token
self.right_tokenizer.pad_token = self.right_tokenizer.eos_token
self.batch_size = batch_size
load_bit_map = {4: {'load_in_4bit': True,
'bnb_4bit_compute_dtype': torch.bfloat16},
8: {'load_in_8bit': True},
16: {'torch_dtype': torch.float16},
'bf16': {'torch_dtype': torch.bfloat16},
32: {'torch_dtype': torch.float32}}
assert config.model.load_bit in [16, 32, 'bf16'], 'deepspeed is not friendly with bnb!'
model = AutoModelForCausalLM.from_pretrained(
config.model.model_name,
**load_bit_map[config.model.load_bit]
)
if config.training.mode != 'causal':
model.resize_token_embeddings(len(self.left_tokenizer))
model.gradient_checkpointing_enable()
if config.model.peft_config is not None:
for param in model.parameters():
param.requires_grad = False # freeze the model - train adapters later
if param.ndim == 1:
# cast the small parameters (e.g. layernorm) to fp32 for stability
param.data = param.data.to(torch.float32)
model.enable_input_require_grads()
class CastOutputToFloat(nn.Sequential):
def forward(self, x): return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)
self.model = model
models = []
peft_config = None
for _ in range(K):
if config.model.peft_type == 'prompt_tuning':
peft_config = PromptTuningConfig(
**config.model.peft_config,
task_type=TaskType.CAUSAL_LM,
)
elif config.model.peft_type == 'prefix_tuning':
peft_config = PrefixTuningConfig(
**config.model.peft_config,
task_type=TaskType.CAUSAL_LM,
)
else:
raise NotImplementedError()
_peft_model = get_peft_model(model, peft_config)
models.append(_peft_model)
self.models = nn.ModuleList(models)
self.learning_rate = config.training.learning_rate
self.warmup_steps = config.training.warmup_steps
self.config = config
self.find_batch = False
self.retriever = None
if config.model.retriever.retriever_type == 'transformer_encoder':
encoder_layer = nn.TransformerEncoderLayer(d_model=self.models[0].word_embeddings.weight.shape[1],
nhead=config.model.retriever.n_head)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.model.retriever.num_layers)
self.retriever = transformer_encoder
if config.model.peft_type in ['prompt_tuning'] and config.model.normalizer.__class__ is not str:
class DoNothing(nn.Sequential):
def forward(self, x): return x
self.prompt_normalizer = DoNothing()
elif config.model.normalizer == 'linear':
if config.model.peft_type in ['prompt_tuning', 'prefix_tuning']:
_d_peft = self.models[0].prompt_encoder.default.embedding.weight.shape[1]
else:
raise NotImplementedError('check here!')
self.prompt_normalizer = nn.Linear(_d_peft, _d_peft)
if config.model.score_activation == 'softplus':
self.score_activation = nn.Softplus(threshold=1, beta=10)
elif config.model.score_activation == 'relu':
self.score_activation = nn.ReLU()
elif config.model.score_activation == 'leaky_relu':
self.score_activation = nn.LeakyReLU()
else:
self.score_activation = nn.Softplus(threshold=1, beta=10)
# raise NotImplementedError()
self.retriever_on = ['extra']
if config.model.retriever.retriever_on.__class__ is list:
self.retriever_on = config.model.retriever.retriever_on
if config.training.all_tunable.__class__ is bool and config.training.all_tunable:
for param in self.parameters():
param.requires_grad = True
print_trainable_parameters(self)
self.contrastive_metric = None
if config.training.contrastive_metric.__class__ is str:
self.contrastive_metric = config.training.contrastive_metric
self.contrastive_threshold = 0.0
if config.training.contrastive_threshold.__class__ is float:
self.contrastive_threshold = config.training.contrastive_threshold
self.config = config
self.annealing_nll = False
self.annealing_scalar = 0.0
if self.config.training.annealing_nll.__class__ == bool:
self.annealing_nll = self.config.training.annealing_nll
self.annealing_scalar = self.config.training.annealing_scalar
def print_llm_trainable_parameters(self):
print_trainable_parameters(self.model)
def retrieve_based_on_input_x(self, x, K):
return self.retrieve_prompts(x, K)
@autocast('cuda')
def retrieve_prompts(self, x, K):
batch_size = x['input_ids'].shape[0]
input_ids = x['input_ids']
spawned_x = input_ids.repeat(K, 1)
if self.models[0].base_model.__class__ == transformers.models.llama.modeling_llama.LlamaForCausalLM:
spawned_x_emb = self.models[0].base_model.model.embed_tokens(spawned_x)
else:
spawned_x_emb = self.models[0].base_model.model.decoder.embed_tokens(spawned_x)
if spawned_x_emb.shape[-1] != self.models[0].config.hidden_size:
# need project_in here
spawned_x_emb = self.models[0].base_model.model.decoder.project_in(spawned_x_emb)
prompt_embeddings = torch.stack([_model.prompt_encoder.default.embedding.weight for _model in self.models])
if self.retriever is not None:
if 'extra' in self.retriever_on:
prompt_embeddings = self.retriever(self.prompt_normalizer(prompt_embeddings))
if 'lm' in self.retriever_on:
spawned_x_emb = self.retriever(spawned_x_emb)
spawned_x_emb_mean = spawned_x_emb.mean(dim=1)
prompt_embeddings_mean = prompt_embeddings.mean(dim=1)
if self.retriever is None:
normalizer_on = self.config.model.normalizer_on
if normalizer_on.__class__ is not list:
prompt_embeddings_mean = self.prompt_normalizer(prompt_embeddings_mean)
if 'prompt' in normalizer_on:
prompt_embeddings_mean = self.prompt_normalizer(prompt_embeddings_mean)
if 'lm' in normalizer_on:
spawned_x_emb_mean = self.prompt_normalizer(spawned_x_emb_mean)
prompt_embeddings_mean_spawn = torch.repeat_interleave(prompt_embeddings_mean, batch_size, dim=0)
sim_scores = self.score_activation(
torch.nn.CosineSimilarity()(prompt_embeddings_mean_spawn, spawned_x_emb_mean))
return sim_scores
@autocast('cuda')
def forward(self, x, mode='training'):
for k in x.keys():
x[k] = x[k].cuda(device=deepspeed.comm.get_local_rank())
if self.find_batch:
x['attention_mask'] = x['attention_mask'].new_ones(x['attention_mask'].shape)
if mode == 'training':
if self.config.training.skip_retrieval.__class__ is bool and self.config.training.skip_retrieval:
sim_scores = None
else:
sim_scores = self.retrieve_based_on_input_x(x, self.K)
# get pt embeddings
_outputs = [_model(**x) for _model in self.models]
_logits = torch.stack([_output['logits'] for _output in _outputs])
return {'logits': _logits, 'sim_scores': sim_scores}
else:
raise NotImplementedError('validation and testing not implemented')
def on_train_start(self) -> None:
self.print_llm_trainable_parameters()
deepspeed.zero.Init()
@staticmethod
def training_step(model, batch, left_tokenizer, right_tokenizer, config, mode='normal',
task_type=TASK_TYPE.GENERATE_RESPONSE, training_process=0.0):
assert mode in ['normal', 'causal']
if task_type == TASK_TYPE.GENERATE_PERSONA and mode == 'normal':
lm_input, lm_target = format_generate_persona_input(batch, left_tokenizer, right_tokenizer,
config)
elif task_type == TASK_TYPE.GENERATE_RESPONSE and mode == 'causal':
lm_input, lm_target = format_causal_personachat_input(batch, left_tokenizer, right_tokenizer,
config)
elif task_type == TASK_TYPE.GENERATE_RESPONSE and mode == 'normal':
lm_input, lm_target = format_personachat_input(batch, left_tokenizer, right_tokenizer, config)
else:
raise NotImplementedError('mode and task_type not implemented')
output = model.module(dict(lm_input))
# suppose batch=2, K=3, the logits is presented interleave:
# [0,1]
# [0,1]
# [0,1]
logits = output['logits'] # (K*Batch,SeqLen,VocabSize)
logits = logits.view(-1, logits.shape[2], logits.shape[3])
sim_scores = output['sim_scores']
batch_size = lm_target.size()[0]
if config.model.peft_type == 'prompt_tuning':
virtual_tokens = config.model.peft_config.num_virtual_tokens
_lm_target = torch.cat(
(lm_target.new_ones((batch_size, virtual_tokens)) * left_tokenizer.pad_token_id, lm_target), dim=1)
else:
_lm_target = lm_target
_lm_target_spawn = _lm_target.repeat(config.model.K, 1)
losses = nn.CrossEntropyLoss(ignore_index=left_tokenizer.pad_token_id, reduction='none')(
logits.view(-1, logits.shape[-1]),
_lm_target_spawn.cuda(device=deepspeed.comm.get_local_rank()).view(-1))
if config.training.only_nll.__class__ == bool and config.training.only_nll:
return losses[losses != 0].mean()
reshaped_losses = losses.view(logits.shape[0], logits.shape[1]).detach().clone()
reshaped_losses = torch.stack([_losses[_losses != 0].mean() for _losses in reshaped_losses.detach().clone()])
# reshaped_losses = reshaped_losses.clone().detach().mean(dim=1)
softmaxed_neg_losses = nn.Softmax(dim=0)(
-reshaped_losses.view(config.model.K, batch_size) / config.training.tau_gold).permute(1, 0)
if config.training.adding_noise.__class__ is float:
noise = torch.randn_like(softmaxed_neg_losses, device=softmaxed_neg_losses.device)
softmaxed_neg_losses = softmaxed_neg_losses + config.training.adding_noise * noise
logsoftmaxed_sim_scores = F.log_softmax(sim_scores.view(config.model.K, batch_size) / config.training.tau_sim,
dim=0).permute(1, 0)
kldiv_loss = nn.KLDivLoss(reduction='batchmean')(logsoftmaxed_sim_scores,
softmaxed_neg_losses)
selective_loss_weight = 1.0
if config.training.annealing_nll.__class__ is bool and config.training.annealing_nll:
_ann_scalar = config.training.annealing_scalar * (1 - training_process)
_sim_score = torch.clamp(_ann_scalar * nn.Softmax(-1)(sim_scores),
config.training.annealing_min, config.training.annealing_max).detach()
losses = torch.einsum('ab,a->ab', losses.view(logits.shape[0], logits.shape[1]), _sim_score).view(-1)
if config.training.selective_loss_weight.__class__ == float:
selective_loss_weight = config.training.selective_loss_weight
if config.training.selective_loss.__class__ == bool and (config.training.selective_loss == False):
loss = losses[losses != 0].mean()
elif config.training.disable_nll.__class__ is bool and config.training.disable_nll:
loss = selective_loss_weight * kldiv_loss
else:
loss = losses[losses != 0].mean() + selective_loss_weight * kldiv_loss
if model.module.ensemble_training:
K = config.model.K
enb_losses = []
for data_idx in range(batch_size):
data_indices = [data_idx + (batch_size * inc) for inc in range(K)]
ensemble_preds = logits[data_indices, :, :]
ensemble_sims = sim_scores[data_indices]
normed_preds = ensemble_sims.unsqueeze(-1).unsqueeze(-1).mul(ensemble_preds)
normed_preds = normed_preds.sum(dim=0)
_target = _lm_target_spawn[data_indices, :]
assert _target.unique(dim=0).shape[0] == 1, 'error in resemble the preds'
enb_loss = nn.CrossEntropyLoss(ignore_index=left_tokenizer.pad_token_id)(normed_preds,
_target[0].cuda(
device=deepspeed.comm.get_local_rank()))
enb_losses.append(enb_loss)
loss += torch.stack(enb_losses).mean()
if model.module.contrastive_metric:
ctr_losses = []
from sacrebleu import BLEU
ctr_metrics = BLEU(effective_order=True)
batch_persona = [' '.join(row) for row in batch['persona_list']]
statics = {}
# Dim here
# x1 x2
# p1 s11 s21
# p2 s12 s22
# p3 s13 s23
permuted_sim_scores = sim_scores.unsqueeze(0).view(model.module.K, batch_size)
if model.module.contrastive_metric == 'bleu':
for idx in range(len(batch_persona) - 1):
for jdx in range(idx + 1, len(batch_persona)):
iele = batch_persona[idx]
jele = batch_persona[jdx]
scores = ctr_metrics.sentence_score(iele, [jele]).score
idist = permuted_sim_scores[:, idx]
jdist = permuted_sim_scores[:, jdx]
cosine_emb_loss = nn.CosineEmbeddingLoss()
if scores > model.module.contrastive_threshold:
cosine_target = 1
else:
cosine_target = -1
cos_loss = cosine_emb_loss(idist, jdist, torch.tensor(cosine_target))
ctr_losses.append(cos_loss)
statics[(idx, jdx)] = {'iele': iele, 'jele': jele, 'scores': scores,
'idist': idist,
'jdist': jdist, 'cos_emb_loss': cos_loss}
if len(ctr_losses) != 0:
ctr_losses_pt = torch.stack(ctr_losses).mean()
loss += config.training.contrastive_weight * ctr_losses_pt
else:
print(f'CTR ERROR: {statics}')
return loss
@staticmethod
def validation_step(model, batch, left_tokenizer, right_tokenizer, config, task_type, mode='normal'):
loss = SelectLLMChat.training_step(model, batch, left_tokenizer, right_tokenizer, config, task_type=task_type,
mode=mode, training_process=0.0)
return loss
@staticmethod
@autocast('cuda')
def test_step(model, batch, left_tokenizer, right_tokenizer, config, max_new_tokens=16, tqdm_instance: tqdm = None,
selection_noise=None, **gen_kwargs):
model.eval()
with torch.no_grad():
if config.training.mode == 'causal':
lm_input, lm_target, inference_tokenized = format_causal_personachat_input(batch,
left_tokenizer,
right_tokenizer,
config,
for_test=True)
else:
lm_input, lm_target, inference_tokenized = format_personachat_input(batch, left_tokenizer,
right_tokenizer,
config,
for_test=True)
inference_tokenized.to('cuda')
if 'deepspeed' in str(model.__class__):
batch_size = inference_tokenized['input_ids'].shape[0]
sim_scores = model.module.retrieve_based_on_input_x(inference_tokenized, config.model.K)
sim_scores = sim_scores.reshape(config.model.K, batch_size).permute(1, 0)
if selection_noise:
noise = torch.randn_like(sim_scores, device=sim_scores.device)
sim_scores = sim_scores + selection_noise * noise
selected_prompts = torch.argmax(sim_scores, dim=1)
if tqdm_instance is not None:
tqdm_instance.set_postfix_str(f"selected prompts: {selected_prompts}")
detached_selected_prompts = selected_prompts.detach().cpu().numpy()
selected_prompts_set = set(detached_selected_prompts)
output_dicts = {}
# adding do_sample=False to avoid inf error!
for key in selected_prompts_set:
outputs = model.module.models[key].generate(
input_ids=inference_tokenized['input_ids'],
attention_mask=inference_tokenized['attention_mask'],
max_new_tokens=max_new_tokens,
do_sample=False,
**gen_kwargs
)
output_dicts[key] = outputs.detach().cpu()
raw_output = []
for idx, prompt_idx in enumerate(detached_selected_prompts):
raw_output.append(output_dicts[prompt_idx][idx][inference_tokenized['input_ids'].shape[1]:])
# raw_output = torch.stack(raw_output).squeeze(1)
trunc_output = raw_output
text_output = right_tokenizer.batch_decode(trunc_output, skip_special_tokens=True)
return trunc_output, text_output, selected_prompts
else:
raise NotImplementedError('not implemented')