File size: 7,864 Bytes
b6bb234
2a68ddd
3fea55e
 
 
 
b6bb234
2a68ddd
 
b6bb234
037c7a6
2a68ddd
498b808
2a68ddd
 
 
3908eea
2a68ddd
 
b6bb234
 
 
a3e0baa
 
037c7a6
a914076
7b0fbfe
 
606e959
 
6c71f04
606e959
 
 
6c71f04
11f0dcb
b6bb234
 
 
 
 
a3e0baa
 
b6bb234
 
a3e0baa
b6bb234
498b808
 
6c71f04
a3e0baa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
037c7a6
 
 
 
 
 
 
 
 
ff7362d
037c7a6
 
 
 
 
 
399d36d
037c7a6
 
 
 
 
ff7362d
399d36d
037c7a6
 
 
 
 
 
 
de73f36
 
 
 
 
 
 
 
 
 
 
 
3fea55e
606e959
a3e0baa
3fea55e
a3e0baa
 
11f0dcb
de73f36
95675e7
a3e0baa
606e959
de73f36
a3e0baa
 
 
de73f36
a3e0baa
 
 
f16b04e
 
 
 
 
 
 
 
a3e0baa
 
ac1c10f
a3e0baa
2ff429b
b6bb234
 
0ccc4f3
6aeacba
86b7652
6aeacba
 
 
 
 
606e959
de73f36
7b0fbfe
 
6c71f04
606e959
 
2ff429b
2a68ddd
a3e0baa
 
 
6c71f04
 
 
606e959
6c71f04
 
 
2a68ddd
b6bb234
2a68ddd
 
a3e0baa
 
 
fb9c37a
a3e0baa
 
 
 
aa762cf
a3e0baa
 
2a68ddd
aa762cf
86b7652
aa762cf
2a68ddd
25f28a8
2a68ddd
25f28a8
2a68ddd
a914076
6aeacba
037c7a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import os.path

import time
import datetime
from pytz import timezone

import torch
import torch.nn.functional as F

import gradio as gr
import spaces

from x_transformer import *
import tqdm

import TMIDIX
from midi_to_colab_audio import midi_to_colab_audio

import matplotlib.pyplot as plt

in_space = os.getenv("SYSTEM") == "spaces"

# =================================================================================================

@spaces.GPU
def GenerateMIDI(num_tok, idrums, iinstr):
    print('=' * 70)
    print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
    start_time = time.time()

    print('-' * 70)
    print('Req num tok:', num_tok)
    print('Req instr:', iinstr)
    print('Drums:', idrums)
    print('-' * 70)

    if idrums:
        drums = 3074
    else:
        drums = 3073

    instruments_list = ["Piano", "Guitar", "Bass", "Violin", "Cello", "Harp", "Trumpet", "Sax", "Flute", 'Drums',
                        "Choir", "Organ"]
    first_note_instrument_number = instruments_list.index(iinstr)

    start_tokens = [3087, drums, 3075 + first_note_instrument_number]

    print('Selected Improv sequence:')
    print(start_tokens)
    print('-' * 70)

    output_signature = 'Allegro Music Transformer'
    output_file_name = 'Allegro-Music-Transformer-Music-Composition'
    track_name = 'Project Los Angeles'
    list_of_MIDI_patches = [0, 24, 32, 40, 42, 46, 56, 71, 73, 0, 53, 19, 0, 0, 0, 0]
    number_of_ticks_per_quarter = 500
    text_encoding = 'ISO-8859-1'

    output_header = [number_of_ticks_per_quarter,
                     [['track_name', 0, bytes(output_signature, text_encoding)]]]

    patch_list = [['patch_change', 0, 0, list_of_MIDI_patches[0]],
                  ['patch_change', 0, 1, list_of_MIDI_patches[1]],
                  ['patch_change', 0, 2, list_of_MIDI_patches[2]],
                  ['patch_change', 0, 3, list_of_MIDI_patches[3]],
                  ['patch_change', 0, 4, list_of_MIDI_patches[4]],
                  ['patch_change', 0, 5, list_of_MIDI_patches[5]],
                  ['patch_change', 0, 6, list_of_MIDI_patches[6]],
                  ['patch_change', 0, 7, list_of_MIDI_patches[7]],
                  ['patch_change', 0, 8, list_of_MIDI_patches[8]],
                  ['patch_change', 0, 9, list_of_MIDI_patches[9]],
                  ['patch_change', 0, 10, list_of_MIDI_patches[10]],
                  ['patch_change', 0, 11, list_of_MIDI_patches[11]],
                  ['patch_change', 0, 12, list_of_MIDI_patches[12]],
                  ['patch_change', 0, 13, list_of_MIDI_patches[13]],
                  ['patch_change', 0, 14, list_of_MIDI_patches[14]],
                  ['patch_change', 0, 15, list_of_MIDI_patches[15]],
                  ['track_name', 0, bytes(track_name, text_encoding)]]

    output = output_header + [patch_list]

    print('Loading model...')

    SEQ_LEN = 2048

    # instantiate the model

    model = TransformerWrapper(
        num_tokens=3088,
        max_seq_len=SEQ_LEN,
        attn_layers=Decoder(dim=1024, depth=32, heads=8, attn_flash=True)
    )

    model = AutoregressiveWrapper(model)

    model = torch.nn.DataParallel(model)

    model.cuda()
    print('=' * 70)

    print('Loading model checkpoint...')

    model.load_state_dict(
        torch.load('Allegro_Music_Transformer_Small_Trained_Model_56000_steps_0.9399_loss_0.7374_acc.pth',
                   map_location='cuda'))
    print('=' * 70)

    model.eval()

    print('Done!')
    print('=' * 70)


    inp = torch.LongTensor([start_tokens]).cuda()

    with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
        with torch.inference_mode():
            out = model.module.generate(inp,
                                        max(1, min(1024, num_tok)),
                                        temperature=0.9,
                                        return_prime=False,
                                        verbose=False)

    out0 = out[0].tolist()
    
    ctime = 0
    dur = 0
    vel = 90
    pitch = 0
    channel = 0

    for ss1 in out0:

        if 0 < ss1 < 256:
            ctime += ss1 * 8
    
        if 256 <= ss1 < 1280:
            dur = ((ss1 - 256) // 8) * 32
            vel = (((ss1 - 256) % 8) + 1) * 15
    
        if 1280 <= ss1 < 2816:
            channel = (ss1 - 1280) // 128
            pitch = (ss1 - 1280) % 128
            
            if channel != 9:
                pat = list_of_MIDI_patches[channel]
            else:
                pat = 128
                
            event = ['note', ctime, dur, channel, pitch, vel, pat]
            
            output[-1].append(event)

    midi_data = TMIDIX.score2midi(output, text_encoding)

    with open(f"Allegro-Music-Transformer-Composition.mid", 'wb') as f:
        f.write(midi_data)

    output_plot = TMIDIX.plot_ms_SONG(output[2], plot_title='Allegro-Music-Transformer-Composition', return_plt=True)

    audio = midi_to_colab_audio('Allegro-Music-Transformer-Composition.mid', 
                        soundfont_path="SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2",
                        sample_rate=16000,
                        volume_scale=10,
                        output_for_gradio=True
                        )
    
    print('First generated MIDI events', output[2][:3])
    print('-' * 70)
    print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
    print('-' * 70)
    print('Req execution time:', (time.time() - start_time), 'sec')
    
    return output_plot, "Allegro-Music-Transformer-Composition.mid", (16000, audio)

# =================================================================================================

if __name__ == "__main__":
    
    PDT = timezone('US/Pacific')
    
    print('=' * 70)
    print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
    print('=' * 70)
    
    app = gr.Blocks()
    with app:
        gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Allegro Music Transformer</h1>")
        gr.Markdown(
            "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Allegro-Music-Transformer&style=flat)\n\n"
            "Full-attention multi-instrumental music transformer featuring asymmetrical encoding with octo-velocity, and chords counters tokens, optimized for speed and performance\n\n"
            "Check out [Allegro Music Transformer](https://github.com/asigalov61/Allegro-Music-Transformer) on GitHub!\n\n"
            "Special thanks go out to [SkyTNT](https://github.com/SkyTNT/midi-model) for fantastic FluidSynth Synthesizer and MIDI Visualizer code\n\n"
            "[Open In Colab]"
            "(https://colab.research.google.com/github/asigalov61/Allegro-Music-Transformer/blob/main/Allegro_Music_Transformer_Composer.ipynb)"
            " for faster execution and endless generation"
        )
        
        input_instrument = gr.Radio(
            ["Piano", "Guitar", "Bass", "Violin", "Cello", "Harp", "Trumpet", "Sax", "Flute", "Choir", "Organ"],
            value="Piano", label="Lead Instrument Controls", info="Desired lead instrument")
        input_drums = gr.Checkbox(label="Add Drums", value=False, info="Add drums to the composition")
        input_num_tokens = gr.Slider(16, 1024, value=512, label="Number of Tokens", info="Number of tokens to generate")
        
        run_btn = gr.Button("generate", variant="primary")
        
        output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
        output_plot = gr.Plot(label='output plot')
        output_midi = gr.File(label="output midi", file_types=[".mid"])
        run_event = run_btn.click(GenerateMIDI, [input_num_tokens, input_drums, input_instrument],
                                  [output_plot, output_midi, output_audio])
        app.queue().launch()