File size: 8,859 Bytes
ad822ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import os
import torch
import numpy as np
from tqdm import tqdm
from config import *
from utils import *
from samplings import *
from accelerate import Accelerator
from transformers import BertConfig, AutoTokenizer
import argparse

# Parse command-line arguments
parser = argparse.ArgumentParser(description="Feature extraction for CLaMP3.")
parser.add_argument("--epoch", type=str, default=None, help="Epoch of the checkpoint to load.")
parser.add_argument("input_dir", type=str, help="Directory containing input data files.")
parser.add_argument("output_dir", type=str, help="Directory to save the output features.")
parser.add_argument("--get_global", action="store_true", help="Get global feature.")

args = parser.parse_args()

# Retrieve arguments
epoch = args.epoch
input_dir = args.input_dir
output_dir = args.output_dir
get_global = args.get_global

files = []
for root, dirs, fs in os.walk(input_dir):
    for f in fs:
        if f.endswith(".txt") or f.endswith(".abc") or f.endswith(".mtf") or f.endswith(".npy"):
            files.append(os.path.join(root, f))

print(f"Found {len(files)} files in total")

# Initialize accelerator and device
accelerator = Accelerator()
device = accelerator.device
print("Using device:", device)

# Model and configuration setup
audio_config = BertConfig(vocab_size=1,
                        hidden_size=AUDIO_HIDDEN_SIZE,
                        num_hidden_layers=AUDIO_NUM_LAYERS,
                        num_attention_heads=AUDIO_HIDDEN_SIZE//64,
                        intermediate_size=AUDIO_HIDDEN_SIZE*4,
                        max_position_embeddings=MAX_AUDIO_LENGTH)
symbolic_config = BertConfig(vocab_size=1,
                            hidden_size=M3_HIDDEN_SIZE,
                            num_hidden_layers=PATCH_NUM_LAYERS,
                            num_attention_heads=M3_HIDDEN_SIZE//64,
                            intermediate_size=M3_HIDDEN_SIZE*4,
                            max_position_embeddings=PATCH_LENGTH)
model = CLaMP3Model(audio_config=audio_config,
                    symbolic_config=symbolic_config,
                    text_model_name=TEXT_MODEL_NAME,
                    hidden_size=CLAMP3_HIDDEN_SIZE,
                    load_m3=CLAMP3_LOAD_M3)
model = model.to(device)

tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
patchilizer = M3Patchilizer()

# print parameter number
print("Total Parameter Number: "+str(sum(p.numel() for p in model.parameters())))

# Load model weights
model.eval()
checkpoint_path = CLAMP3_WEIGHTS_PATH
if epoch is not None:
    checkpoint_path = CLAMP3_WEIGHTS_PATH.replace(".pth", f"_{epoch}.pth")
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
print(f"Successfully Loaded CLaMP 3 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
model.load_state_dict(checkpoint['model'])

def extract_feature(filename, get_global=get_global):
    if not filename.endswith(".npy"):
        with open(filename, "r", encoding="utf-8") as f:
            item = f.read()

    if filename.endswith(".txt"):
        item = list(set(item.split("\n")))
        item = "\n".join(item)
        item = item.split("\n")
        item = [c for c in item if len(c) > 0]
        item = tokenizer.sep_token.join(item)
        input_data = tokenizer(item, return_tensors="pt")
        input_data = input_data['input_ids'].squeeze(0)
        max_input_length = MAX_TEXT_LENGTH
    elif filename.endswith(".abc") or filename.endswith(".mtf"):
        input_data = patchilizer.encode(item, add_special_patches=True)
        input_data = torch.tensor(input_data)
        max_input_length = PATCH_LENGTH
    elif filename.endswith(".npy"):
        input_data = np.load(filename)
        input_data = torch.tensor(input_data)
        input_data = input_data.reshape(-1, input_data.size(-1))
        zero_vec = torch.zeros((1, input_data.size(-1)))
        input_data = torch.cat((zero_vec, input_data, zero_vec), 0)
        max_input_length = MAX_AUDIO_LENGTH
    else:
        raise ValueError(f"Unsupported file type: {filename}, only support .txt, .abc, .mtf, .npy files")

    segment_list = []
    for i in range(0, len(input_data), max_input_length):
        segment_list.append(input_data[i:i+max_input_length])
    segment_list[-1] = input_data[-max_input_length:]

    last_hidden_states_list = []

    for input_segment in segment_list:
        input_masks = torch.tensor([1]*input_segment.size(0))
        if filename.endswith(".txt"):
            pad_indices = torch.ones(MAX_TEXT_LENGTH - input_segment.size(0)).long() * tokenizer.pad_token_id
        elif filename.endswith(".abc") or filename.endswith(".mtf"):
            pad_indices = torch.ones((PATCH_LENGTH - input_segment.size(0), PATCH_SIZE)).long() * patchilizer.pad_token_id
        else:
            pad_indices = torch.ones((MAX_AUDIO_LENGTH - input_segment.size(0), AUDIO_HIDDEN_SIZE)).float() * 0.
        input_masks = torch.cat((input_masks, torch.zeros(max_input_length - input_segment.size(0))), 0)
        input_segment = torch.cat((input_segment, pad_indices), 0)

        if filename.endswith(".txt"):
            last_hidden_states = model.get_text_features(text_inputs=input_segment.unsqueeze(0).to(device),
                                                         text_masks=input_masks.unsqueeze(0).to(device),
                                                         get_global=get_global)
        elif filename.endswith(".abc") or filename.endswith(".mtf"):
            last_hidden_states = model.get_symbolic_features(symbolic_inputs=input_segment.unsqueeze(0).to(device),
                                                          symbolic_masks=input_masks.unsqueeze(0).to(device),
                                                          get_global=get_global)
        else:
            last_hidden_states = model.get_audio_features(audio_inputs=input_segment.unsqueeze(0).to(device),
                                                          audio_masks=input_masks.unsqueeze(0).to(device),
                                                          get_global=get_global)
        if not get_global:
            last_hidden_states = last_hidden_states[:, :input_masks.sum().long().item(), :]
        last_hidden_states_list.append(last_hidden_states)

    if not get_global:
        last_hidden_states_list = [last_hidden_states[0] for last_hidden_states in last_hidden_states_list]
        last_hidden_states_list[-1] = last_hidden_states_list[-1][-(len(input_data)%max_input_length):]
        last_hidden_states_list = torch.concat(last_hidden_states_list, 0)
    else:
        full_chunk_cnt = len(input_data) // max_input_length
        remain_chunk_len = len(input_data) % max_input_length
        if remain_chunk_len == 0:
            feature_weights = torch.tensor([max_input_length] * full_chunk_cnt, device=device).view(-1, 1)
        else:
            feature_weights = torch.tensor([max_input_length] * full_chunk_cnt + [remain_chunk_len], device=device).view(-1, 1)
        
        last_hidden_states_list = torch.concat(last_hidden_states_list, 0)
        last_hidden_states_list = last_hidden_states_list * feature_weights
        last_hidden_states_list = last_hidden_states_list.sum(dim=0) / feature_weights.sum()

    return last_hidden_states_list

def process_directory(input_dir, output_dir, files):
    # calculate the number of files to process per GPU
    num_files_per_gpu = len(files) // accelerator.num_processes

    # calculate the start and end index for the current GPU
    start_idx = accelerator.process_index * num_files_per_gpu
    end_idx = start_idx + num_files_per_gpu
    if accelerator.process_index == accelerator.num_processes - 1:
        end_idx = len(files)

    files_to_process = files[start_idx:end_idx]

    # process the files
    for file in tqdm(files_to_process):
        output_subdir = output_dir + os.path.dirname(file)[len(input_dir):]
        try:
            os.makedirs(output_subdir, exist_ok=True)
        except Exception as e:
            print(output_subdir + " can not be created\n" + str(e))

        output_file = os.path.join(output_subdir, os.path.splitext(os.path.basename(file))[0] + ".npy")

        if os.path.exists(output_file):
            print(f"Skipping {file}, output already exists")
            continue

        try:
            with torch.no_grad():
                features = extract_feature(file).unsqueeze(0)
            np.save(output_file, features.detach().cpu().numpy())
        except Exception as e:
            print(f"Failed to process {file}: {e}")

# process the files
process_directory(input_dir, output_dir, files)