nikhilchintawar commited on
Commit
c131f40
1 Parent(s): 1ae750e

Upload files

Browse files
Files changed (2) hide show
  1. inference.py +67 -0
  2. requirements.txt +18 -0
inference.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @title Imports
2
+ from diffusers import DiffusionPipeline
3
+ from riffusion.spectrogram_image_converter import SpectrogramImageConverter
4
+ from riffusion.spectrogram_params import SpectrogramParams
5
+ from io import BytesIO
6
+
7
+
8
+ # @title Define a `predict` function
9
+
10
+ params = SpectrogramParams()
11
+ converter = SpectrogramImageConverter(params)
12
+
13
+
14
+ def preprocess_function(text):
15
+ with open(text, "r", encoding="utf-8") as f:
16
+ data = f.read()
17
+ print(data)
18
+ # pass the textand the target tanguage to be translated separated by a ";" semicolon
19
+ # data = text_path.read().decode("utf-8")
20
+ prompt = data.split(";")[0]
21
+ negative_prompt = data.split(";")[1].strip()
22
+ print(negative_prompt.strip())
23
+ print(data)
24
+ return (prompt, negative_prompt)
25
+
26
+
27
+ def predict_function(params, pipe):
28
+ prompt, negative_prompt = params
29
+ spec = pipe(
30
+ prompt,
31
+ negative_prompt=negative_prompt,
32
+ width=768,
33
+ ).images[0]
34
+
35
+ wav = converter.audio_from_spectrogram_image(image=spec)
36
+ wav.export("output.wav", format="wav")
37
+ return ("output.wav", spec)
38
+
39
+
40
+ def model_load_function(model_path):
41
+ pipe = DiffusionPipeline.from_pretrained(model_path)
42
+ pipe = pipe.to("cuda")
43
+ return pipe
44
+
45
+
46
+ def postprocess_function(audio_file, content_type=None):
47
+ audio = open(audio_file, "rb")
48
+ audio = audio.read()
49
+ print(type(audio))
50
+ audio_bytes = BytesIO(audio)
51
+ response = dict()
52
+ audio_bytes.seek(0)
53
+ response["output"] = {"data": audio_bytes, "ext": "wav"}
54
+ return response
55
+
56
+
57
+ ## Test the script
58
+ """
59
+ if __name__ == '__main__':
60
+ text = ""
61
+ data = preprocess_function(text)
62
+ model_path = "./model_files"
63
+ path = model_load_function(model_path)
64
+ predictions = predict_function(data,path)
65
+ out = postprocess_function(audio_file)
66
+ print(out)
67
+ """
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ argh
3
+ dacite
4
+ demucs
5
+ diffusers
6
+ numpy
7
+ pillow
8
+ plotly
9
+ pydub
10
+ pysoundfile
11
+ scipy
12
+ soundfile
13
+ sox
14
+ torch
15
+ torchaudio
16
+ torchvision
17
+ transformers
18
+ riffusion