updated generation process - epsilon
Browse files- familizer.py +0 -1
- generate.py +62 -29
- generation_utils.py +35 -6
familizer.py
CHANGED
@@ -115,7 +115,6 @@ class Familizer:
|
|
115 |
|
116 |
|
117 |
if __name__ == "__main__":
|
118 |
-
|
119 |
# Choose number of jobs for parallel processing
|
120 |
n_jobs = -1
|
121 |
|
|
|
115 |
|
116 |
|
117 |
if __name__ == "__main__":
|
|
|
118 |
# Choose number of jobs for parallel processing
|
119 |
n_jobs = -1
|
120 |
|
generate.py
CHANGED
@@ -1,8 +1,5 @@
|
|
1 |
from generation_utils import *
|
2 |
-
|
3 |
-
from load import LoadModel
|
4 |
-
from decoder import TextDecoder
|
5 |
-
from playback import get_music
|
6 |
|
7 |
|
8 |
class GenerateMidiText:
|
@@ -100,15 +97,26 @@ class GenerateMidiText:
|
|
100 |
text = text.rstrip(" ").rstrip("TRACK_END")
|
101 |
return text
|
102 |
|
103 |
-
def get_last_generated_track(self,
|
104 |
-
track
|
105 |
-
|
106 |
-
+ self.striping_track_ends(full_piece.split("TRACK_START ")[-1])
|
107 |
-
+ "TRACK_END "
|
108 |
-
) # forcing the space after track and
|
109 |
return track
|
110 |
|
111 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
text = ""
|
113 |
for bar in self.piece_by_track[track_id]["bars"]:
|
114 |
text += bar
|
@@ -122,18 +130,12 @@ class GenerateMidiText:
|
|
122 |
def get_whole_piece_from_bar_dict(self):
|
123 |
text = "PIECE_START "
|
124 |
for track_id, _ in enumerate(self.piece_by_track):
|
125 |
-
text += self.
|
126 |
return text
|
127 |
|
128 |
-
def delete_one_track(self, track):
|
129 |
self.piece_by_track.pop(track)
|
130 |
|
131 |
-
# def update_piece_dict__add_track(self, track_id, track):
|
132 |
-
# self.piece_dict[track_id] = track
|
133 |
-
|
134 |
-
# def update_all_dictionnaries__add_track(self, track):
|
135 |
-
# self.update_piece_dict__add_track(track_id, track)
|
136 |
-
|
137 |
"""Basic generation tools"""
|
138 |
|
139 |
def tokenize_input_prompt(self, input_prompt, verbose=True):
|
@@ -238,10 +240,12 @@ class GenerateMidiText:
|
|
238 |
)
|
239 |
else:
|
240 |
print('"--- Wrong length - Regenerating ---')
|
|
|
241 |
if not bar_count_checks:
|
242 |
failed += 1
|
243 |
-
|
244 |
-
|
|
|
245 |
|
246 |
return full_piece
|
247 |
|
@@ -298,8 +302,7 @@ class GenerateMidiText:
|
|
298 |
|
299 |
""" Piece generation - Extra Bars """
|
300 |
|
301 |
-
|
302 |
-
def process_prompt_for_next_bar(self, track_idx):
|
303 |
"""Processing the prompt for the model to generate one more bar only.
|
304 |
The prompt containts:
|
305 |
if not the first bar: the previous, already processed, bars of the track
|
@@ -318,6 +321,10 @@ class GenerateMidiText:
|
|
318 |
if i != track_idx:
|
319 |
len_diff = len(othertrack["bars"]) - len(track["bars"])
|
320 |
if len_diff > 0:
|
|
|
|
|
|
|
|
|
321 |
# if other bars are longer, it mean that this one should catch up
|
322 |
pre_promt += othertrack["bars"][0]
|
323 |
for bar in track["bars"][-self.model_n_bar :]:
|
@@ -325,7 +332,7 @@ class GenerateMidiText:
|
|
325 |
pre_promt += "TRACK_END "
|
326 |
elif (
|
327 |
False
|
328 |
-
): # len_diff <= 0: # THIS DOES NOT WORK - It just
|
329 |
# adding an empty bars at the end of the other tracks if they have not been processed yet
|
330 |
pre_promt += othertracks["bars"][0]
|
331 |
for bar in track["bars"][-(self.model_n_bar - 1) :]:
|
@@ -337,27 +344,54 @@ class GenerateMidiText:
|
|
337 |
# for the bar to prolong
|
338 |
# initialization e.g TRACK_START INST=DRUMS DENSITY=2
|
339 |
processed_prompt = track["bars"][0]
|
|
|
|
|
|
|
|
|
340 |
for bar in track["bars"][-(self.model_n_bar - 1) :]:
|
341 |
# adding the "last" bars of the track
|
342 |
processed_prompt += bar
|
343 |
|
344 |
processed_prompt += "BAR_START "
|
|
|
|
|
|
|
|
|
345 |
print(
|
346 |
f"--- prompt length = {len((pre_promt + processed_prompt).split(' '))} ---"
|
347 |
)
|
|
|
348 |
return pre_promt + processed_prompt
|
349 |
|
350 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
"""Generate one more bar from the input_prompt"""
|
352 |
-
processed_prompt = self.process_prompt_for_next_bar(
|
|
|
353 |
prompt_plus_bar = self.generate_until_track_end(
|
354 |
input_prompt=processed_prompt,
|
355 |
-
temperature=self.piece_by_track[
|
356 |
expected_length=1,
|
357 |
verbose=False,
|
358 |
)
|
359 |
added_bar = self.get_newly_generated_bar(prompt_plus_bar)
|
360 |
-
self.update_track_dict__add_bars(added_bar,
|
361 |
|
362 |
def get_newly_generated_bar(self, prompt_plus_bar):
|
363 |
return "BAR_START " + self.striping_track_ends(
|
@@ -380,7 +414,6 @@ class GenerateMidiText:
|
|
380 |
self.check_the_piece_for_errors()
|
381 |
|
382 |
def check_the_piece_for_errors(self, piece: str = None):
|
383 |
-
|
384 |
if piece is None:
|
385 |
piece = self.get_whole_piece_from_bar_dict()
|
386 |
errors = []
|
|
|
1 |
from generation_utils import *
|
2 |
+
import random
|
|
|
|
|
|
|
3 |
|
4 |
|
5 |
class GenerateMidiText:
|
|
|
97 |
text = text.rstrip(" ").rstrip("TRACK_END")
|
98 |
return text
|
99 |
|
100 |
+
def get_last_generated_track(self, piece):
|
101 |
+
"""Get the last track from a piece written as a single long string"""
|
102 |
+
track = self.get_tracks_from_a_piece(piece)[-1]
|
|
|
|
|
|
|
103 |
return track
|
104 |
|
105 |
+
def get_tracks_from_a_piece(self, piece):
|
106 |
+
"""Get all the tracks from a piece written as a single long string"""
|
107 |
+
all_tracks = [
|
108 |
+
"TRACK_START " + the_track + "TRACK_END "
|
109 |
+
for the_track in self.striping_track_ends(piece.split("TRACK_START ")[1::])
|
110 |
+
]
|
111 |
+
return all_tracks
|
112 |
+
|
113 |
+
def get_piece_from_track_list(self, track_list):
|
114 |
+
piece = "PIECE_START "
|
115 |
+
for track in track_list:
|
116 |
+
piece += track
|
117 |
+
return piece
|
118 |
+
|
119 |
+
def get_whole_track_from_bar_dict(self, track_id):
|
120 |
text = ""
|
121 |
for bar in self.piece_by_track[track_id]["bars"]:
|
122 |
text += bar
|
|
|
130 |
def get_whole_piece_from_bar_dict(self):
|
131 |
text = "PIECE_START "
|
132 |
for track_id, _ in enumerate(self.piece_by_track):
|
133 |
+
text += self.get_whole_track_from_bar_dict(track_id)
|
134 |
return text
|
135 |
|
136 |
+
def delete_one_track(self, track):
|
137 |
self.piece_by_track.pop(track)
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
"""Basic generation tools"""
|
140 |
|
141 |
def tokenize_input_prompt(self, input_prompt, verbose=True):
|
|
|
240 |
)
|
241 |
else:
|
242 |
print('"--- Wrong length - Regenerating ---')
|
243 |
+
|
244 |
if not bar_count_checks:
|
245 |
failed += 1
|
246 |
+
|
247 |
+
if failed > 2:
|
248 |
+
bar_count_checks = True # exit the while loop if failed too much
|
249 |
|
250 |
return full_piece
|
251 |
|
|
|
302 |
|
303 |
""" Piece generation - Extra Bars """
|
304 |
|
305 |
+
def process_prompt_for_next_bar(self, track_idx, verbose=True):
|
|
|
306 |
"""Processing the prompt for the model to generate one more bar only.
|
307 |
The prompt containts:
|
308 |
if not the first bar: the previous, already processed, bars of the track
|
|
|
321 |
if i != track_idx:
|
322 |
len_diff = len(othertrack["bars"]) - len(track["bars"])
|
323 |
if len_diff > 0:
|
324 |
+
if verbose:
|
325 |
+
print(
|
326 |
+
f"Adding bars - {len(track['bars'][-self.model_n_bar :])} selected from SIDE track: {i} for prompt"
|
327 |
+
)
|
328 |
# if other bars are longer, it mean that this one should catch up
|
329 |
pre_promt += othertrack["bars"][0]
|
330 |
for bar in track["bars"][-self.model_n_bar :]:
|
|
|
332 |
pre_promt += "TRACK_END "
|
333 |
elif (
|
334 |
False
|
335 |
+
): # len_diff <= 0: # THIS DOES NOT WORK - It just adds empty bars
|
336 |
# adding an empty bars at the end of the other tracks if they have not been processed yet
|
337 |
pre_promt += othertracks["bars"][0]
|
338 |
for bar in track["bars"][-(self.model_n_bar - 1) :]:
|
|
|
344 |
# for the bar to prolong
|
345 |
# initialization e.g TRACK_START INST=DRUMS DENSITY=2
|
346 |
processed_prompt = track["bars"][0]
|
347 |
+
if verbose:
|
348 |
+
print(
|
349 |
+
f"Adding bars - {len(track['bars'][-(self.model_n_bar - 1) :])} selected from MAIN track: {track_idx} for prompt"
|
350 |
+
)
|
351 |
for bar in track["bars"][-(self.model_n_bar - 1) :]:
|
352 |
# adding the "last" bars of the track
|
353 |
processed_prompt += bar
|
354 |
|
355 |
processed_prompt += "BAR_START "
|
356 |
+
|
357 |
+
# making the preprompt short enought to avoid bug due to length of the prompt (model limitation)
|
358 |
+
pre_promt = self.force_prompt_length(pre_promt, 1500)
|
359 |
+
|
360 |
print(
|
361 |
f"--- prompt length = {len((pre_promt + processed_prompt).split(' '))} ---"
|
362 |
)
|
363 |
+
|
364 |
return pre_promt + processed_prompt
|
365 |
|
366 |
+
def force_prompt_length(self, prompt, expected_length):
|
367 |
+
"""remove one instrument/track from the prompt it too long
|
368 |
+
Args:
|
369 |
+
prompt (str): the prompt to be processed
|
370 |
+
expected_length (int): the expected length of the prompt
|
371 |
+
Returns:
|
372 |
+
the truncated prompt"""
|
373 |
+
if len(prompt.split(" ")) < expected_length:
|
374 |
+
truncated_prompt = prompt
|
375 |
+
else:
|
376 |
+
tracks = self.get_tracks_from_a_piece(prompt)
|
377 |
+
selected_tracks = random.sample(tracks, len(tracks) - 1)
|
378 |
+
truncated_prompt = self.get_piece_from_track_list(selected_tracks)
|
379 |
+
print(f"Prompt too long - deleting one track")
|
380 |
+
|
381 |
+
return truncated_prompt
|
382 |
+
|
383 |
+
def generate_one_more_bar(self, track_index):
|
384 |
"""Generate one more bar from the input_prompt"""
|
385 |
+
processed_prompt = self.process_prompt_for_next_bar(track_index)
|
386 |
+
|
387 |
prompt_plus_bar = self.generate_until_track_end(
|
388 |
input_prompt=processed_prompt,
|
389 |
+
temperature=self.piece_by_track[track_index]["temperature"],
|
390 |
expected_length=1,
|
391 |
verbose=False,
|
392 |
)
|
393 |
added_bar = self.get_newly_generated_bar(prompt_plus_bar)
|
394 |
+
self.update_track_dict__add_bars(added_bar, track_index)
|
395 |
|
396 |
def get_newly_generated_bar(self, prompt_plus_bar):
|
397 |
return "BAR_START " + self.striping_track_ends(
|
|
|
414 |
self.check_the_piece_for_errors()
|
415 |
|
416 |
def check_the_piece_for_errors(self, piece: str = None):
|
|
|
417 |
if piece is None:
|
418 |
piece = self.get_whole_piece_from_bar_dict()
|
419 |
errors = []
|
generation_utils.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import numpy as np
|
3 |
import matplotlib.pyplot as plt
|
4 |
import matplotlib
|
|
|
5 |
|
6 |
from constants import INSTRUMENT_CLASSES
|
7 |
from playback import get_music, show_piano_roll
|
@@ -14,11 +15,38 @@ matplotlib.rcParams["axes.facecolor"] = "none"
|
|
14 |
matplotlib.rcParams["axes.edgecolor"] = "grey"
|
15 |
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
def bar_count_check(sequence, n_bars):
|
@@ -64,7 +92,8 @@ def check_if_prompt_density_in_tokenizer_vocab(tokenizer, density_prompt_list):
|
|
64 |
|
65 |
def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
|
66 |
"""Forcing the generated sequence to have the expected length
|
67 |
-
expected_length and bar_count refers to the length of newly_generated_only (without input prompt)
|
|
|
68 |
|
69 |
if bar_count - expected_length > 0: # Cut the sequence if too long
|
70 |
full_piece = ""
|
|
|
2 |
import numpy as np
|
3 |
import matplotlib.pyplot as plt
|
4 |
import matplotlib
|
5 |
+
from utils import writeToFile, get_datetime
|
6 |
|
7 |
from constants import INSTRUMENT_CLASSES
|
8 |
from playback import get_music, show_piano_roll
|
|
|
15 |
matplotlib.rcParams["axes.edgecolor"] = "grey"
|
16 |
|
17 |
|
18 |
+
class WriteTextMidiToFile: # utils saving miditext from teh class GenerateMidiText to file
|
19 |
+
def __init__(self, generate_midi, output_path):
|
20 |
+
self.generated_midi = generate_midi.generated_piece
|
21 |
+
self.output_path = output_path
|
22 |
+
self.hyperparameter_and_bars = generate_midi.piece_by_track
|
23 |
+
|
24 |
+
def hashing_seq(self):
|
25 |
+
self.current_time = get_datetime()
|
26 |
+
self.output_path_filename = f"{self.output_path}/{self.current_time}.json"
|
27 |
+
|
28 |
+
def wrapping_seq_hyperparameters_in_dict(self):
|
29 |
+
# assert type(self.generated_midi) is str, "error: generate_midi must be a string"
|
30 |
+
# assert (
|
31 |
+
# type(self.hyperparameter_dict) is dict
|
32 |
+
# ), "error: feature_dict must be a dictionnary"
|
33 |
+
return {
|
34 |
+
"generated_midi": self.generated_midi,
|
35 |
+
"hyperparameters_and_bars": self.hyperparameter_and_bars,
|
36 |
+
}
|
37 |
+
|
38 |
+
def text_midi_to_file(self):
|
39 |
+
self.hashing_seq()
|
40 |
+
output_dict = self.wrapping_seq_hyperparameters_in_dict()
|
41 |
+
print(f"Token generate_midi written: {self.output_path_filename}")
|
42 |
+
writeToFile(self.output_path_filename, output_dict)
|
43 |
+
return self.output_path_filename
|
44 |
+
|
45 |
+
|
46 |
+
def define_generation_dir(generation_dir):
|
47 |
+
if not os.path.exists(generation_dir):
|
48 |
+
os.makedirs(generation_dir)
|
49 |
+
return generation_dir
|
50 |
|
51 |
|
52 |
def bar_count_check(sequence, n_bars):
|
|
|
92 |
|
93 |
def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
|
94 |
"""Forcing the generated sequence to have the expected length
|
95 |
+
expected_length and bar_count refers to the length of newly_generated_only (without input prompt)
|
96 |
+
"""
|
97 |
|
98 |
if bar_count - expected_length > 0: # Cut the sequence if too long
|
99 |
full_piece = ""
|