Spaces:
Runtime error
Runtime error
import os | |
import time | |
import torch | |
import re | |
import difflib | |
from utils import * | |
from config import * | |
from transformers import GPT2Config | |
from abctoolkit.utils import Exclaim_re, Quote_re, SquareBracket_re, Barline_regexPattern | |
from abctoolkit.transpose import Note_list, Pitch_sign_list | |
from abctoolkit.duration import calculate_bartext_duration | |
import requests | |
import torch | |
from huggingface_hub import hf_hub_download | |
import logging | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
Note_list = Note_list + ['z', 'x'] | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
patchilizer = Patchilizer() | |
patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS, | |
max_length=PATCH_LENGTH, | |
max_position_embeddings=PATCH_LENGTH, | |
n_embd=HIDDEN_SIZE, | |
num_attention_heads=HIDDEN_SIZE // 64, | |
vocab_size=1) | |
byte_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS, | |
max_length=PATCH_SIZE + 1, | |
max_position_embeddings=PATCH_SIZE + 1, | |
hidden_size=HIDDEN_SIZE, | |
num_attention_heads=HIDDEN_SIZE // 64, | |
vocab_size=128) | |
model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config).to(device) | |
def download_model_weights(): | |
weights_path = "weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth" | |
local_weights_path = os.path.join(os.getcwd(), weights_path) | |
# Check if weights already exist locally | |
if os.path.exists(local_weights_path): | |
logger.info(f"Model weights already exist at {local_weights_path}") | |
return local_weights_path | |
logger.info("Downloading model weights from HuggingFace Hub...") | |
try: | |
# Download from HuggingFace | |
downloaded_path = hf_hub_download( | |
repo_id="ElectricAlexis/NotaGen", | |
filename=weights_path, | |
local_dir=os.getcwd(), | |
local_dir_use_symlinks=False | |
) | |
logger.info(f"Model weights downloaded successfully to {downloaded_path}") | |
return downloaded_path | |
except Exception as e: | |
logger.error(f"Error downloading model weights: {str(e)}") | |
raise RuntimeError(f"Failed to download model weights: {str(e)}") | |
def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True): | |
""" | |
Prepare model for k-bit training. | |
Features include: | |
1. Convert model to mixed precision (FP16). | |
2. Disable unnecessary gradient computations. | |
3. Enable gradient checkpointing (optional). | |
""" | |
# Convert model to mixed precision | |
model = model.to(dtype=torch.float16) | |
# Disable gradients for embedding layers | |
for param in model.parameters(): | |
if param.dtype == torch.float32: | |
param.requires_grad = False | |
# Enable gradient checkpointing | |
if use_gradient_checkpointing: | |
model.gradient_checkpointing_enable() | |
return model | |
model = prepare_model_for_kbit_training( | |
model, | |
use_gradient_checkpointing=False | |
) | |
print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad))) | |
# Download weights at startup | |
model_weights_path = download_model_weights() | |
checkpoint = torch.load(model_weights_path, map_location=torch.device(device)) | |
model.load_state_dict(checkpoint['model'], strict=False) | |
model = model.to(device) | |
model.eval() | |
def postprocess_inst_names(abc_text): | |
with open('standard_inst_names.txt', 'r', encoding='utf-8') as f: | |
standard_instruments_list = [line.strip() for line in f if line.strip()] | |
with open('instrument_mapping.json', 'r', encoding='utf-8') as f: | |
instrument_mapping = json.load(f) | |
abc_lines = abc_text.split('\n') | |
abc_lines = list(filter(None, abc_lines)) | |
abc_lines = [line + '\n' for line in abc_lines] | |
for i, line in enumerate(abc_lines): | |
if line.startswith('V:') and 'nm=' in line: | |
match = re.search(r'nm="([^"]*)"', line) | |
if match: | |
inst_name = match.group(1) | |
# Check if the instrument name is already standard | |
if inst_name in standard_instruments_list: | |
continue | |
# Find the most similar key in instrument_mapping | |
matching_key = difflib.get_close_matches(inst_name, list(instrument_mapping.keys()), n=1, cutoff=0.6) | |
if matching_key: | |
# Replace the instrument name with the standardized version | |
replacement = instrument_mapping[matching_key[0]] | |
new_line = line.replace(f'nm="{inst_name}"', f'nm="{replacement}"') | |
abc_lines[i] = new_line | |
# Combine the lines back into a single string | |
processed_abc_text = ''.join(abc_lines) | |
return processed_abc_text | |
def complete_brackets(s): | |
stack = [] | |
bracket_map = {'{': '}', '[': ']', '(': ')'} | |
# Iterate through each character, handle bracket matching | |
for char in s: | |
if char in bracket_map: | |
stack.append(char) | |
elif char in bracket_map.values(): | |
# Find the corresponding left bracket | |
for key, value in bracket_map.items(): | |
if value == char: | |
if stack and stack[-1] == key: | |
stack.pop() | |
break # Found matching right bracket, process next character | |
# Complete missing right brackets (in reverse order of remaining left brackets in stack) | |
completion = ''.join(bracket_map[c] for c in reversed(stack)) | |
return s + completion | |
def rest_unreduce(abc_lines): | |
tunebody_index = None | |
for i in range(len(abc_lines)): | |
if abc_lines[i].startswith('%%score'): | |
abc_lines[i] = complete_brackets(abc_lines[i]) | |
if '[V:' in abc_lines[i]: | |
tunebody_index = i | |
break | |
metadata_lines = abc_lines[: tunebody_index] | |
tunebody_lines = abc_lines[tunebody_index:] | |
part_symbol_list = [] | |
voice_group_list = [] | |
for line in metadata_lines: | |
if line.startswith('%%score'): | |
for round_bracket_match in re.findall(r'\((.*?)\)', line): | |
voice_group_list.append(round_bracket_match.split()) | |
existed_voices = [item for sublist in voice_group_list for item in sublist] | |
if line.startswith('V:'): | |
symbol = line.split()[0] | |
part_symbol_list.append(symbol) | |
if symbol[2:] not in existed_voices: | |
voice_group_list.append([symbol[2:]]) | |
z_symbol_list = [] # voices that use z as rest | |
x_symbol_list = [] # voices that use x as rest | |
for voice_group in voice_group_list: | |
z_symbol_list.append('V:' + voice_group[0]) | |
for j in range(1, len(voice_group)): | |
x_symbol_list.append('V:' + voice_group[j]) | |
part_symbol_list.sort(key=lambda x: int(x[2:])) | |
unreduced_tunebody_lines = [] | |
for i, line in enumerate(tunebody_lines): | |
unreduced_line = '' | |
line = re.sub(r'^\[r:[^\]]*\]', '', line) | |
pattern = r'\[V:(\d+)\](.*?)(?=\[V:|$)' | |
matches = re.findall(pattern, line) | |
line_bar_dict = {} | |
for match in matches: | |
key = f'V:{match[0]}' | |
value = match[1] | |
line_bar_dict[key] = value | |
# calculate duration and collect barline | |
dur_dict = {} | |
for symbol, bartext in line_bar_dict.items(): | |
right_barline = ''.join(re.split(Barline_regexPattern, bartext)[-2:]) | |
bartext = bartext[:-len(right_barline)] | |
try: | |
bar_dur = calculate_bartext_duration(bartext) | |
except: | |
bar_dur = None | |
if bar_dur is not None: | |
if bar_dur not in dur_dict.keys(): | |
dur_dict[bar_dur] = 1 | |
else: | |
dur_dict[bar_dur] += 1 | |
try: | |
ref_dur = max(dur_dict, key=dur_dict.get) | |
except: | |
pass # use last ref_dur | |
if i == 0: | |
prefix_left_barline = line.split('[V:')[0] | |
else: | |
prefix_left_barline = '' | |
for symbol in part_symbol_list: | |
if symbol in line_bar_dict.keys(): | |
symbol_bartext = line_bar_dict[symbol] | |
else: | |
if symbol in z_symbol_list: | |
symbol_bartext = prefix_left_barline + 'z' + str(ref_dur) + right_barline | |
elif symbol in x_symbol_list: | |
symbol_bartext = prefix_left_barline + 'x' + str(ref_dur) + right_barline | |
unreduced_line += '[' + symbol + ']' + symbol_bartext | |
unreduced_tunebody_lines.append(unreduced_line + '\n') | |
unreduced_lines = metadata_lines + unreduced_tunebody_lines | |
return unreduced_lines | |
def inference_patch(period, composer, instrumentation): | |
prompt_lines = [ | |
'%' + period + '\n', | |
'%' + composer + '\n', | |
'%' + instrumentation + '\n'] | |
while True: | |
failure_flag = False | |
bos_patch = [patchilizer.bos_token_id] * (PATCH_SIZE - 1) + [patchilizer.eos_token_id] | |
start_time = time.time() | |
prompt_patches = patchilizer.patchilize_metadata(prompt_lines) | |
byte_list = list(''.join(prompt_lines)) | |
context_tunebody_byte_list = [] | |
metadata_byte_list = [] | |
print(''.join(byte_list), end='') | |
prompt_patches = [[ord(c) for c in patch] + [patchilizer.special_token_id] * (PATCH_SIZE - len(patch)) for patch | |
in prompt_patches] | |
prompt_patches.insert(0, bos_patch) | |
input_patches = torch.tensor(prompt_patches, device=device).reshape(1, -1) | |
end_flag = False | |
cut_index = None | |
tunebody_flag = False | |
with torch.inference_mode(): | |
while True: | |
with torch.autocast(device_type='cuda', dtype=torch.float16): | |
predicted_patch = model.generate(input_patches.unsqueeze(0), | |
top_k=TOP_K, | |
top_p=TOP_P, | |
temperature=TEMPERATURE) | |
if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith( | |
'[r:'): # 初次进入tunebody,必须以[r:0/开头 | |
tunebody_flag = True | |
r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device) | |
temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1) | |
predicted_patch = model.generate(temp_input_patches.unsqueeze(0), | |
top_k=TOP_K, | |
top_p=TOP_P, | |
temperature=TEMPERATURE) | |
predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch | |
if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id: | |
end_flag = True | |
break | |
next_patch = patchilizer.decode([predicted_patch]) | |
for char in next_patch: | |
byte_list.append(char) | |
if tunebody_flag: | |
context_tunebody_byte_list.append(char) | |
else: | |
metadata_byte_list.append(char) | |
print(char, end='') | |
patch_end_flag = False | |
for j in range(len(predicted_patch)): | |
if patch_end_flag: | |
predicted_patch[j] = patchilizer.special_token_id | |
if predicted_patch[j] == patchilizer.eos_token_id: | |
patch_end_flag = True | |
predicted_patch = torch.tensor([predicted_patch], device=device) # (1, 16) | |
input_patches = torch.cat([input_patches, predicted_patch], dim=1) # (1, 16 * patch_len) | |
if len(byte_list) > 102400: | |
failure_flag = True | |
break | |
if time.time() - start_time > 10 * 60: | |
failure_flag = True | |
break | |
if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag: | |
print('Stream generating...') | |
metadata = ''.join(metadata_byte_list) | |
context_tunebody = ''.join(context_tunebody_byte_list) | |
if '\n' not in context_tunebody: | |
break # Generated content is all metadata, abandon | |
context_tunebody_lines = context_tunebody.strip().split('\n') | |
if not context_tunebody.endswith('\n'): | |
context_tunebody_lines = [context_tunebody_lines[i] + '\n' for i in | |
range(len(context_tunebody_lines) - 1)] + [context_tunebody_lines[-1]] | |
else: | |
context_tunebody_lines = [context_tunebody_lines[i] + '\n' for i in | |
range(len(context_tunebody_lines))] | |
cut_index = len(context_tunebody_lines) // 2 | |
abc_code_slice = metadata + ''.join(context_tunebody_lines[-cut_index:]) | |
input_patches = patchilizer.encode_generate(abc_code_slice) | |
input_patches = [item for sublist in input_patches for item in sublist] | |
input_patches = torch.tensor([input_patches], device=device) | |
input_patches = input_patches.reshape(1, -1) | |
context_tunebody_byte_list = list(''.join(context_tunebody_lines[-cut_index:])) | |
if not failure_flag: | |
abc_text = ''.join(byte_list) | |
# unreduce | |
abc_lines = abc_text.split('\n') | |
abc_lines = list(filter(None, abc_lines)) | |
abc_lines = [line + '\n' for line in abc_lines] | |
try: | |
unreduced_abc_lines = rest_unreduce(abc_lines) | |
except: | |
failure_flag = True | |
pass | |
else: | |
unreduced_abc_lines = [line for line in unreduced_abc_lines if | |
not (line.startswith('%') and not line.startswith('%%'))] | |
unreduced_abc_lines = ['X:1\n'] + unreduced_abc_lines | |
unreduced_abc_text = ''.join(unreduced_abc_lines) | |
return unreduced_abc_text | |
if __name__ == '__main__': | |
inference_patch('Classical', 'Beethoven, Ludwig van', 'Orchestral') | |