Ryo Okada naotokui commited on
Commit
2857b6c
0 Parent(s):

Duplicate from naotokui/TR-ChatGPT

Browse files

Co-authored-by: Nao Tokui <naotokui@users.noreply.huggingface.co>

Files changed (5) hide show
  1. .gitattributes +34 -0
  2. README.md +13 -0
  3. app.py +159 -0
  4. packages.txt +1 -0
  5. requirements.txt +5 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ChatGPT 707
3
+ emoji: 🌍
4
+ colorFrom: green
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.20.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: naotokui/TR-ChatGPT
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import openai
3
+ import numpy as np
4
+ import pretty_midi
5
+ import re
6
+ import numpy as np
7
+ import os
8
+ import gradio as gr
9
+
10
+ openai.api_key = os.environ.get("OPENAI_API_KEY")
11
+
12
+ # sample data
13
+ markdown_table_sample = """8th
14
+
15
+ | | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
16
+ |----|---|---|---|---|---|---|---|---|
17
+ | BD | x | | x | | | | x | |
18
+ | SD | | | | x | | | | x |
19
+ | CH | x | | x | | x | | x | |
20
+ | OH | | | | x | | | x | |
21
+ | LT | | | | | | x | | |
22
+ | MT | | x | | | x | | | |
23
+ | HT | x | | | x | | | | |
24
+ """
25
+
26
+ markdown_table_sample2 = """16th
27
+
28
+ | | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10| 11| 12| 13| 14| 15| 16|
29
+ |----|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
30
+ | BD | x | | x | | | | x | | x | | x | | x | | x | |
31
+ | SD | | | | x | | | | x | | | x | | | | x | |
32
+ | CH | x | | x | | x | | x | | x | | x | | x | | x | |
33
+ | OH | | | | x | | | x | | | | | x | | | x | |
34
+ | LT | | | | | | x | | | | | | | | x | | |
35
+ | MT | | x | | | x | | | | | x | | | x | | | |
36
+ | HT | x | | | x | | | | | x | | | x | | | | |
37
+ """
38
+
39
+ MIDI_NOTENUM = {
40
+ "BD": 36,
41
+ "SD": 38,
42
+ "CH": 42,
43
+ "HH": 44,
44
+ "OH": 46,
45
+ "LT": 48,
46
+ "MT": 48,
47
+ "HT": 50,
48
+ "CP": 50,
49
+ "CB": 56,
50
+ }
51
+ SR = 44100
52
+
53
+ MAX_QUERY = 5
54
+
55
+ def convert_table_to_audio(markdown_table, resolution=8, bpm = 120.0):
56
+ # convert table to array
57
+ rhythm_pattern = []
58
+ for line in markdown_table.split('\n')[2:]:
59
+ rhythm_pattern.append(line.split('|')[1:-1])
60
+ print(rhythm_pattern)
61
+
62
+ # table to MIDI
63
+ pm = pretty_midi.PrettyMIDI(initial_tempo=bpm) # midi object
64
+ pm_inst = pretty_midi.Instrument(0, is_drum=True) # midi instrument
65
+ pm.instruments.append(pm_inst)
66
+
67
+ note_length = (60. / bpm) * (4.0 / resolution) # note duration
68
+
69
+ beat_num = resolution
70
+ for i in range(len(rhythm_pattern)):
71
+ for j in range(1, len(rhythm_pattern[i])):
72
+ beat_num = j # for looping
73
+ inst = rhythm_pattern[i][0].strip().upper()
74
+ velocity = 0
75
+ if 'x' == rhythm_pattern[i][j].strip():
76
+ velocity = 120
77
+ if 'o' == rhythm_pattern[i][j].strip():
78
+ velocity = 65
79
+ if velocity > 0:
80
+ if inst in MIDI_NOTENUM.keys():
81
+ midinote = MIDI_NOTENUM[inst]
82
+ note = pretty_midi.Note(velocity=velocity, pitch=midinote, start=note_length * (j-1)+0.0001, end=note_length * j)
83
+ pm_inst.notes.append(note)
84
+
85
+ # convert to audio
86
+ audio_data = pm.fluidsynth()
87
+
88
+ # cut off the reverb section
89
+ audio_data = audio_data[:int(SR*note_length*beat_num)] # for looping, cut the tail
90
+ return audio_data
91
+
92
+ def get_answer(question):
93
+ response = openai.ChatCompletion.create(
94
+ model="gpt-3.5-turbo",
95
+ messages=[
96
+ {"role": "system", "content": "You are a rhythm generator. "},
97
+ {"role": "user", "content": "Please generate a rhythm pattern in a Markdown table. Time resolution is the 8th note. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat. You need to write the time resolution first."},
98
+ {"role": "assistant", "content": markdown_table_sample},
99
+ # {"role": "user", "content": "Please generate a rhythm pattern. The resolution is the fourth note. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat. You need to write the time resolution first."},
100
+ # {"role": "assistant", "content": markdown_table_sample},
101
+ {"role": "user", "content": question}
102
+ ]
103
+ )
104
+ return response["choices"][0]["message"]["content"]
105
+
106
+ def generate_rhythm(query, state):
107
+ print(state)
108
+ if state["gen_count"] > MAX_QUERY and len(state["user_token"]) == 0:
109
+ return [None, "You need to set your ChatGPT API Key to try more than %d times" % MAX_QUERY]
110
+ state["gen_count"] = state["gen_count"] + 1
111
+
112
+ # get respance from ChatGPT
113
+ text_output = get_answer(query)
114
+
115
+ # Try to use the first row as time resolution
116
+ resolution_text = text_output.split('|')[0]
117
+ try:
118
+ resolution_text = re.findall(r'\d+', resolution_text)[0]
119
+ resolution = int(resolution_text)
120
+ except:
121
+ resolution = 8 # default
122
+
123
+ # Extract rhythm table
124
+ table = "|" + "|".join(text_output.split('|')[1:-1]) + "|"
125
+ audio_data = convert_table_to_audio(table, resolution)
126
+
127
+ # loop x2
128
+ audio_data = np.tile(audio_data, 4)
129
+
130
+ return [(SR, audio_data), text_output]
131
+ # %%
132
+
133
+ def on_token_change(user_token, state):
134
+ print(user_token)
135
+ openai.api_key = user_token or os.environ.get("OPENAI_API_KEY")
136
+ state["user_token"] = user_token
137
+ return state
138
+
139
+ with gr.Blocks() as demo:
140
+ state = gr.State({"gen_count": 0, "user_token":""})
141
+ with gr.Row():
142
+ with gr.Column():
143
+ # gr.Markdown("Ask ChatGPT to generate rhythm patterns")
144
+ gr.Markdown("***Hey TR-ChatGPT, give me a drum pattern!***")
145
+ gr.Markdown("You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. Use 'x' for an accented beat, 'o' for a weak beat. ", elem_id="label")
146
+ with gr.Row():
147
+ with gr.Column():
148
+ inp = gr.Textbox(placeholder="Give me a Hiphop rhythm pattern with some reggae twist!")
149
+ btn = gr.Button("Generate")
150
+ with gr.Column():
151
+ out_audio = gr.Audio()
152
+ out_text = gr.Textbox(placeholder="ChatGPT output")
153
+ with gr.Row():
154
+ with gr.Column():
155
+ gr.Markdown("Enter your own OpenAI API Key to try out more than 5 times. You can get it [here](https://platform.openai.com/account/api-keys).")
156
+ user_token = gr.Textbox(placeholder="OpenAI API Key", type="password", show_label=False)
157
+ btn.click(fn=generate_rhythm, inputs=[inp, state], outputs=[out_audio, out_text])
158
+ user_token.change(on_token_change, inputs=[user_token, state], outputs=[state])
159
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libfluidsynth1
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ pretty_midi
3
+ pyfluidsynth
4
+ openai
5
+ #brew install fluid-synth