File size: 6,402 Bytes
3c428bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import json
import random
import torch
import numpy as np
from tqdm import tqdm
from config import *
from utils import *
from samplings import *
from accelerate import Accelerator
from transformers import BertConfig, GPT2Config
import argparse

# Parse command-line arguments for input_dir and output_dir
parser = argparse.ArgumentParser(description="Process files to extract features.")
parser.add_argument("input_dir", type=str, help="Directory with input files")
parser.add_argument("output_dir", type=str, help="Directory to save extracted features")
args = parser.parse_args()

# Use args for input and output directories
input_dir = args.input_dir
output_dir = args.output_dir

# Create logs directory if it doesn't exist
os.makedirs("logs", exist_ok=True)

# Remove existing log files if present
for file in [
    "logs/files_extract_m3.json",
    "logs/files_shuffle_extract_m3.json",
    "logs/log_extract_m3.txt",
    "logs/pass_extract_m3.txt",
    "logs/skip_extract_m3.txt",
]:
    if os.path.exists(file):
        os.remove(file)

# Collect input files
files = []
for root, dirs, fs in os.walk(input_dir):
    for f in fs:
        if f.endswith(".abc") or f.endswith(".mtf"):
            files.append(os.path.join(root, f))

print(f"Found {len(files)} files in total")
with open("logs/files_extract_m3.json", "w", encoding="utf-8") as f:
    json.dump(files, f)

# Shuffle files and save the shuffled order
random.shuffle(files)
with open("logs/files_shuffle_extract_m3.json", "w", encoding="utf-8") as f:
    json.dump(files, f)

# Initialize accelerator and device
accelerator = Accelerator()
device = accelerator.device
print("Using device:", device)
with open("logs/log_extract_m3.txt", "a", encoding="utf-8") as f:
    f.write("Using device: " + str(device) + "\n")

# Model and configuration setup
patchilizer = M3Patchilizer()
encoder_config = BertConfig(
    vocab_size=1,
    hidden_size=M3_HIDDEN_SIZE,
    num_hidden_layers=PATCH_NUM_LAYERS,
    num_attention_heads=M3_HIDDEN_SIZE // 64,
    intermediate_size=M3_HIDDEN_SIZE * 4,
    max_position_embeddings=PATCH_LENGTH,
)
decoder_config = GPT2Config(
    vocab_size=128,
    n_positions=PATCH_SIZE,
    n_embd=M3_HIDDEN_SIZE,
    n_layer=TOKEN_NUM_LAYERS,
    n_head=M3_HIDDEN_SIZE // 64,
    n_inner=M3_HIDDEN_SIZE * 4,
)
model = M3Model(encoder_config, decoder_config).to(device)

# Print parameter count
print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))

# Load model weights
model.eval()
checkpoint = torch.load(M3_WEIGHTS_PATH, map_location='cpu', weights_only=True)
print(f"Successfully Loaded Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
model.load_state_dict(checkpoint['model'])

def extract_feature(item):
    """Extracts features from input data."""
    target_patches = patchilizer.encode(item, add_special_patches=True)
    target_patches_list = [target_patches[i:i + PATCH_LENGTH] for i in range(0, len(target_patches), PATCH_LENGTH)]
    target_patches_list[-1] = target_patches[-PATCH_LENGTH:]

    last_hidden_states_list = []
    for input_patches in target_patches_list:
        input_masks = torch.tensor([1] * len(input_patches))
        input_patches = torch.tensor(input_patches)
        last_hidden_states = model.encoder(
            input_patches.unsqueeze(0).to(device), input_masks.unsqueeze(0).to(device)
        )["last_hidden_state"][0]
        last_hidden_states_list.append(last_hidden_states)

    # Handle the last segment padding correctly
    last_hidden_states_list[-1] = last_hidden_states_list[-1][-(len(target_patches) % PATCH_LENGTH):]
    return torch.concat(last_hidden_states_list, 0)

def process_directory(input_dir, output_dir, files):
    """Processes files in the input directory and saves features to the output directory."""
    print(f"Found {len(files)} files in total")
    with open("logs/log_extract_m3.txt", "a", encoding="utf-8") as f:
        f.write("Found " + str(len(files)) + " files in total\n")

    # Distribute files across processes for parallel processing
    num_files_per_gpu = len(files) // accelerator.num_processes
    start_idx = accelerator.process_index * num_files_per_gpu
    end_idx = start_idx + num_files_per_gpu if accelerator.process_index < accelerator.num_processes - 1 else len(files)
    files_to_process = files[start_idx:end_idx]

    # Process each file
    for file in tqdm(files_to_process):
        output_subdir = output_dir + os.path.dirname(file)[len(input_dir):]
        try:
            os.makedirs(output_subdir, exist_ok=True)
        except Exception as e:
            print(f"{output_subdir} cannot be created\n{e}")
            with open("logs/log_extract_m3.txt", "a") as f:
                f.write(f"{output_subdir} cannot be created\n{e}\n")

        output_file = os.path.join(output_subdir, os.path.splitext(os.path.basename(file))[0] + ".npy")

        if os.path.exists(output_file):
            print(f"Skipping {file}, output already exists")
            with open("logs/skip_extract_m3.txt", "a", encoding="utf-8") as f:
                f.write(file + "\n")
            continue

        try:
            with open(file, "r", encoding="utf-8") as f:
                item = f.read()
                if not item.startswith("ticks_per_beat"):
                    item = item.replace("L:1/8\n", "")
                with torch.no_grad():
                    features = extract_feature(item).unsqueeze(0)
                np.save(output_file, features.detach().cpu().numpy())
                with open("logs/pass_extract_m3.txt", "a", encoding="utf-8") as f:
                    f.write(file + "\n")
        except Exception as e:
            print(f"Failed to process {file}: {e}")
            with open("logs/log_extract_m3.txt", "a", encoding="utf-8") as f:
                f.write(f"Failed to process {file}: {e}\n")

# Load shuffled files list and start processing
with open("logs/files_shuffle_extract_m3.json", "r", encoding="utf-8") as f:
    files = json.load(f)

# Process the directory
process_directory(input_dir, output_dir, files)

with open("logs/log_extract_m3.txt", "a", encoding="utf-8") as f:
    f.write("GPU ID: " + str(device) + "\n")