Kyio commited on
Commit
1e1e30f
1 Parent(s): 50cc4e2

Create run_inference.py

Browse files
Files changed (1) hide show
  1. run_inference.py +100 -0
run_inference.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import random
5
+ import argparse
6
+ from unidecode import unidecode
7
+ from samplings import top_p_sampling, temperature_sampling
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
+
10
+ def generate_abc(args):
11
+
12
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
13
+
14
+ if torch.cuda.is_available():
15
+ device = torch.device("cuda")
16
+ print('There are %d GPU(s) available.' % torch.cuda.device_count())
17
+ print('We will use the GPU:', torch.cuda.get_device_name(0), '\n')
18
+ else:
19
+ print('No GPU available, using the CPU instead.\n')
20
+ device = torch.device("cpu")
21
+
22
+ num_tunes = args.num_tunes
23
+ max_length = args.max_length
24
+ top_p = args.top_p
25
+ temperature = args.temperature
26
+ seed = args.seed
27
+ print(" HYPERPARAMETERS ".center(60, "#"), '\n')
28
+ args = vars(args)
29
+ for key in args.keys():
30
+ print(key+': '+str(args[key]))
31
+
32
+ with open('input_text.txt') as f:
33
+ text = unidecode(f.read())
34
+ print("\n"+" INPUT TEXT ".center(60, "#"))
35
+ print('\n'+text+'\n')
36
+
37
+ tokenizer = AutoTokenizer.from_pretrained('sander-wood/text-to-music')
38
+ model = AutoModelForSeq2SeqLM.from_pretrained('sander-wood/text-to-music')
39
+ model = model.to(device)
40
+
41
+ input_ids = tokenizer(text,
42
+ return_tensors='pt',
43
+ truncation=True,
44
+ max_length=max_length)['input_ids'].to(device)
45
+ decoder_start_token_id = model.config.decoder_start_token_id
46
+ eos_token_id = model.config.eos_token_id
47
+ random.seed(seed)
48
+ tunes = ""
49
+ print(" OUTPUT TUNES ".center(60, "#"))
50
+
51
+ for n_idx in range(num_tunes):
52
+ print("\nX:"+str(n_idx+1)+"\n", end="")
53
+ tunes += "X:"+str(n_idx+1)+"\n"
54
+ decoder_input_ids = torch.tensor([[decoder_start_token_id]])
55
+
56
+ for t_idx in range(max_length):
57
+
58
+ if seed!=None:
59
+ n_seed = random.randint(0, 1000000)
60
+ random.seed(n_seed)
61
+ else:
62
+ n_seed = None
63
+ outputs = model(input_ids=input_ids,
64
+ decoder_input_ids=decoder_input_ids.to(device))
65
+ probs = outputs.logits[0][-1]
66
+ probs = torch.nn.Softmax(dim=-1)(probs).cpu().detach().numpy()
67
+ sampled_id = temperature_sampling(probs=top_p_sampling(probs,
68
+ top_p=top_p,
69
+ seed=n_seed,
70
+ return_probs=True),
71
+ seed=n_seed,
72
+ temperature=temperature)
73
+ decoder_input_ids = torch.cat((decoder_input_ids, torch.tensor([[sampled_id]])), 1)
74
+ if sampled_id!=eos_token_id:
75
+ sampled_token = tokenizer.decode([sampled_id])
76
+ print(sampled_token, end="")
77
+ tunes += sampled_token
78
+ else:
79
+ tunes += '\n'
80
+ break
81
+
82
+ timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
83
+ with open('output_tunes/'+timestamp+'.abc', 'w') as f:
84
+ f.write(unidecode(tunes))
85
+
86
+ def get_args(parser):
87
+
88
+ parser.add_argument('-num_tunes', type=int, default=3, help='the number of independently computed returned tunes')
89
+ parser.add_argument('-max_length', type=int, default=1024, help='integer to define the maximum length in tokens of each tune')
90
+ parser.add_argument('-top_p', type=float, default=0.9, help='float to define the tokens that are within the sample operation of text generation')
91
+ parser.add_argument('-temperature', type=float, default=1., help='the temperature of the sampling operation')
92
+ parser.add_argument('-seed', type=int, default=None, help='seed for randomstate')
93
+ args = parser.parse_args()
94
+
95
+ return args
96
+
97
+ if __name__ == '__main__':
98
+ parser = argparse.ArgumentParser()
99
+ args = get_args(parser)
100
+ generate_abc(args)