NotaGen / utils.py
ElectricAlexis's picture
Upload 12 files
401e785 verified
raw
history blame
16.5 kB
import torch
import random
import bisect
import json
import re
from config import *
from transformers import GPT2Model, GPT2LMHeadModel, PreTrainedModel, BitsAndBytesConfig
from samplings import top_p_sampling, top_k_sampling, temperature_sampling
from tokenizers import Tokenizer
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_skip_modules=["patch_embedding"] # 跳过可能不兼容的模块
)
class Patchilizer:
def __init__(self, stream=PATCH_STREAM):
self.stream = stream
self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
self.bos_token_id = 1
self.eos_token_id = 2
self.special_token_id = 0
def split_bars(self, body_lines):
"""
Split a body of music into individual bars.
"""
new_bars = []
try:
for line in body_lines:
line_bars = re.split(self.regexPattern, line)
line_bars = list(filter(None, line_bars))
new_line_bars = []
if len(line_bars) == 1:
new_line_bars = line_bars
else:
if line_bars[0] in self.delimiters:
new_line_bars = [line_bars[i] + line_bars[i + 1] for i in range(0, len(line_bars), 2)]
else:
new_line_bars = [line_bars[0]] + [line_bars[i] + line_bars[i + 1] for i in range(1, len(line_bars), 2)]
if 'V' not in new_line_bars[-1]:
new_line_bars[-2] += new_line_bars[-1] # 吸收最后一个 小节线+\n 的组合
new_line_bars = new_line_bars[:-1]
new_bars += new_line_bars
except:
pass
return new_bars
def split_patches(self, abc_text, patch_size=PATCH_SIZE, generate_last=False):
if not generate_last and len(abc_text) % patch_size != 0:
abc_text += chr(self.eos_token_id)
patches = [abc_text[i : i + patch_size] for i in range(0, len(abc_text), patch_size)]
return patches
def patch2chars(self, patch):
"""
Convert a patch into a bar.
"""
bytes = ''
for idx in patch:
if idx == self.eos_token_id:
break
if idx < self.eos_token_id:
pass
bytes += chr(idx)
return bytes
def patchilize_metadata(self, metadata_lines):
metadata_patches = []
for line in metadata_lines:
metadata_patches += self.split_patches(line)
return metadata_patches
def patchilize_tunebody(self, tunebody_lines, encode_mode='train'):
tunebody_patches = []
bars = self.split_bars(tunebody_lines)
if encode_mode == 'train':
for bar in bars:
tunebody_patches += self.split_patches(bar)
elif encode_mode == 'generate':
for bar in bars[:-1]:
tunebody_patches += self.split_patches(bar)
tunebody_patches += self.split_patches(bars[-1], generate_last=True)
return tunebody_patches
def encode(self, abc_text, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True, cut=True):
lines = abc_text.split('\n')
lines = list(filter(None, lines))
lines = [line + '\n' for line in lines]
tunebody_index = -1
for i, line in enumerate(lines):
if line.startswith('[r:'):
tunebody_index = i
break
metadata_lines = lines[: tunebody_index]
tunebody_lines = lines[tunebody_index:]
metadata_patches = self.patchilize_metadata(metadata_lines)
tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='train')
if add_special_patches:
bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
eos_patch = chr(self.bos_token_id) + chr(self.eos_token_id) * (patch_size - 1)
metadata_patches = [bos_patch] + metadata_patches
tunebody_patches = tunebody_patches + [eos_patch]
if self.stream:
if len(metadata_patches) + len(tunebody_patches) > patch_length:
available_cut_indexes = [0] + [index + 1 for index, patch in enumerate(tunebody_patches) if
'\n' in patch]
line_index_for_cut_index = list(range(len(available_cut_indexes))) # 每个cut_index对应tunebody的哪一行
end_index = len(metadata_patches) + len(tunebody_patches) - patch_length
biggest_index = bisect.bisect_left(available_cut_indexes, end_index) # biggest index 在 end_index 右面一位
available_cut_indexes = available_cut_indexes[:biggest_index + 1]
if len(available_cut_indexes) == 1:
choices = ['head']
elif len(available_cut_indexes) == 2:
choices = ['head', 'tail']
else:
choices = ['head', 'tail', 'middle']
choice = random.choice(choices)
if choice == 'head':
patches = metadata_patches + tunebody_patches[0:]
else:
if choice == 'tail':
cut_index = len(available_cut_indexes) - 1
else:
cut_index = random.choice(range(1, len(available_cut_indexes) - 1))
line_index = line_index_for_cut_index[cut_index]
stream_tunebody_lines = tunebody_lines[line_index:]
stream_tunebody_patches = self.patchilize_tunebody(stream_tunebody_lines, encode_mode='train')
if add_special_patches:
stream_tunebody_patches = stream_tunebody_patches + [eos_patch]
patches = metadata_patches + stream_tunebody_patches
else:
patches = metadata_patches + tunebody_patches
else:
patches = metadata_patches + tunebody_patches
patches = patches[: patch_length]
# encode to ids
id_patches = []
for patch in patches:
id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
id_patches.append(id_patch)
return id_patches
def encode_generate(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True):
lines = abc_code.split('\n')
lines = list(filter(None, lines))
tunebody_index = None
for i, line in enumerate(lines):
if line.startswith('[V:') or line.startswith('[r:'):
tunebody_index = i
break
metadata_lines = lines[ : tunebody_index]
tunebody_lines = lines[tunebody_index : ] # 备份未省略前的tunebody_lines
metadata_lines = [line + '\n' for line in metadata_lines]
if self.stream:
if not abc_code.endswith('\n'): # 如果生成结果最后一行未完结
tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [tunebody_lines[-1]]
else:
tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
else:
tunebody_lines = [line + '\n' for line in tunebody_lines]
metadata_patches = self.patchilize_metadata(metadata_lines)
tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='generate')
if add_special_patches:
bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
metadata_patches = [bos_patch] + metadata_patches
patches = metadata_patches + tunebody_patches
patches = patches[ : patch_length]
# encode to ids
id_patches = []
for patch in patches:
if len(patch) < PATCH_SIZE and patch[-1] != chr(self.eos_token_id):
id_patch = [ord(c) for c in patch]
else:
id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
id_patches.append(id_patch)
return id_patches
def decode(self, patches):
"""
Decode patches into music.
"""
return ''.join(self.patch2chars(patch) for patch in patches)
class PatchLevelDecoder(PreTrainedModel):
"""
A Patch-level Decoder model for generating patch features in an auto-regressive manner.
It inherits PreTrainedModel from transformers.
"""
def __init__(self, config):
super().__init__(config)
self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd).to(torch.float16)
torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
self.base = GPT2Model(config)
def forward(self,
patches: torch.Tensor,
masks=None) -> torch.Tensor:
"""
The forward pass of the patch-level decoder model.
:param patches: the patches to be encoded
:param masks: the masks for the patches
:return: the encoded patches
"""
patches = torch.nn.functional.one_hot(patches, num_classes=128).to(self.dtype)
patches = patches.reshape(len(patches), -1, PATCH_SIZE * (128))
patches = self.patch_embedding(patches.to(self.device))
if masks==None:
return self.base(inputs_embeds=patches)
else:
return self.base(inputs_embeds=patches,
attention_mask=masks)
class CharLevelDecoder(PreTrainedModel):
"""
A Char-level Decoder model for generating the chars within each patch in an auto-regressive manner
based on the encoded patch features. It inherits PreTrainedModel from transformers.
"""
def __init__(self, config):
super().__init__(config)
self.special_token_id = 0
self.bos_token_id = 1
self.base = GPT2LMHeadModel(config)
def forward(self,
encoded_patches: torch.Tensor,
target_patches: torch.Tensor):
"""
The forward pass of the char-level decoder model.
:param encoded_patches: the encoded patches
:param target_patches: the target patches
:return: the output of the model
"""
target_patches = torch.cat((torch.ones_like(target_patches[:, 0:1]) * self.bos_token_id,
target_patches), dim=1) # [patch_len, patch_size + 1]
target_masks = target_patches == self.special_token_id # [patch_len, patch_size + 1]
labels = target_patches.clone().masked_fill_(target_masks, -100)
target_masks = torch.ones_like(labels)
target_masks = target_masks.masked_fill_(labels == -100, 0)
input_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
input_embeds = torch.cat((encoded_patches.unsqueeze(1), input_embeds[:, 1:, :]), dim=1)
logits = self.base(inputs_embeds=input_embeds,
attention_mask=target_masks).logits # [patch_len, patch_size + 1, vocab_size]
logits = logits[:, :-1, :]
token_logps = torch.gather(logits.log_softmax(-1), dim=-1, index=target_patches[:, 1:].unsqueeze(-1)).squeeze(-1) # [patch_len, patch_size]
token_logps = token_logps[target_masks[:, 1:] == 1]
all_logps = token_logps.sum()
return all_logps
def generate(self,
encoded_patch: torch.Tensor, # [hidden_size]
tokens: torch.Tensor): # [1]
"""
The generate function for generating a patch based on the encoded patch and already generated tokens.
:param encoded_patch: the encoded patch
:param tokens: already generated tokens in the patch
:return: the probability distribution of next token
"""
encoded_patch = encoded_patch.reshape(1, 1, -1) # [1, 1, hidden_size]
tokens = tokens.reshape(1, -1)
# Get input embeddings
tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
# Concatenate the encoded patch with the input embeddings
tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)
# Get output from model
outputs = self.base(inputs_embeds=tokens)
# Get probabilities of next token
probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
return probs
class NotaGenLMHeadModel(PreTrainedModel):
"""
NotaGen is a language model with a hierarchical structure.
It includes a patch-level decoder and a char-level decoder.
The patch-level decoder is used to generate patch features in an auto-regressive manner.
The char-level decoder is used to generate the chars within each patch in an auto-regressive manner.
It inherits PreTrainedModel from transformers.
"""
def __init__(self, encoder_config, decoder_config):
super().__init__(encoder_config)
self.special_token_id = 0
self.bos_token_id = 1
self.eos_token_id = 2
self.patch_level_decoder = PatchLevelDecoder(encoder_config)
self.char_level_decoder = CharLevelDecoder(decoder_config)
def forward(self,
patches: torch.Tensor,
masks: torch.Tensor):
"""
The forward pass of the bGPT model.
:param patches: the patches to be encoded
:param masks: the masks for the patches
:return: the decoded patches
"""
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
encoded_patches = self.patch_level_decoder(patches, masks)["last_hidden_state"]
left_shift_masks = masks * (masks.flip(1).cumsum(1).flip(1) > 1)
masks[:, 0] = 0
encoded_patches = encoded_patches[left_shift_masks == 1]
patches = patches[masks == 1]
return self.char_level_decoder(encoded_patches, patches)
def generate(self,
patches: torch.Tensor,
top_k=0,
top_p=1,
temperature=1.0):
"""
The generate function for generating patches based on patches.
:param patches: the patches to be encoded
:param top_k: the top k for sampling
:param top_p: the top p for sampling
:param temperature: the temperature for sampling
:return: the generated patches
"""
if patches.shape[-1] % PATCH_SIZE != 0:
tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1)
tokens = torch.cat((torch.tensor([self.bos_token_id], device=self.device), tokens), dim=-1)
patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)]
else:
tokens = torch.tensor([self.bos_token_id], device=self.device)
patches = patches.reshape(len(patches), -1, PATCH_SIZE) # [bs, seq, patch_size]
encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"] # [bs, seq, hidden_size]
generated_patch = []
while True:
prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy() # [128]
prob = top_k_sampling(prob, top_k=top_k, return_probs=True) # [128]
prob = top_p_sampling(prob, top_p=top_p, return_probs=True) # [128]
token = temperature_sampling(prob, temperature=temperature) # int
char = chr(token)
generated_patch.append(token)
if len(tokens) >= PATCH_SIZE:# or token == self.eos_token_id:
break
else:
tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)
return generated_patch