sander-wood commited on
Commit
0e3e69a
1 Parent(s): 3198fcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -314
app.py CHANGED
@@ -10,17 +10,6 @@ from samplings import top_p_sampling, top_k_sampling, temperature_sampling
10
  from transformers import GPT2Config, GPT2Model, GPT2LMHeadModel, PreTrainedModel
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- PATCH_LENGTH = 128 # Patch Length
14
- PATCH_SIZE = 32 # Patch Size
15
-
16
- PATCH_NUM_LAYERS = 9 # Number of layers in the encoder
17
- CHAR_NUM_LAYERS = 3 # Number of layers in the decoder
18
-
19
- NUM_EPOCHS = 32 # Number of epochs to train for (if early stopping doesn't intervene)
20
- LEARNING_RATE = 5e-5 # Learning rate for the optimizer
21
- PATCH_SAMPLING_BATCH_SIZE = 0 # Batch size for patch during training, 0 for full context
22
- LOAD_FROM_CHECKPOINT = False # Whether to load weights from a checkpoint
23
- SHARE_WEIGHTS = False # Whether to share weights between the encoder and decoder
24
 
25
  description = """
26
  <div>
@@ -61,288 +50,72 @@ A general recommendation is to adjust the desired musical structure and form thr
61
  Please make sure to operate according to the provided formats and guidelines to fully leverage the capabilities of TunesFormer and achieve a satisfying music generation experience.
62
  """
63
 
64
- class Patchilizer:
65
- """
66
- A class for converting music bars to patches and vice versa.
67
- """
68
  def __init__(self):
69
- self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
70
- self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
71
  self.pad_token_id = 0
72
- self.bos_token_id = 1
73
- self.eos_token_id = 2
 
74
 
75
- def split_bars(self, body):
76
- """
77
- Split a body of music into individual bars.
78
- """
79
- bars = re.split(self.regexPattern, ''.join(body))
80
- bars = list(filter(None, bars)) # remove empty strings
81
- if bars[0] in self.delimiters:
82
- bars[1] = bars[0] + bars[1]
83
- bars = bars[1:]
84
- bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)]
85
- return bars
86
 
87
- def bar2patch(self, bar, patch_size=PATCH_SIZE):
88
- """
89
- Convert a bar into a patch of specified length.
90
- """
91
- patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id]
92
- patch = patch[:patch_size]
93
- patch += [self.pad_token_id] * (patch_size - len(patch))
94
- return patch
95
-
96
- def patch2bar(self, patch):
97
- """
98
- Convert a patch into a bar.
99
- """
100
- return ''.join(chr(idx) if idx > self.eos_token_id else '' for idx in patch if idx != self.eos_token_id)
101
-
102
- def encode(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=False):
103
- """
104
- Encode music into patches of specified length.
105
- """
106
- lines = unidecode(abc_code).split('\n')
107
- lines = list(filter(None, lines)) # remove empty lines
108
-
109
- body = ""
110
- patches = []
111
-
112
- for line in lines:
113
- if len(line) > 1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%score')):
114
- if body:
115
- bars = self.split_bars(body)
116
- patches.extend(self.bar2patch(bar + '\n' if idx == len(bars) - 1 else bar, patch_size)
117
- for idx, bar in enumerate(bars))
118
- body = ""
119
- patches.append(self.bar2patch(line + '\n', patch_size))
120
- else:
121
- body += line + '\n'
122
-
123
- if body:
124
- patches.extend(self.bar2patch(bar, patch_size) for bar in self.split_bars(body))
125
-
126
- if add_special_patches:
127
- bos_patch = [self.bos_token_id] * (patch_size-1) + [self.eos_token_id]
128
- eos_patch = [self.bos_token_id] + [self.eos_token_id] * (patch_size-1)
129
- patches = [bos_patch] + patches + [eos_patch]
130
-
131
- return patches[:patch_length]
132
-
133
- def decode(self, patches):
134
- """
135
- Decode patches into music.
136
- """
137
- return ''.join(self.patch2bar(patch) for patch in patches)
138
-
139
- class PatchLevelDecoder(PreTrainedModel):
140
- """
141
- An Patch-level Decoder model for generating patch features in an auto-regressive manner.
142
- It inherits PreTrainedModel from transformers.
143
- """
144
-
145
- def __init__(self, config):
146
- super().__init__(config)
147
- self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
148
- torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
149
- self.base = GPT2Model(config)
150
-
151
- def forward(self, patches: torch.Tensor) -> torch.Tensor:
152
- """
153
- The forward pass of the patch-level decoder model.
154
- :param patches: the patches to be encoded
155
- :return: the encoded patches
156
- """
157
- patches = torch.nn.functional.one_hot(patches, num_classes=128).float()
158
- patches = patches.reshape(len(patches), -1, PATCH_SIZE * 128)
159
- patches = self.patch_embedding(patches.to(self.device))
160
-
161
- return self.base(inputs_embeds=patches)
162
-
163
- class CharLevelDecoder(PreTrainedModel):
164
- """
165
- A Char-level Decoder model for generating the characters within each bar patch sequentially.
166
- It inherits PreTrainedModel from transformers.
167
- """
168
- def __init__(self, config):
169
- super().__init__(config)
170
- self.pad_token_id = 0
171
- self.bos_token_id = 1
172
- self.eos_token_id = 2
173
- self.base = GPT2LMHeadModel(config)
174
-
175
- def forward(self, encoded_patches: torch.Tensor, target_patches: torch.Tensor, patch_sampling_batch_size: int):
176
- """
177
- The forward pass of the char-level decoder model.
178
- :param encoded_patches: the encoded patches
179
- :param target_patches: the target patches
180
- :return: the decoded patches
181
- """
182
- # preparing the labels for model training
183
- target_masks = target_patches == self.pad_token_id
184
- labels = target_patches.clone().masked_fill_(target_masks, -100)
185
-
186
- # masking the labels for model training
187
- target_masks = torch.ones_like(labels)
188
- target_masks = target_masks.masked_fill_(labels == -100, 0)
189
-
190
- # select patches
191
- if patch_sampling_batch_size!=0 and patch_sampling_batch_size<target_patches.shape[0]:
192
- indices = list(range(len(target_patches)))
193
- random.shuffle(indices)
194
- selected_indices = sorted(indices[:patch_sampling_batch_size])
195
-
196
- target_patches = target_patches[selected_indices,:]
197
- target_masks = target_masks[selected_indices,:]
198
- encoded_patches = encoded_patches[selected_indices,:]
199
- labels = labels[selected_indices,:]
200
-
201
- # get input embeddings
202
- inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
203
-
204
- # concatenate the encoded patches with the input embeddings
205
- inputs_embeds = torch.cat((encoded_patches.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
206
-
207
- return self.base(inputs_embeds=inputs_embeds,
208
- attention_mask=target_masks,
209
- labels=labels)
210
-
211
- def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor):
212
- """
213
- The generate function for generating a patch based on the encoded patch and already generated tokens.
214
- :param encoded_patch: the encoded patch
215
- :param tokens: already generated tokens in the patch
216
- :return: the probability distribution of next token
217
- """
218
- encoded_patch = encoded_patch.reshape(1, 1, -1)
219
- tokens = tokens.reshape(1, -1)
220
-
221
- # Get input embeddings
222
- tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
223
-
224
- # Concatenate the encoded patch with the input embeddings
225
- tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)
226
-
227
- # Get output from model
228
- outputs = self.base(inputs_embeds=tokens)
229
-
230
- # Get probabilities of next token
231
- probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
232
-
233
- return probs
234
-
235
- class TunesFormer(PreTrainedModel):
236
- """
237
- TunesFormer is a hierarchical music generation model based on bar patching.
238
- It includes a patch-level decoder and a character-level decoder.
239
- It inherits PreTrainedModel from transformers.
240
- """
241
-
242
- def __init__(self, encoder_config, decoder_config, share_weights=False):
243
- super().__init__(encoder_config)
244
- self.pad_token_id = 0
245
- self.bos_token_id = 1
246
- self.eos_token_id = 2
247
- if share_weights:
248
- max_layers = max(encoder_config.num_hidden_layers, decoder_config.num_hidden_layers)
249
- max_context_size = max(encoder_config.max_length, decoder_config.max_length)
250
- max_position_embeddings = max(encoder_config.max_position_embeddings, decoder_config.max_position_embeddings)
251
-
252
- encoder_config.num_hidden_layers = max_layers
253
- encoder_config.max_length = max_context_size
254
- encoder_config.max_position_embeddings = max_position_embeddings
255
- decoder_config.num_hidden_layers = max_layers
256
- decoder_config.max_length = max_context_size
257
- decoder_config.max_position_embeddings = max_position_embeddings
258
-
259
- self.patch_level_decoder = PatchLevelDecoder(encoder_config)
260
- self.char_level_decoder = CharLevelDecoder(decoder_config)
261
-
262
- if share_weights:
263
- self.patch_level_decoder.base = self.char_level_decoder.base.transformer
264
-
265
- def forward(self, patches: torch.Tensor, patch_sampling_batch_size: int=PATCH_SAMPLING_BATCH_SIZE):
266
- """
267
- The forward pass of the TunesFormer model.
268
- :param patches: the patches to be both encoded and decoded
269
- :return: the decoded patches
270
- """
271
- patches = patches.reshape(len(patches), -1, PATCH_SIZE)
272
- encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
273
-
274
- return self.char_level_decoder(encoded_patches.squeeze(0)[:-1, :], patches.squeeze(0)[1:, :], patch_sampling_batch_size)
275
-
276
- def generate(self, patches: torch.Tensor,
277
- tokens: torch.Tensor,
278
- top_p: float=1,
279
- top_k: int=0,
280
- temperature: float=1,
281
- seed: int=None):
282
- """
283
- The generate function for generating patches based on patches.
284
- :param patches: the patches to be encoded
285
- :return: the generated patches
286
- """
287
- patches = patches.reshape(len(patches), -1, PATCH_SIZE)
288
- encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
289
- if tokens==None:
290
- tokens = torch.tensor([self.bos_token_id], device=self.device)
291
- generated_patch = []
292
- random.seed(seed)
293
-
294
- while True:
295
- if seed!=None:
296
- n_seed = random.randint(0, 1000000)
297
- random.seed(n_seed)
298
- else:
299
- n_seed = None
300
- prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy()
301
- prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
302
- prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
303
- token = temperature_sampling(prob, temperature=temperature, seed=n_seed)
304
- generated_patch.append(token)
305
- if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1:
306
- break
307
- else:
308
- tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)
309
 
310
- return generated_patch, n_seed
 
 
311
 
312
  def generate_abc(prompt,
313
  num_tunes,
314
- max_patch,
315
  top_p,
316
  top_k,
317
  temperature,
318
- seed,
319
- show_control_code=True):
320
 
321
  if torch.cuda.is_available():
322
  device = torch.device("cuda")
323
  else:
324
  device = torch.device("cpu")
325
 
326
- patchilizer = Patchilizer()
 
 
327
 
328
- patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
329
- max_length=PATCH_LENGTH,
330
- max_position_embeddings=PATCH_LENGTH,
331
- vocab_size=1)
332
- char_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
333
- max_length=PATCH_SIZE,
334
- max_position_embeddings=PATCH_SIZE,
335
- vocab_size=128)
336
- model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
337
-
338
- filename = "weights.pth"
339
 
340
  if os.path.exists(filename):
341
  print(f"Weights already exist at '{filename}'. Loading...")
342
  else:
343
  print(f"Downloading weights to '{filename}' from huggingface.co...")
344
  try:
345
- url = 'https://huggingface.co/sander-wood/tunesformer/resolve/main/weights.pth'
346
  response = requests.get(url, stream=True)
347
 
348
  total_size = int(response.headers.get('content-length', 0))
@@ -362,59 +135,43 @@ def generate_abc(prompt,
362
  print(f"Error: {e}")
363
  exit()
364
 
365
- checkpoint = torch.load('weights.pth')
366
- model.load_state_dict(checkpoint['model'])
367
- model = model.to(device)
368
  model.eval()
369
 
370
  tunes = ""
 
 
 
 
 
371
 
372
- print("\n"+" OUTPUT TUNES ".center(60, "#"))
373
 
 
 
374
  for i in range(num_tunes):
375
  tune = "X:"+str(i+1) + "\n" + prompt
376
- lines = re.split(r'(\n)', tune)
377
- tune = ""
378
- skip = False
379
- for line in lines:
380
- if show_control_code or line[:2] not in ["S:", "B:", "E:"]:
381
- if not skip:
382
- print(line, end="")
383
- tune += line
384
- skip = False
385
  else:
386
- skip = True
387
-
388
- input_patches = torch.tensor([patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=device)
389
- if tune=="":
390
- tokens = None
391
- else:
392
- prefix = patchilizer.decode(input_patches[0])
393
- remaining_tokens = prompt[len(prefix):]
394
- tokens = torch.tensor([patchilizer.bos_token_id]+[ord(c) for c in remaining_tokens], device=device)
395
-
396
- while input_patches.shape[1]<max_patch:
397
- predicted_patch, seed = model.generate(input_patches,
398
- tokens,
399
- top_p=top_p,
400
- top_k=top_k,
401
- temperature=temperature,
402
- seed=seed)
403
- tokens = None
404
- if predicted_patch[0]!=patchilizer.eos_token_id:
405
- next_bar = patchilizer.decode([predicted_patch])
406
- if show_control_code or next_bar[:2] not in ["S:", "B:", "E:"]:
407
- print(next_bar, end="")
408
- tune += next_bar
409
- if next_bar=="":
410
- break
411
- next_bar = remaining_tokens+next_bar
412
- remaining_tokens = ""
413
- predicted_patch = torch.tensor(patchilizer.bar2patch(next_bar), device=device).unsqueeze(0)
414
- input_patches = torch.cat([input_patches, predicted_patch.unsqueeze(0)], dim=1)
415
  else:
416
  break
417
-
418
  tunes += tune+"\n\n"
419
 
420
  return tunes
@@ -430,7 +187,7 @@ K:D
430
 
431
  input_prompt = gr.inputs.Textbox(lines=5, label="Prompt", default=default_prompt)
432
  input_num_tunes = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Tunes")
433
- input_max_patch = gr.inputs.Slider(minimum=10, maximum=128, step=1, default=128, label="Max Patch")
434
  input_top_p = gr.inputs.Slider(minimum=0.0, maximum=1.0, step=0.05, default=0.8, label="Top P")
435
  input_top_k = gr.inputs.Slider(minimum=1, maximum=20, step=1, default=8, label="Top K")
436
  input_temperature = gr.inputs.Slider(minimum=0.0, maximum=2.0, step=0.05, default=1.2, label="Temperature")
@@ -438,7 +195,7 @@ input_seed = gr.inputs.Textbox(lines=1, label="Seed (int)", default="None")
438
  output_abc = gr.outputs.Textbox(label="Generated Tunes")
439
 
440
  gr.Interface(generate_abc,
441
- [input_prompt, input_num_tunes, input_max_patch, input_top_p, input_top_k, input_temperature, input_seed],
442
  output_abc,
443
  title="TunesFormer: Forming Irish Tunes with Control Codes by Bar Patching",
444
  description=description).launch()
 
10
  from transformers import GPT2Config, GPT2Model, GPT2LMHeadModel, PreTrainedModel
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  description = """
15
  <div>
 
50
  Please make sure to operate according to the provided formats and guidelines to fully leverage the capabilities of TunesFormer and achieve a satisfying music generation experience.
51
  """
52
 
53
+ class ABCTokenizer():
 
 
 
54
  def __init__(self):
 
 
55
  self.pad_token_id = 0
56
+ self.bos_token_id = 2
57
+ self.eos_token_id = 3
58
+ self.merged_tokens = []
59
 
60
+ def __len__(self):
61
+ return 128+len(self.merged_tokens)
 
 
 
 
 
 
 
 
 
62
 
63
+ def encode(self, text):
64
+ encodings = {}
65
+ encodings['input_ids'] = torch.tensor(self.txt2ids(text, self.merged_tokens))
66
+ encodings['attention_mask'] = torch.tensor([1]*len(encodings['input_ids']))
67
+ return encodings
68
+
69
+ def decode(self, ids, skip_special_tokens=False):
70
+ txt = ""
71
+ for i in ids:
72
+ if i>=128:
73
+ if not skip_special_tokens:
74
+ txt += self.merged_tokens[i-128]
75
+ elif i!=2 and i!=3:
76
+ txt += chr(i)
77
+ return txt
78
+
79
+ def txt2ids(self, text, merged_tokens):
80
+ text = unidecode(text)
81
+ ids = [str(ord(c)) for c in text]
82
+ if torch.max(torch.tensor([ord(c) for c in text]))>=128:
83
+ return [128+len(self.merged_tokens)]
84
+ txt_ids = ' '.join(ids)
85
+ for t_idx, token in enumerate(merged_tokens):
86
+ token_ids = [str(ord(c)) for c in token]
87
+ token_txt_ids = ' '.join(token_ids)
88
+ txt_ids = txt_ids.replace(token_txt_ids, str(t_idx+128))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ txt_ids = txt_ids.split(' ')
91
+ txt_ids = [int(i) for i in txt_ids]
92
+ return [self.bos_token_id]+txt_ids+[self.eos_token_id]
93
 
94
  def generate_abc(prompt,
95
  num_tunes,
96
+ max_length,
97
  top_p,
98
  top_k,
99
  temperature,
100
+ seed):
 
101
 
102
  if torch.cuda.is_available():
103
  device = torch.device("cuda")
104
  else:
105
  device = torch.device("cpu")
106
 
107
+ tokenizer = ABCTokenizer()
108
+ config = GPT2Config(vocab_size=len(tokenizer))
109
+ model = GPT2LMHeadModel(config).to(device)
110
 
111
+ filename = "pytorch_model.bin"
 
 
 
 
 
 
 
 
 
 
112
 
113
  if os.path.exists(filename):
114
  print(f"Weights already exist at '{filename}'. Loading...")
115
  else:
116
  print(f"Downloading weights to '{filename}' from huggingface.co...")
117
  try:
118
+ url = 'https://huggingface.co/sander-wood/tunesformer/resolve/main/pytorch_model.bin'
119
  response = requests.get(url, stream=True)
120
 
121
  total_size = int(response.headers.get('content-length', 0))
 
135
  print(f"Error: {e}")
136
  exit()
137
 
138
+ model.load_state_dict(torch.load('pytorch_model.bin'))
 
 
139
  model.eval()
140
 
141
  tunes = ""
142
+
143
+ if prompt:
144
+ ids = tokenizer.encode(prompt)['input_ids'][:-1]
145
+ else:
146
+ ids = torch.tensor([tokenizer.bos_token_id])
147
 
148
+ random.seed(seed)
149
 
150
+ print("\n"+" OUTPUT TUNES ".center(60, "#"))
151
+
152
  for i in range(num_tunes):
153
  tune = "X:"+str(i+1) + "\n" + prompt
154
+ print(tune, end="")
155
+ input_ids = ids.unsqueeze(0)
156
+ for t_idx in range(max_length):
157
+ if seed!=None:
158
+ n_seed = random.randint(0, 1000000)
159
+ random.seed(n_seed)
 
 
 
160
  else:
161
+ n_seed = None
162
+ outputs = model(input_ids=input_ids.to(device))
163
+ probs = outputs.logits[0][-1]
164
+ probs = torch.nn.Softmax(dim=-1)(probs).cpu().detach().numpy()
165
+ probs = top_p_sampling(probs, top_p=top_p, return_probs=True)
166
+ probs = top_k_sampling(probs, top_k=top_k, return_probs=True)
167
+ sampled_id = temperature_sampling(probs, temperature=temperature, seed=n_seed)
168
+ input_ids = torch.cat((input_ids, torch.tensor([[sampled_id]])), 1)
169
+ if sampled_id!=3:
170
+ tune += tokenizer.decode([sampled_id], skip_special_tokens=True)
171
+ print(tune[-1], end="")
172
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  else:
174
  break
 
175
  tunes += tune+"\n\n"
176
 
177
  return tunes
 
187
 
188
  input_prompt = gr.inputs.Textbox(lines=5, label="Prompt", default=default_prompt)
189
  input_num_tunes = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Tunes")
190
+ input_max_length = gr.inputs.Slider(minimum=64, maximum=1024, step=1, default=1024, label="Max Length")
191
  input_top_p = gr.inputs.Slider(minimum=0.0, maximum=1.0, step=0.05, default=0.8, label="Top P")
192
  input_top_k = gr.inputs.Slider(minimum=1, maximum=20, step=1, default=8, label="Top K")
193
  input_temperature = gr.inputs.Slider(minimum=0.0, maximum=2.0, step=0.05, default=1.2, label="Temperature")
 
195
  output_abc = gr.outputs.Textbox(label="Generated Tunes")
196
 
197
  gr.Interface(generate_abc,
198
+ [input_prompt, input_num_tunes, input_max_length, input_top_p, input_top_k, input_temperature, input_seed],
199
  output_abc,
200
  title="TunesFormer: Forming Irish Tunes with Control Codes by Bar Patching",
201
  description=description).launch()