File size: 1,805 Bytes
c131f40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# @title Imports
from diffusers import DiffusionPipeline
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from io import BytesIO


# @title Define a `predict` function

params = SpectrogramParams()
converter = SpectrogramImageConverter(params)


def preprocess_function(text):
    with open(text, "r", encoding="utf-8") as f:
        data = f.read()
    print(data)
    # pass the textand the target tanguage to be translated separated by a ";" semicolon
    # data = text_path.read().decode("utf-8")
    prompt = data.split(";")[0]
    negative_prompt = data.split(";")[1].strip()
    print(negative_prompt.strip())
    print(data)
    return (prompt, negative_prompt)


def predict_function(params, pipe):
    prompt, negative_prompt = params
    spec = pipe(
        prompt,
        negative_prompt=negative_prompt,
        width=768,
    ).images[0]

    wav = converter.audio_from_spectrogram_image(image=spec)
    wav.export("output.wav", format="wav")
    return ("output.wav", spec)


def model_load_function(model_path):
    pipe = DiffusionPipeline.from_pretrained(model_path)
    pipe = pipe.to("cuda")
    return pipe


def postprocess_function(audio_file, content_type=None):
    audio = open(audio_file, "rb")
    audio = audio.read()
    print(type(audio))
    audio_bytes = BytesIO(audio)
    response = dict()
    audio_bytes.seek(0)
    response["output"] = {"data": audio_bytes, "ext": "wav"}
    return response


## Test the script
"""
if __name__ == '__main__':
    text = ""
    data = preprocess_function(text)
    model_path = "./model_files"
    path = model_load_function(model_path)
    predictions = predict_function(data,path)
    out = postprocess_function(audio_file)
    print(out)
"""