File size: 1,910 Bytes
0d9f09c
 
 
 
 
 
511c7b4
0d9f09c
 
 
ba4b027
0d9f09c
b995697
 
52c481a
 
b995697
0d9f09c
 
 
 
 
 
ba4b027
 
0d9f09c
 
 
 
 
 
5587862
 
 
 
 
 
 
0d9f09c
d043662
511c7b4
d043662
 
 
 
0d9f09c
 
d043662
 
 
0d9f09c
5587862
0d9f09c
 
c049bdf
 
d043662
 
 
aa4c24f
3a35f81
a2ceada
 
 
3a35f81
c049bdf
511c7b4
b995697
52c481a
c049bdf
0d9f09c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
import openai
from t2a import text_to_audio
import joblib
from sentence_transformers import SentenceTransformer
import numpy as np
import os

reg = joblib.load('text_reg.joblib')
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
finetune = "davinci:ft-personal:autodrummer-v5-2022-11-04-22-34-07"

with open('description.txt', 'r') as f:
    description = f.read()
with open('article.txt', 'r') as f:
    article = f.read()

def get_note_text(prompt):
    prompt = prompt + " ->"
    # get completion from finetune
    response = openai.Completion.create(
        engine=finetune,
        prompt=prompt,
        temperature=0.5,
        max_tokens=200,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
        stop=["###"]
    )
    return response.choices[0].text.strip()
    
def increment_count():
    with open('count.txt', 'r') as f:
        count = int(f.read())
    count += 1
    with open('count.txt', 'w') as f:
        f.write(str(count))

def get_drummer_output(prompt, tempo):
    openai.api_key = os.environ['key']
    if tempo == "fast":
        tempo = 138
    elif tempo == "slow":
        tempo = 100
    note_text = get_note_text(prompt)
    # note_text = note_text + " " + note_text
    # prompt_enc = model.encode([prompt])
    # bpm = int(reg.predict(prompt_enc)[0]) + 20
    audio = text_to_audio(note_text, tempo)
    audio = np.array(audio.get_array_of_samples(), dtype=np.float32)
    increment_count()
    return (96000, audio)

iface = gr.Interface(
    fn=get_drummer_output,
    inputs=[
        "text",
        gr.Radio(["fast", "slow"], label="Tempo", default="fast"),
    ],
    examples=[
        ["hiphop groove 808", "fast"],
        ["rock metal", "fast"],
        ["disco funk", "fast"],
    ],
    outputs="audio",
    title='Autodrummer',
    description=description,
    article=article,
)
iface.launch()