sander-wood commited on
Commit
b287f62
1 Parent(s): bd048a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +362 -85
app.py CHANGED
@@ -2,10 +2,22 @@ import gradio as gr
2
  import torch
3
  import random
4
  from unidecode import unidecode
5
- from transformers import GPT2LMHeadModel
6
- from samplings import top_p_sampling, temperature_sampling
 
7
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  description = """
11
  <div>
@@ -45,114 +57,379 @@ The first control code `[SECS_3]` specifies there are 3 sections in the tune, an
45
 
46
  """
47
 
48
-
49
- class ABCTokenizer():
 
 
50
  def __init__(self):
 
 
51
  self.pad_token_id = 0
52
- self.bos_token_id = 2
53
- self.eos_token_id = 3
54
- self.merged_tokens = []
55
-
56
- for i in range(8):
57
- self.merged_tokens.append('[SECS_'+str(i+1)+']')
58
- for i in range(32):
59
- self.merged_tokens.append('[BARS_'+str(i+1)+']')
60
- for i in range(11):
61
- self.merged_tokens.append('[SIM_'+str(i)+']')
62
-
63
- def __len__(self):
64
- return 128+len(self.merged_tokens)
 
 
 
 
 
 
 
 
 
 
65
 
66
- def encode(self, text):
67
- encodings = {}
68
- encodings['input_ids'] = torch.tensor(self.txt2ids(text, self.merged_tokens))
69
- encodings['attention_mask'] = torch.tensor([1]*len(encodings['input_ids']))
70
- return encodings
71
-
72
- def decode(self, ids, skip_special_tokens=False):
73
- txt = ""
74
- for i in ids:
75
- if i>=128:
76
- if not skip_special_tokens:
77
- txt += self.merged_tokens[i-128]
78
- elif i!=self.bos_token_id and i!=self.eos_token_id:
79
- txt += chr(i)
80
- return txt
81
-
82
- def txt2ids(self, text, merged_tokens):
83
- ids = ["\""+str(ord(c))+"\"" for c in text]
84
- txt_ids = ' '.join(ids)
85
- for t_idx, token in enumerate(merged_tokens):
86
- token_ids = ["\""+str(ord(c))+"\"" for c in token]
87
- token_txt_ids = ' '.join(token_ids)
88
- txt_ids = txt_ids.replace(token_txt_ids, "\""+str(t_idx+128)+"\"")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- txt_ids = txt_ids.split(' ')
91
- txt_ids = [int(i[1:-1]) for i in txt_ids]
92
- return [self.bos_token_id]+txt_ids+[self.eos_token_id]
 
 
93
 
94
- def generate_abc(control_codes, prefix, num_tunes, max_length, top_p, temperature, seed):
95
 
96
- try:
97
- seed = int(seed)
98
- except:
99
- seed = None
 
 
100
 
101
- prefix = unidecode(control_codes + prefix)
102
- tokenizer = ABCTokenizer()
103
- model = GPT2LMHeadModel.from_pretrained('sander-wood/tunesformer').to(device)
 
 
 
 
 
 
104
 
105
- if prefix:
106
- ids = tokenizer.encode(prefix)['input_ids'][:-1]
107
- else:
108
- ids = torch.tensor([tokenizer.bos_token_id])
 
 
109
 
110
- random.seed(seed)
111
- tunes = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- for c_idx in range(num_tunes):
114
- print("\nX:"+str(c_idx+1)+"\n", end="")
115
- print(tokenizer.decode(ids[1:], skip_special_tokens=True), end="")
116
- input_ids = ids.unsqueeze(0)
117
- for t_idx in range(max_length):
118
  if seed!=None:
119
  n_seed = random.randint(0, 1000000)
120
  random.seed(n_seed)
121
  else:
122
  n_seed = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- outputs = model(input_ids=input_ids.to(device))
125
- probs = outputs.logits[0][-1]
126
- probs = torch.nn.Softmax(dim=-1)(probs).cpu().detach().numpy()
127
- sampled_id = temperature_sampling(probs=top_p_sampling(probs,
128
- top_p=top_p,
129
- seed=n_seed,
130
- return_probs=True),
131
- seed=n_seed,
132
- temperature=temperature)
133
- input_ids = torch.cat((input_ids, torch.tensor([[sampled_id]])), 1)
134
- if sampled_id!=tokenizer.eos_token_id:
135
- print(tokenizer.decode([sampled_id], skip_special_tokens=True), end="")
136
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  else:
138
- tune = "X:"+str(c_idx+1)+"\n"+tokenizer.decode(input_ids.squeeze(), skip_special_tokens=True)
139
- tunes += tune+"\n\n"
140
- print("\n")
141
  break
 
 
142
 
143
  return tunes
144
 
145
- input_control_codes = gr.inputs.Textbox(lines=5, label="Control Codes", default="[SECS_2][BARS_9][SIM_3][BARS_9]")
146
- input_prefix = gr.inputs.Textbox(lines=5, label="Prefix", default="L:1/8\nQ:1/4=114\nM:3/4\nK:D\nde | \"D\"")
 
 
 
 
 
 
 
 
147
  input_num_tunes = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Tunes")
148
- input_max_length = gr.inputs.Slider(minimum=10, maximum=1000, step=10, default=500, label="Max Length")
149
- input_top_p = gr.inputs.Slider(minimum=0.0, maximum=1.0, step=0.05, default=0.9, label="Top P")
150
- input_temperature = gr.inputs.Slider(minimum=0.0, maximum=2.0, step=0.1, default=1.0, label="Temperature")
 
151
  input_seed = gr.inputs.Textbox(lines=1, label="Seed (int)", default="None")
152
  output_abc = gr.outputs.Textbox(label="Generated Tunes")
153
 
154
  gr.Interface(generate_abc,
155
- [input_control_codes, input_prefix, input_num_tunes, input_max_length, input_top_p, input_temperature, input_seed],
156
  output_abc,
157
- title="TunesFormer: Forming Tunes with Control Codes",
158
  description=description).launch()
 
2
  import torch
3
  import random
4
  from unidecode import unidecode
5
+ import re
6
+ from samplings import top_p_sampling, top_k_sampling, temperature_sampling
7
+ from transformers import GPT2Model, GPT2LMHeadModel, PreTrainedModel
8
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ PATCH_LENGTH = 128 # Patch Length
11
+ PATCH_SIZE = 32 # Patch Size
12
+
13
+ PATCH_NUM_LAYERS = 9 # Number of layers in the encoder
14
+ CHAR_NUM_LAYERS = 3 # Number of layers in the decoder
15
+
16
+ NUM_EPOCHS = 32 # Number of epochs to train for (if early stopping doesn't intervene)
17
+ LEARNING_RATE = 5e-5 # Learning rate for the optimizer
18
+ PATCH_SAMPLING_BATCH_SIZE = 0 # Batch size for patch during training, 0 for full context
19
+ LOAD_FROM_CHECKPOINT = False # Whether to load weights from a checkpoint
20
+ SHARE_WEIGHTS = False # Whether to share weights between the encoder and decoder
21
 
22
  description = """
23
  <div>
 
57
 
58
  """
59
 
60
+ class Patchilizer:
61
+ """
62
+ A class for converting music bars to patches and vice versa.
63
+ """
64
  def __init__(self):
65
+ self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
66
+ self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
67
  self.pad_token_id = 0
68
+ self.bos_token_id = 1
69
+ self.eos_token_id = 2
70
+
71
+ def split_bars(self, body):
72
+ """
73
+ Split a body of music into individual bars.
74
+ """
75
+ bars = re.split(self.regexPattern, ''.join(body))
76
+ bars = list(filter(None, bars)) # remove empty strings
77
+ if bars[0] in self.delimiters:
78
+ bars[1] = bars[0] + bars[1]
79
+ bars = bars[1:]
80
+ bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)]
81
+ return bars
82
+
83
+ def bar2patch(self, bar, patch_size=PATCH_SIZE):
84
+ """
85
+ Convert a bar into a patch of specified length.
86
+ """
87
+ patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id]
88
+ patch = patch[:patch_size]
89
+ patch += [self.pad_token_id] * (patch_size - len(patch))
90
+ return patch
91
 
92
+ def patch2bar(self, patch):
93
+ """
94
+ Convert a patch into a bar.
95
+ """
96
+ return ''.join(chr(idx) if idx > self.eos_token_id else '' for idx in patch if idx != self.eos_token_id)
97
+
98
+ def encode(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=False):
99
+ """
100
+ Encode music into patches of specified length.
101
+ """
102
+ lines = unidecode(abc_code).split('\n')
103
+ lines = list(filter(None, lines)) # remove empty lines
104
+
105
+ body = ""
106
+ patches = []
107
+
108
+ for line in lines:
109
+ if len(line) > 1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%score')):
110
+ if body:
111
+ bars = self.split_bars(body)
112
+ patches.extend(self.bar2patch(bar + '\n' if idx == len(bars) - 1 else bar, patch_size)
113
+ for idx, bar in enumerate(bars))
114
+ body = ""
115
+ patches.append(self.bar2patch(line + '\n', patch_size))
116
+ else:
117
+ body += line + '\n'
118
+
119
+ if body:
120
+ patches.extend(self.bar2patch(bar, patch_size) for bar in self.split_bars(body))
121
+
122
+ if add_special_patches:
123
+ bos_patch = [self.bos_token_id] * (patch_size-1) + [self.eos_token_id]
124
+ eos_patch = [self.bos_token_id] + [self.eos_token_id] * (patch_size-1)
125
+ patches = [bos_patch] + patches + [eos_patch]
126
+
127
+ return patches[:patch_length]
128
+
129
+ def decode(self, patches):
130
+ """
131
+ Decode patches into music.
132
+ """
133
+ return ''.join(self.patch2bar(patch) for patch in patches)
134
+
135
+ class PatchLevelDecoder(PreTrainedModel):
136
+ """
137
+ An Patch-level Decoder model for generating patch features in an auto-regressive manner.
138
+ It inherits PreTrainedModel from transformers.
139
+ """
140
+
141
+ def __init__(self, config):
142
+ super().__init__(config)
143
+ self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
144
+ torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
145
+ self.base = GPT2Model(config)
146
+
147
+ def forward(self, patches: torch.Tensor) -> torch.Tensor:
148
+ """
149
+ The forward pass of the patch-level decoder model.
150
+ :param patches: the patches to be encoded
151
+ :return: the encoded patches
152
+ """
153
+ patches = torch.nn.functional.one_hot(patches, num_classes=128).float()
154
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE * 128)
155
+ patches = self.patch_embedding(patches.to(self.device))
156
+
157
+ return self.base(inputs_embeds=patches)
158
+
159
+ class CharLevelDecoder(PreTrainedModel):
160
+ """
161
+ A Char-level Decoder model for generating the characters within each bar patch sequentially.
162
+ It inherits PreTrainedModel from transformers.
163
+ """
164
+ def __init__(self, config):
165
+ super().__init__(config)
166
+ self.pad_token_id = 0
167
+ self.bos_token_id = 1
168
+ self.eos_token_id = 2
169
+ self.base = GPT2LMHeadModel(config)
170
+
171
+ def forward(self, encoded_patches: torch.Tensor, target_patches: torch.Tensor, patch_sampling_batch_size: int):
172
+ """
173
+ The forward pass of the char-level decoder model.
174
+ :param encoded_patches: the encoded patches
175
+ :param target_patches: the target patches
176
+ :return: the decoded patches
177
+ """
178
+ # preparing the labels for model training
179
+ target_masks = target_patches == self.pad_token_id
180
+ labels = target_patches.clone().masked_fill_(target_masks, -100)
181
+
182
+ # masking the labels for model training
183
+ target_masks = torch.ones_like(labels)
184
+ target_masks = target_masks.masked_fill_(labels == -100, 0)
185
+
186
+ # select patches
187
+ if patch_sampling_batch_size!=0 and patch_sampling_batch_size<target_patches.shape[0]:
188
+ indices = list(range(len(target_patches)))
189
+ random.shuffle(indices)
190
+ selected_indices = sorted(indices[:patch_sampling_batch_size])
191
+
192
+ target_patches = target_patches[selected_indices,:]
193
+ target_masks = target_masks[selected_indices,:]
194
+ encoded_patches = encoded_patches[selected_indices,:]
195
+ labels = labels[selected_indices,:]
196
+
197
+ # get input embeddings
198
+ inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
199
+
200
+ # concatenate the encoded patches with the input embeddings
201
+ inputs_embeds = torch.cat((encoded_patches.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
202
+
203
+ return self.base(inputs_embeds=inputs_embeds,
204
+ attention_mask=target_masks,
205
+ labels=labels)
206
+
207
+ def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor):
208
+ """
209
+ The generate function for generating a patch based on the encoded patch and already generated tokens.
210
+ :param encoded_patch: the encoded patch
211
+ :param tokens: already generated tokens in the patch
212
+ :return: the probability distribution of next token
213
+ """
214
+ encoded_patch = encoded_patch.reshape(1, 1, -1)
215
+ tokens = tokens.reshape(1, -1)
216
+
217
+ # Get input embeddings
218
+ tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
219
+
220
+ # Concatenate the encoded patch with the input embeddings
221
+ tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)
222
 
223
+ # Get output from model
224
+ outputs = self.base(inputs_embeds=tokens)
225
+
226
+ # Get probabilities of next token
227
+ probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
228
 
229
+ return probs
230
 
231
+ class TunesFormer(PreTrainedModel):
232
+ """
233
+ TunesFormer is a hierarchical music generation model based on bar patching.
234
+ It includes a patch-level decoder and a character-level decoder.
235
+ It inherits PreTrainedModel from transformers.
236
+ """
237
 
238
+ def __init__(self, encoder_config, decoder_config, share_weights=False):
239
+ super().__init__(encoder_config)
240
+ self.pad_token_id = 0
241
+ self.bos_token_id = 1
242
+ self.eos_token_id = 2
243
+ if share_weights:
244
+ max_layers = max(encoder_config.num_hidden_layers, decoder_config.num_hidden_layers)
245
+ max_context_size = max(encoder_config.max_length, decoder_config.max_length)
246
+ max_position_embeddings = max(encoder_config.max_position_embeddings, decoder_config.max_position_embeddings)
247
 
248
+ encoder_config.num_hidden_layers = max_layers
249
+ encoder_config.max_length = max_context_size
250
+ encoder_config.max_position_embeddings = max_position_embeddings
251
+ decoder_config.num_hidden_layers = max_layers
252
+ decoder_config.max_length = max_context_size
253
+ decoder_config.max_position_embeddings = max_position_embeddings
254
 
255
+ self.patch_level_decoder = PatchLevelDecoder(encoder_config)
256
+ self.char_level_decoder = CharLevelDecoder(decoder_config)
257
+
258
+ if share_weights:
259
+ self.patch_level_decoder.base = self.char_level_decoder.base.transformer
260
+
261
+ def forward(self, patches: torch.Tensor, patch_sampling_batch_size: int=PATCH_SAMPLING_BATCH_SIZE):
262
+ """
263
+ The forward pass of the TunesFormer model.
264
+ :param patches: the patches to be both encoded and decoded
265
+ :return: the decoded patches
266
+ """
267
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE)
268
+ encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
269
+
270
+ return self.char_level_decoder(encoded_patches.squeeze(0)[:-1, :], patches.squeeze(0)[1:, :], patch_sampling_batch_size)
271
+
272
+ def generate(self, patches: torch.Tensor,
273
+ tokens: torch.Tensor,
274
+ top_p: float=1,
275
+ top_k: int=0,
276
+ temperature: float=1,
277
+ seed: int=None):
278
+ """
279
+ The generate function for generating patches based on patches.
280
+ :param patches: the patches to be encoded
281
+ :return: the generated patches
282
+ """
283
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE)
284
+ encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
285
+ if tokens==None:
286
+ tokens = torch.tensor([self.bos_token_id], device=self.device)
287
+ generated_patch = []
288
+ random.seed(seed)
289
 
290
+ while True:
 
 
 
 
291
  if seed!=None:
292
  n_seed = random.randint(0, 1000000)
293
  random.seed(n_seed)
294
  else:
295
  n_seed = None
296
+ prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy()
297
+ prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
298
+ prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
299
+ token = temperature_sampling(prob, temperature=temperature, seed=n_seed)
300
+ generated_patch.append(token)
301
+ if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1:
302
+ break
303
+ else:
304
+ tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)
305
+
306
+ return generated_patch, n_seed
307
+
308
+ def generate_abc(prompt, num_tunes, max_patch, top_p, top_k, temperature, seed, show_control_code):
309
+
310
+ if torch.cuda.is_available():
311
+ device = torch.device("cuda")
312
+ else:
313
+ device = torch.device("cpu")
314
+
315
+ patchilizer = Patchilizer()
316
 
317
+ patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
318
+ max_length=PATCH_LENGTH,
319
+ max_position_embeddings=PATCH_LENGTH,
320
+ vocab_size=1)
321
+ char_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
322
+ max_length=PATCH_SIZE,
323
+ max_position_embeddings=PATCH_SIZE,
324
+ vocab_size=128)
325
+ model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
326
+
327
+ filename = "weights.pth"
328
+
329
+ if os.path.exists(filename):
330
+ print(f"Weights already exist at '{filename}'. Loading...")
331
+ else:
332
+ print(f"Downloading weights to '{filename}' from huggingface.co...")
333
+ try:
334
+ url = 'https://huggingface.co/sander-wood/tunesformer/resolve/main/weights.pth'
335
+ response = requests.get(url, stream=True)
336
+
337
+ total_size = int(response.headers.get('content-length', 0))
338
+ chunk_size = 1024
339
+
340
+ with open(filename, 'wb') as file, tqdm(
341
+ desc=filename,
342
+ total=total_size,
343
+ unit='B',
344
+ unit_scale=True,
345
+ unit_divisor=1024,
346
+ ) as bar:
347
+ for data in response.iter_content(chunk_size=chunk_size):
348
+ size = file.write(data)
349
+ bar.update(size)
350
+ except Exception as e:
351
+ print(f"Error: {e}")
352
+ exit()
353
+
354
+ checkpoint = torch.load('weights.pth')
355
+ model.load_state_dict(checkpoint['model'])
356
+ model = model.to(device)
357
+ model.eval()
358
+
359
+ tunes = ""
360
+
361
+ print("\n"+" OUTPUT TUNES ".center(60, "#"))
362
+
363
+ start_time = time.time()
364
+
365
+ for i in range(num_tunes):
366
+ tune = "X:"+str(i+1) + "\n" + prompt
367
+ lines = re.split(r'(\n)', tune)
368
+ tune = ""
369
+ skip = False
370
+ for line in lines:
371
+ if show_control_code or line[:2] not in ["S:", "B:", "E:"]:
372
+ if not skip:
373
+ print(line, end="")
374
+ tune += line
375
+ skip = False
376
+ else:
377
+ skip = True
378
+
379
+ input_patches = torch.tensor([patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=device)
380
+ if tune=="":
381
+ tokens = None
382
+ else:
383
+ prefix = patchilizer.decode(input_patches[0])
384
+ remaining_tokens = prompt[len(prefix):]
385
+ tokens = torch.tensor([patchilizer.bos_token_id]+[ord(c) for c in remaining_tokens], device=device)
386
+
387
+ while input_patches.shape[1]<max_patch:
388
+ predicted_patch, seed = model.generate(input_patches,
389
+ tokens,
390
+ top_p=top_p,
391
+ top_k=top_k,
392
+ temperature=temperature,
393
+ seed=seed)
394
+ tokens = None
395
+ if predicted_patch[0]!=patchilizer.eos_token_id:
396
+ next_bar = patchilizer.decode([predicted_patch])
397
+ if show_control_code or next_bar[:2] not in ["S:", "B:", "E:"]:
398
+ print(next_bar, end="")
399
+ tune += next_bar
400
+ if next_bar=="":
401
+ break
402
+ next_bar = remaining_tokens+next_bar
403
+ remaining_tokens = ""
404
+ predicted_patch = torch.tensor(patchilizer.bar2patch(next_bar), device=device).unsqueeze(0)
405
+ input_patches = torch.cat([input_patches, predicted_patch.unsqueeze(0)], dim=1)
406
  else:
 
 
 
407
  break
408
+
409
+ tunes += tune+"\n\n"
410
 
411
  return tunes
412
 
413
+ default_prompt = """S:2
414
+ B:9
415
+ E:4
416
+ B:9
417
+ L:1/8
418
+ M:3/4
419
+ K:D
420
+ de |"D" """
421
+
422
+ input_prompt = gr.inputs.Textbox(lines=5, label="Prefix", default=default_prompt)
423
  input_num_tunes = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Tunes")
424
+ input_max_patch = gr.inputs.Slider(minimum=10, maximum=128, step=1, default=128, label="Max Patch")
425
+ input_top_p = gr.inputs.Slider(minimum=0.0, maximum=1.0, step=0.05, default=0.8, label="Top P")
426
+ input_top_k = gr.inputs.Slider(minimum=1, maximum=20, step=1, default=8, label="Top K")
427
+ input_temperature = gr.inputs.Slider(minimum=0.0, maximum=2.0, step=0.05, default=1.2, label="Temperature")
428
  input_seed = gr.inputs.Textbox(lines=1, label="Seed (int)", default="None")
429
  output_abc = gr.outputs.Textbox(label="Generated Tunes")
430
 
431
  gr.Interface(generate_abc,
432
+ [input_prompt, input_num_tunes, input_max_patch, input_top_p, input_top_k, input_temperature, input_seed],
433
  output_abc,
434
+ title="TunesFormer: Forming Irish Tunes with Control Codes by Bar Patching",
435
  description=description).launch()