|
|
|
|
|
|
|
|
|
|
|
from .modeling_llama import LlamaConfig, LlamaForCausalLM, LlamaModel |
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import os |
|
import torch.nn as nn |
|
|
|
|
|
class ValleAR(nn.Module): |
|
def __init__( |
|
self, |
|
phone_vocab_size=256, |
|
target_vocab_size=1024, |
|
hidden_size=1024, |
|
intermediate_size=4096, |
|
num_hidden_layers=12, |
|
num_attention_heads=16, |
|
pad_token_id=1281, |
|
bos_target_id=1282, |
|
eos_target_id=1283, |
|
bos_phone_id=1284, |
|
eos_phone_id=1285, |
|
use_input_embeds=False, |
|
emb_dim=256, |
|
**kwargs, |
|
): |
|
super(ValleAR, self).__init__() |
|
self.config = LlamaConfig( |
|
vocab_size=phone_vocab_size + target_vocab_size + 10, |
|
hidden_size=hidden_size, |
|
intermediate_size=intermediate_size, |
|
num_hidden_layers=num_hidden_layers, |
|
num_attention_heads=num_attention_heads, |
|
pad_token_id=pad_token_id, |
|
bos_token_id=bos_target_id, |
|
eos_token_id=eos_target_id, |
|
) |
|
self.phone_vocab_size = phone_vocab_size |
|
self.target_vocab_size = target_vocab_size |
|
self.pad_token_id = pad_token_id |
|
self.bos_target_id = bos_target_id |
|
self.eos_target_id = eos_target_id |
|
self.bos_phone_id = bos_phone_id |
|
self.eos_phone_id = eos_phone_id |
|
self.model = LlamaForCausalLM(self.config) |
|
|
|
self.use_input_embeds = use_input_embeds |
|
|
|
|
|
if self.use_input_embeds: |
|
self.emb_linear = nn.Linear(emb_dim, hidden_size) |
|
self.emb_linear.weight.data.normal_(mean=0.0, std=0.01) |
|
self.emb_linear.bias.data.zero_() |
|
|
|
def forward( |
|
self, phone_ids, phone_mask, target_ids, target_mask, input_embeds=None |
|
): |
|
if input_embeds is not None: |
|
input_embeds = self.emb_linear(input_embeds) |
|
phone_ids, phone_mask, phone_label = self.add_phone_eos_bos_label( |
|
phone_ids, |
|
phone_mask, |
|
self.eos_phone_id, |
|
self.bos_phone_id, |
|
self.pad_token_id, |
|
) |
|
target_ids, target_mask, target_label = self.add_target_eos_bos_label( |
|
target_ids, |
|
target_mask, |
|
self.eos_target_id, |
|
self.bos_target_id, |
|
self.pad_token_id, |
|
) |
|
input_token_ids = torch.cat([phone_ids, target_ids], dim=-1) |
|
attention_mask = torch.cat([phone_mask, target_mask], dim=-1) |
|
|
|
if input_embeds is not None: |
|
raise NotImplementedError |
|
attention_mask = torch.cat( |
|
[ |
|
torch.ones( |
|
(input_embeds.shape[0], input_embeds.shape[1]), |
|
dtype=attention_mask.dtype, |
|
device=attention_mask.device, |
|
), |
|
attention_mask, |
|
], |
|
dim=-1, |
|
) |
|
labels = torch.cat([phone_label, target_label], dim=-1) |
|
if input_embeds is not None: |
|
raise NotImplementedError |
|
labels = torch.cat( |
|
[ |
|
-100 |
|
* torch.ones( |
|
(input_embeds.shape[0], input_embeds.shape[1]), |
|
dtype=labels.dtype, |
|
device=labels.device, |
|
), |
|
labels, |
|
], |
|
dim=-1, |
|
) |
|
|
|
if input_embeds is not None: |
|
raise NotImplementedError |
|
inputs_embeds = torch.cat( |
|
[input_embeds, self.model.model.embed_tokens(input_token_ids)], dim=1 |
|
) |
|
out = self.model( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
labels=labels, |
|
return_dict=True, |
|
) |
|
return out |
|
|
|
out = self.model( |
|
input_token_ids, |
|
attention_mask=attention_mask, |
|
labels=labels, |
|
return_dict=True, |
|
) |
|
|
|
|
|
logits = out.logits |
|
logits = logits[:, -target_ids.shape[1] :] |
|
top1_acc = logits.argmax(-1)[..., :-1] == target_ids[:, 1:] |
|
top1_acc = (top1_acc * target_mask[..., :-1]).sum() / target_mask.sum() |
|
|
|
top5_acc = torch.topk(logits[..., :-1, :], 5, dim=-1)[1] |
|
top5_acc = top5_acc == target_ids[:, 1:].unsqueeze(-1) |
|
top5_acc = ( |
|
top5_acc * target_mask[..., :-1].unsqueeze(-1) |
|
).sum() / target_mask.sum() |
|
|
|
top10_acc = torch.topk(logits[..., :-1, :], 10, dim=-1)[1] |
|
top10_acc = top10_acc == target_ids[:, 1:].unsqueeze(-1) |
|
top10_acc = ( |
|
top10_acc * target_mask[..., :-1].unsqueeze(-1) |
|
).sum() / target_mask.sum() |
|
|
|
out.top1_acc = top1_acc |
|
out.top5_acc = top5_acc |
|
out.top10_acc = top10_acc |
|
|
|
return out |
|
|
|
def add_phone_eos_bos_label( |
|
self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id |
|
): |
|
|
|
|
|
|
|
phone_ids = phone_ids + self.target_vocab_size * phone_mask |
|
|
|
phone_ids = phone_ids * phone_mask |
|
phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad( |
|
1 - phone_mask, (0, 1), value=1 |
|
) |
|
phone_mask = F.pad(phone_mask, (1, 0), value=1) |
|
phone_ids = phone_ids * phone_mask + pad_token_id * ( |
|
1 - phone_mask |
|
) |
|
phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) |
|
phone_mask = F.pad(phone_mask, (1, 0), value=1) |
|
phone_label = -100 * torch.ones_like( |
|
phone_ids |
|
) |
|
return phone_ids, phone_mask, phone_label |
|
|
|
def add_target_eos_bos_label( |
|
self, target_ids, target_mask, target_eos_id, target_bos_id, pad_token_id |
|
): |
|
|
|
|
|
target_ids = target_ids * target_mask |
|
target_ids = F.pad(target_ids, (0, 1), value=0) + target_eos_id * F.pad( |
|
1 - target_mask, (0, 1), value=1 |
|
) |
|
target_mask = F.pad(target_mask, (1, 0), value=1) |
|
target_ids = target_ids * target_mask + pad_token_id * (1 - target_mask) |
|
target_ids = F.pad(target_ids, (1, 0), value=target_bos_id) |
|
target_mask = F.pad(target_mask, (1, 0), value=1) |
|
target_label = target_ids * target_mask + (-100) * ( |
|
1 - target_mask |
|
) |
|
return target_ids, target_mask, target_label |
|
|
|
def sample_hf( |
|
self, |
|
phone_ids, |
|
prompt_ids, |
|
inputs_embeds=None, |
|
max_length=2000, |
|
temperature=1.0, |
|
top_k=100, |
|
top_p=0.9, |
|
repeat_penalty=1.0, |
|
num_beams=1, |
|
): |
|
if inputs_embeds is not None: |
|
inputs_embeds = self.emb_linear(inputs_embeds) |
|
phone_mask = torch.ones_like(phone_ids) |
|
prompt_mask = torch.ones_like(prompt_ids) |
|
phone_ids, _, _ = self.add_phone_eos_bos_label( |
|
phone_ids, |
|
phone_mask, |
|
self.eos_phone_id, |
|
self.bos_phone_id, |
|
self.pad_token_id, |
|
) |
|
prompt_ids, _, _ = self.add_target_eos_bos_label( |
|
prompt_ids, |
|
prompt_mask, |
|
self.eos_target_id, |
|
self.bos_target_id, |
|
self.pad_token_id, |
|
) |
|
prompt_ids = prompt_ids[:, :-1] |
|
|
|
input_token_ids = torch.cat([phone_ids, prompt_ids], dim=-1) |
|
|
|
if inputs_embeds is not None: |
|
raise NotImplementedError |
|
inputs_embeds = torch.cat( |
|
[inputs_embeds, self.model.model.embed_tokens(input_token_ids)], dim=1 |
|
) |
|
generated_ids = self.model.generate( |
|
inputs_embeds=inputs_embeds, |
|
do_sample=True, |
|
max_length=max_length, |
|
pad_token_id=self.pad_token_id, |
|
eos_token_id=self.eos_target_id, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
repetition_penalty=repeat_penalty, |
|
) |
|
gen_tokens = generated_ids[:, :-1] |
|
return gen_tokens |
|
|
|
input_length = input_token_ids.shape[1] |
|
generated_ids = self.model.generate( |
|
input_token_ids, |
|
do_sample=True, |
|
max_length=max_length, |
|
pad_token_id=self.pad_token_id, |
|
eos_token_id=self.eos_target_id, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
repetition_penalty=repeat_penalty, |
|
num_beams=num_beams, |
|
) |
|
|
|
gen_tokens = generated_ids[:, input_length:-1] |
|
|
|
return gen_tokens |
|
|
|
|
|
def test(): |
|
model = ValleAR() |
|
|
|
phone_ids = torch.LongTensor([[1, 2, 3, 4, 5, 0], [1, 2, 3, 4, 5, 6]]) |
|
phone_mask = torch.LongTensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]]) |
|
target_ids = torch.LongTensor([765, 234, 123, 234, 123, 599]).expand(2, -1) |
|
target_mask = torch.LongTensor([1, 1, 1, 1, 0, 0]).expand(2, -1) |
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) |
|
|
|
for i in range(15): |
|
optimizer.zero_grad() |
|
out = model( |
|
phone_ids=phone_ids, |
|
phone_mask=phone_mask, |
|
target_ids=target_ids, |
|
target_mask=target_mask, |
|
) |
|
loss = out.loss |
|
|
|
loss.backward() |
|
|
|
optimizer.step() |
|
|
|
print(f"iter={i}, {loss}.") |
|
|
|
phone_ids = torch.LongTensor([1, 2, 3]).reshape(1, -1) |
|
target_ids = torch.LongTensor([765, 234]).reshape(1, -1) |
|
sampled = model.sample_hf(phone_ids, target_ids) |
|
|
|
breakpoint() |
|
|
|
|
|
if __name__ == "__main__": |
|
test() |
|
|