File size: 7,328 Bytes
b6bb234
2a68ddd
3fea55e
 
 
 
b6bb234
2a68ddd
b6bb234
037c7a6
2a68ddd
498b808
2a68ddd
 
 
3908eea
2a68ddd
 
b6bb234
a3e0baa
 
037c7a6
fbc5bf3
 
7b0fbfe
 
606e959
 
6c71f04
606e959
 
 
fbc5bf3
6c71f04
11f0dcb
b6bb234
 
 
 
 
a3e0baa
 
b6bb234
 
a3e0baa
b6bb234
498b808
 
6c71f04
a3e0baa
fbc5bf3
a3e0baa
037c7a6
 
 
 
 
 
 
 
 
ff7362d
037c7a6
 
 
 
 
 
399d36d
037c7a6
 
 
 
 
ff7362d
399d36d
037c7a6
 
 
 
 
 
fbc5bf3
de73f36
 
 
 
 
 
 
 
 
 
 
 
fbc5bf3
 
3fea55e
606e959
fbc5bf3
3fea55e
fbc5bf3
a3e0baa
11f0dcb
de73f36
95675e7
a3e0baa
fbc5bf3
606e959
de73f36
a3e0baa
fbc5bf3
a3e0baa
 
de73f36
a3e0baa
fbc5bf3
a3e0baa
 
f16b04e
 
fbc5bf3
f16b04e
 
 
 
 
fbc5bf3
a3e0baa
6fade3e
fbc5bf3
a3e0baa
6fade3e
 
 
 
 
 
fbc5bf3
 
 
 
 
 
b6bb234
fbc5bf3
6aeacba
86b7652
6aeacba
 
 
 
 
606e959
de73f36
7b0fbfe
 
6c71f04
606e959
 
2ff429b
2a68ddd
a3e0baa
 
 
6c71f04
 
 
606e959
6c71f04
 
 
2a68ddd
fbc5bf3
b6bb234
fbc5bf3
2a68ddd
fbc5bf3
2a68ddd
a3e0baa
 
 
fb9c37a
a3e0baa
 
 
 
aa762cf
a3e0baa
 
2a68ddd
aa762cf
bfd9f77
86b7652
aa762cf
2a68ddd
25f28a8
2a68ddd
25f28a8
2a68ddd
fbc5bf3
 
6aeacba
037c7a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import os.path

import time
import datetime
from pytz import timezone

import torch

import gradio as gr
import spaces

from x_transformer import *
import tqdm

import TMIDIX
from midi_to_colab_audio import midi_to_colab_audio

import matplotlib.pyplot as plt

# =================================================================================================

@spaces.GPU
def GenerateMIDI(num_tok, idrums, iinstr, input_align):
    
    print('=' * 70)
    print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
    start_time = time.time()

    print('-' * 70)
    print('Req num tok:', num_tok)
    print('Req instr:', iinstr)
    print('Drums:', idrums)
    print('Align:', input_align)
    print('-' * 70)

    if idrums:
        drums = 3074
    else:
        drums = 3073

    instruments_list = ["Piano", "Guitar", "Bass", "Violin", "Cello", "Harp", "Trumpet", "Sax", "Flute", 'Drums',
                        "Choir", "Organ"]
    first_note_instrument_number = instruments_list.index(iinstr)

    start_tokens = [3087, drums, 3075 + first_note_instrument_number]

    print('Selected Improv sequence:')
    print(start_tokens)
    print('-' * 70)

    output = []

    print('Loading model...')

    SEQ_LEN = 2048

    # instantiate the model

    model = TransformerWrapper(
        num_tokens=3088,
        max_seq_len=SEQ_LEN,
        attn_layers=Decoder(dim=1024, depth=32, heads=8, attn_flash=True)
    )

    model = AutoregressiveWrapper(model)

    model = torch.nn.DataParallel(model)

    model.cuda()
    print('=' * 70)

    print('Loading model checkpoint...')

    model.load_state_dict(
        torch.load('Allegro_Music_Transformer_Small_Trained_Model_56000_steps_0.9399_loss_0.7374_acc.pth',
                   map_location='cuda'))
    print('=' * 70)

    model.eval()

    print('Done!')
    print('=' * 70)
    print('Generating...')

    inp = torch.LongTensor([start_tokens]).cuda()

    with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
        with torch.inference_mode():
            out = model.module.generate(inp,
                                        max(1, min(1024, num_tok)),
                                        temperature=0.9,
                                        return_prime=False,
                                        verbose=False)

    out0 = out[0].tolist()

    patches = [0, 24, 32, 40, 42, 46, 56, 71, 73, 0, 53, 19, 0, 0, 0, 0]
    
    ctime = 0
    dur = 1
    vel = 90
    pitch = 60
    channel = 0

    for ss1 in out0:

        if 0 < ss1 < 256:
            
            ctime += ss1 * 8
    
        if 256 <= ss1 < 1280:
            
            dur = ((ss1 - 256) // 8) * 32
            vel = (((ss1 - 256) % 8) + 1) * 15
    
        if 1280 <= ss1 < 2816:
            
            channel = (ss1 - 1280) // 128
            pitch = (ss1 - 1280) % 128
            
            if channel != 9:
                pat = patches[channel]
            else:
                pat = 128
                
            event = ['note', ctime, dur, channel, pitch, vel, pat]
            
            output.append(event)

    if input_align == "Start Times":
        output = TMIDIX.align_escore_notes_to_bars(output)

    elif input_align == "Start Times and Durations":
        output = TMIDIX.align_escore_notes_to_bars(output, trim_durations=True)

    elif input_align == "Start Times and Split Durations":
        output = TMIDIX.align_escore_notes_to_bars(output, split_durations=True)
        
    detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output,
                                                              output_signature = 'Allegro Music Transformer',
                                                              output_file_name = 'Allegro-Music-Transformer-Composition',
                                                              track_name='Project Los Angeles',
                                                              list_of_MIDI_patches=patches
                                                              )

    output_plot = TMIDIX.plot_ms_SONG(output, plot_title='Allegro-Music-Transformer-Composition', return_plt=True)

    audio = midi_to_colab_audio('Allegro-Music-Transformer-Composition.mid', 
                        soundfont_path="SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2",
                        sample_rate=16000,
                        volume_scale=10,
                        output_for_gradio=True
                        )
    
    print('First generated MIDI events', output[2][:3])
    print('-' * 70)
    print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
    print('-' * 70)
    print('Req execution time:', (time.time() - start_time), 'sec')
    
    return output_plot, "Allegro-Music-Transformer-Composition.mid", (16000, audio)

# =================================================================================================

if __name__ == "__main__":
    
    PDT = timezone('US/Pacific')
    
    print('=' * 70)
    print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
    print('=' * 70)
    
    app = gr.Blocks()
    
    with app:
        
        gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Allegro Music Transformer</h1>")
        
        gr.Markdown(
            "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Allegro-Music-Transformer&style=flat)\n\n"
            "Full-attention multi-instrumental music transformer featuring asymmetrical encoding with octo-velocity, and chords counters tokens, optimized for speed and performance\n\n"
            "Check out [Allegro Music Transformer](https://github.com/asigalov61/Allegro-Music-Transformer) on GitHub!\n\n"
            "Special thanks go out to [SkyTNT](https://github.com/SkyTNT/midi-model) for fantastic FluidSynth Synthesizer and MIDI Visualizer code\n\n"
            "[Open In Colab]"
            "(https://colab.research.google.com/github/asigalov61/Allegro-Music-Transformer/blob/main/Allegro_Music_Transformer_Composer.ipynb)"
            " for faster execution and endless generation"
        )
        
        input_instrument = gr.Radio(
            ["Piano", "Guitar", "Bass", "Violin", "Cello", "Harp", "Trumpet", "Sax", "Flute", "Choir", "Organ"],
            value="Piano", label="Lead Instrument Controls", info="Desired lead instrument")
        input_drums = gr.Checkbox(label="Add Drums", value=False, info="Add drums to the composition")
        input_align = gr.Radio(["Do not align", "Start Times", "Start Times and Durations", "Start Times and Split Durations"], label="Align output to bars", value="Do not align")
        input_num_tokens = gr.Slider(16, 1024, value=512, label="Number of Tokens", info="Number of tokens to generate")
        
        run_btn = gr.Button("generate", variant="primary")
        
        output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
        output_plot = gr.Plot(label='output plot')
        output_midi = gr.File(label="output midi", file_types=[".mid"])
        
        run_event = run_btn.click(GenerateMIDI, [input_num_tokens, input_drums, input_instrument, input_align],
                                  [output_plot, output_midi, output_audio])
        app.queue().launch()