import re import os import math import torch import random from config import * from unidecode import unidecode from torch.nn import functional as F from transformers import AutoModel, BertModel, GPT2LMHeadModel, PreTrainedModel, GPT2Config try: import torch.distributed.nn from torch import distributed as dist has_distributed = True except ImportError: has_distributed = False try: import horovod.torch as hvd except ImportError: hvd = None class ClipLoss(torch.nn.Module): def __init__( self, local_loss=False, gather_with_grad=False, cache_labels=False, rank=0, world_size=1, use_horovod=False, ): super().__init__() self.local_loss = local_loss self.gather_with_grad = gather_with_grad self.cache_labels = cache_labels self.rank = rank self.world_size = world_size self.use_horovod = use_horovod # cache state self.prev_num_logits = 0 self.labels = {} def gather_features( self, image_features, text_features, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False ): assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' if use_horovod: assert hvd is not None, 'Please install horovod' if gather_with_grad: all_image_features = hvd.allgather(image_features) all_text_features = hvd.allgather(text_features) else: with torch.no_grad(): all_image_features = hvd.allgather(image_features) all_text_features = hvd.allgather(text_features) if not local_loss: # ensure grads for local rank when all_* features don't have a gradient gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) gathered_image_features[rank] = image_features gathered_text_features[rank] = text_features all_image_features = torch.cat(gathered_image_features, dim=0) all_text_features = torch.cat(gathered_text_features, dim=0) else: # We gather tensors from all gpus if gather_with_grad: all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) else: gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] dist.all_gather(gathered_image_features, image_features) dist.all_gather(gathered_text_features, text_features) if not local_loss: # ensure grads for local rank when all_* features don't have a gradient gathered_image_features[rank] = image_features gathered_text_features[rank] = text_features all_image_features = torch.cat(gathered_image_features, dim=0) all_text_features = torch.cat(gathered_text_features, dim=0) return all_image_features, all_text_features def get_ground_truth(self, device, num_logits) -> torch.Tensor: # calculated ground-truth and cache if enabled if self.prev_num_logits != num_logits or device not in self.labels: labels = torch.arange(num_logits, device=device, dtype=torch.long) if self.world_size > 1 and self.local_loss: labels = labels + num_logits * self.rank if self.cache_labels: self.labels[device] = labels self.prev_num_logits = num_logits else: labels = self.labels[device] return labels def get_logits(self, image_features, text_features, logit_scale): if self.world_size > 1: all_image_features, all_text_features = self.gather_features( image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) if self.local_loss: logits_per_image = logit_scale * image_features @ all_text_features.T logits_per_text = logit_scale * text_features @ all_image_features.T else: logits_per_image = logit_scale * all_image_features @ all_text_features.T logits_per_text = logits_per_image.T else: logits_per_image = logit_scale * image_features @ text_features.T logits_per_text = logit_scale * text_features @ image_features.T return logits_per_image, logits_per_text def forward(self, image_features, text_features, logit_scale, output_dict=False): device = image_features.device logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) labels = self.get_ground_truth(device, logits_per_image.shape[0]) total_loss = ( F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels) ) / 2 return {"contrastive_loss": total_loss} if output_dict else total_loss class M3Patchilizer: def __init__(self): self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"] self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')' self.pad_token_id = 0 self.bos_token_id = 1 self.eos_token_id = 2 self.mask_token_id = 3 def split_bars(self, body): bars = re.split(self.regexPattern, ''.join(body)) bars = list(filter(None, bars)) # remove empty strings if bars[0] in self.delimiters: bars[1] = bars[0] + bars[1] bars = bars[1:] bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)] return bars def bar2patch(self, bar, patch_size=PATCH_SIZE): patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id] patch = patch[:patch_size] patch += [self.pad_token_id] * (patch_size - len(patch)) return patch def patch2bar(self, patch): return ''.join(chr(idx) if idx > self.mask_token_id else '' for idx in patch) def encode(self, item, patch_size=PATCH_SIZE, add_special_patches=False, truncate=False, random_truncate=False): item = unidecode(item) lines = re.findall(r'.*?\n|.*$', item) lines = list(filter(None, lines)) # remove empty lines patches = [] if lines[0].split(" ")[0] == "ticks_per_beat": patch = "" for line in lines: if patch.startswith(line.split(" ")[0]) and (len(patch) + len(" ".join(line.split(" ")[1:])) <= patch_size-2): patch = patch[:-1] + "\t" + " ".join(line.split(" ")[1:]) else: if patch: patches.append(patch) patch = line if patch!="": patches.append(patch) else: for line in lines: if len(line) > 1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%')): patches.append(line) else: bars = self.split_bars(line) if bars: bars[-1] += '\n' patches.extend(bars) if add_special_patches: bos_patch = chr(self.bos_token_id) * patch_size eos_patch = chr(self.eos_token_id) * patch_size patches = [bos_patch] + patches + [eos_patch] if len(patches) > PATCH_LENGTH and truncate: choices = ["head", "tail", "middle"] choice = random.choice(choices) if choice=="head" or random_truncate==False: patches = patches[:PATCH_LENGTH] elif choice=="tail": patches = patches[-PATCH_LENGTH:] else: start = random.randint(1, len(patches)-PATCH_LENGTH) patches = patches[start:start+PATCH_LENGTH] patches = [self.bar2patch(patch) for patch in patches] return patches def decode(self, patches): return ''.join(self.patch2bar(patch) for patch in patches) class M3PatchEncoder(PreTrainedModel): def __init__(self, config): super(M3PatchEncoder, self).__init__(config) self.patch_embedding = torch.nn.Linear(PATCH_SIZE*128, M3_HIDDEN_SIZE) torch.nn.init.normal_(self.patch_embedding.weight, std=0.02) self.base = BertModel(config=config) self.pad_token_id = 0 self.bos_token_id = 1 self.eos_token_id = 2 self.mask_token_id = 3 def forward(self, input_patches, # [batch_size, seq_length, hidden_size] input_masks): # [batch_size, seq_length] # Transform input_patches into embeddings input_patches = torch.nn.functional.one_hot(input_patches, num_classes=128) input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE*128).type(torch.FloatTensor) input_patches = self.patch_embedding(input_patches.to(self.device)) # Apply BERT model to input_patches and input_masks return self.base(inputs_embeds=input_patches, attention_mask=input_masks) class M3TokenDecoder(PreTrainedModel): def __init__(self, config): super(M3TokenDecoder, self).__init__(config) self.base = GPT2LMHeadModel(config=config) self.pad_token_id = 0 self.bos_token_id = 1 self.eos_token_id = 2 self.mask_token_id = 3 def forward(self, patch_features, # [batch_size, hidden_size] target_patches): # [batch_size, seq_length] # get input embeddings inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight) # concatenate the encoded patches with the input embeddings inputs_embeds = torch.cat((patch_features.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1) # preparing the labels for model training target_masks = target_patches == self.pad_token_id target_patches = target_patches.clone().masked_fill_(target_masks, -100) # get the attention mask target_masks = ~target_masks target_masks = target_masks.type(torch.int) return self.base(inputs_embeds=inputs_embeds, attention_mask=target_masks, labels=target_patches) def generate(self, patch_feature, tokens): # reshape the patch_feature and tokens patch_feature = patch_feature.reshape(1, 1, -1) tokens = tokens.reshape(1, -1) # get input embeddings tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight) # concatenate the encoded patches with the input embeddings tokens = torch.cat((patch_feature, tokens[:,1:,:]), dim=1) # get the outputs from the model outputs = self.base(inputs_embeds=tokens) # get the probabilities of the next token probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1) return probs.detach().cpu().numpy() class M3Model(PreTrainedModel): def __init__(self, encoder_config, decoder_config): super(M3Model, self).__init__(encoder_config) self.encoder = M3PatchEncoder(encoder_config) self.decoder = M3TokenDecoder(decoder_config) self.pad_token_id = 0 self.bos_token_id = 1 self.eos_token_id = 2 self.mask_token_id = 3 def forward(self, input_patches, # [batch_size, seq_length, hidden_size] input_masks, # [batch_size, seq_length] selected_indices, # [batch_size, seq_length] target_patches): # [batch_size, seq_length, hidden_size] input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE).to(self.device) input_masks = input_masks.to(self.device) selected_indices = selected_indices.to(self.device) target_patches = target_patches.reshape(len(target_patches), -1, PATCH_SIZE).to(self.device) # Pass the input_patches and input_masks through the encoder outputs = self.encoder(input_patches, input_masks)["last_hidden_state"] # Use selected_indices to form target_patches target_patches = target_patches[selected_indices.bool()] patch_features = outputs[selected_indices.bool()] # Pass patch_features and target_patches through the decoder return self.decoder(patch_features, target_patches) class CLaMP2Model(PreTrainedModel): def __init__(self, music_config, global_rank=None, world_size=None, text_model_name=TEXT_MODEL_NAME, hidden_size=CLAMP2_HIDDEN_SIZE, load_m3=CLAMP2_LOAD_M3): super(CLaMP2Model, self).__init__(music_config) self.text_model = AutoModel.from_pretrained(text_model_name) # Load the text model self.text_proj = torch.nn.Linear(self.text_model.config.hidden_size, hidden_size) # Linear layer for text projections torch.nn.init.normal_(self.text_proj.weight, std=0.02) # Initialize weights with normal distribution self.music_model = M3PatchEncoder(music_config) # Initialize the music model self.music_proj = torch.nn.Linear(M3_HIDDEN_SIZE, hidden_size) # Linear layer for music projections torch.nn.init.normal_(self.music_proj.weight, std=0.02) # Initialize weights with normal distribution if global_rank==None or world_size==None: global_rank = 0 world_size = 1 self.loss_fn = ClipLoss(local_loss=False, gather_with_grad=True, cache_labels=False, rank=global_rank, world_size=world_size, use_horovod=False) if load_m3 and os.path.exists(M3_WEIGHTS_PATH): checkpoint = torch.load(M3_WEIGHTS_PATH, map_location='cpu', weights_only=True) decoder_config = GPT2Config(vocab_size=128, n_positions=PATCH_SIZE, n_embd=M3_HIDDEN_SIZE, n_layer=TOKEN_NUM_LAYERS, n_head=M3_HIDDEN_SIZE//64, n_inner=M3_HIDDEN_SIZE*4) model = M3Model(music_config, decoder_config) model.load_state_dict(checkpoint['model']) self.music_model = model.encoder model = None print(f"Successfully Loaded M3 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}") def avg_pooling(self, input_features, input_masks): input_masks = input_masks.unsqueeze(-1).to(self.device) # add a dimension to match the feature dimension input_features = input_features * input_masks # apply mask to input_features avg_pool = input_features.sum(dim=1) / input_masks.sum(dim=1) # calculate average pooling return avg_pool def get_text_features(self, text_inputs, text_masks, get_normalized=False): text_features = self.text_model(text_inputs.to(self.device), attention_mask=text_masks.to(self.device))['last_hidden_state'] if get_normalized: text_features = self.avg_pooling(text_features, text_masks) text_features = self.text_proj(text_features) return text_features def get_music_features(self, music_inputs, music_masks, get_normalized=False): music_features = self.music_model(music_inputs.to(self.device), music_masks.to(self.device))['last_hidden_state'] if get_normalized: music_features = self.avg_pooling(music_features, music_masks) music_features = self.music_proj(music_features) return music_features def forward(self, text_inputs, # [batch_size, seq_length] text_masks, # [batch_size, seq_length] music_inputs, # [batch_size, seq_length, hidden_size] music_masks): # [batch_size, seq_length] # Compute the text features text_features = self.get_text_features(text_inputs, text_masks, get_normalized=True) # Compute the music features music_features = self.get_music_features(music_inputs, music_masks, get_normalized=True) return self.loss_fn(text_features, music_features, LOGIT_SCALE, output_dict=False) def split_data(data, eval_ratio=EVAL_SPLIT): random.shuffle(data) split_idx = int(len(data)*eval_ratio) eval_set = data[:split_idx] train_set = data[split_idx:] return train_set, eval_set def mask_patches(target_patches, patchilizer, mode): indices = list(range(len(target_patches))) random.shuffle(indices) selected_indices = indices[:math.ceil(M3_MASK_RATIO*len(indices))] sorted_indices = sorted(selected_indices) input_patches = torch.tensor(target_patches) if mode=="eval": choice = "original" else: choice = random.choices(["mask", "shuffle", "original"], weights=[0.8, 0.1, 0.1])[0] if choice=="mask": input_patches[sorted_indices] = torch.tensor([patchilizer.mask_token_id]*PATCH_SIZE) elif choice=="shuffle": for idx in sorted_indices: patch = input_patches[idx] try: index_eos = (patch == patchilizer.eos_token_id).nonzero().item() except: index_eos = len(patch) indices = list(range(1, index_eos)) random.shuffle(indices) indices = [0] + indices + list(range(index_eos, len(patch))) input_patches[idx] = patch[indices] selected_indices = torch.zeros(len(target_patches)) selected_indices[sorted_indices] = 1. return input_patches, selected_indices def remove_instrument_info(item): # remove instrument information from symbolic music lines = re.findall(r'.*?\n|.*$', item) lines = list(filter(None, lines)) if lines[0].split(" ")[0] == "ticks_per_beat": type = "mtf" else: type = "abc" cleaned_lines = [] for line in lines: if type=="abc" and line.startswith("V:"): # find the position of " nm=" or " snm=" nm_pos = line.find(" nm=") snm_pos = line.find(" snm=") # keep the part before " nm=" or " snm=" if nm_pos != -1: line = line[:nm_pos] elif snm_pos != -1: line = line[:snm_pos] if nm_pos != -1 or snm_pos != -1: line += "\n" elif type=="mtf" and line.startswith("program_change"): line = " ".join(line.split(" ")[:-1]) + " 0\n" cleaned_lines.append(line) return ''.join(cleaned_lines)