Sandesh Bharadwaj commited on
Commit
30b0ee8
·
unverified ·
2 Parent(s): 5978ae3 0c4c7bf

Merge pull request #1 from animikhaich/web-app-dev

Browse files
.gitignore CHANGED
@@ -167,4 +167,7 @@ cython_debug/
167
  *.mp3
168
  *.mp4
169
 
170
- creds.json
 
 
 
 
167
  *.mp3
168
  *.mp4
169
 
170
+ creds.json
171
+
172
+ # Ignore the test file
173
+ test.py
engine/__init__.py CHANGED
@@ -1 +1,2 @@
1
  from .video_descriptor import DescribeVideo
 
 
1
  from .video_descriptor import DescribeVideo
2
+ from .audio_generator import GenerateAudio
engine/audio_generator.py CHANGED
@@ -1 +1,128 @@
1
- # TODO: Add from model server
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+
4
+ warnings.simplefilter("ignore")
5
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
6
+ import io
7
+ import torch
8
+ import numpy as np
9
+ from audiocraft.models import musicgen
10
+ from scipy.io.wavfile import write as wav_write
11
+
12
+ try:
13
+ from logger import logging
14
+ except:
15
+ import logging
16
+
17
+
18
+ class GenerateAudio:
19
+ def __init__(self, model="musicgen-stereo-small"):
20
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ self.model_name = self.get_model_name(model)
22
+ self.model = self.get_model(self.model_name, self.device)
23
+ self.generated_audio = None
24
+ self.sampling_rate = None
25
+
26
+ @staticmethod
27
+ def get_model(model, device):
28
+ try:
29
+ model = musicgen.MusicGen.get_pretrained(model, device=device)
30
+ logging.info(f"Loaded model: {model}")
31
+ return model
32
+ except Exception as e:
33
+ logging.error(f"Failed to load model: {e}")
34
+ raise ValueError(f"Failed to load model: {e}")
35
+ return
36
+
37
+ @staticmethod
38
+ def get_model_name(model_name):
39
+ if model_name.startswith("facebook/"):
40
+ return model_name
41
+ return f"facebook/{model_name}"
42
+
43
+ @staticmethod
44
+ def duration_sanity_check(duration):
45
+ if duration < 1:
46
+ logging.warning("Duration is less than 1 second. Setting duration to 1 second.")
47
+ return 1
48
+ elif duration > 30:
49
+ logging.warning("Duration is greater than 30 seconds. Setting duration to 30 seconds.")
50
+ return 30
51
+ return duration
52
+
53
+ @staticmethod
54
+ def prompts_sanity_check(prompts):
55
+ if isinstance(prompts, str):
56
+ prompts = [prompts]
57
+ elif not isinstance(prompts, list):
58
+ raise ValueError("Prompts should be a string or a list of strings.")
59
+ else:
60
+ for prompt in prompts:
61
+ if not isinstance(prompt, str):
62
+ raise ValueError("Prompts should be a string or a list of strings.")
63
+ if len(prompts) > 8: # Too many prompts will cause OOM error
64
+ raise ValueError("Maximum number of prompts allowed is 8.")
65
+ return prompts
66
+
67
+
68
+ def generate_audio(self, prompts, duration=10):
69
+ duration = self.duration_sanity_check(duration)
70
+ prompts = self.prompts_sanity_check(prompts)
71
+
72
+ try:
73
+ self.model.set_generation_params(duration=duration)
74
+ result = self.model.generate(prompts, progress=False)
75
+ self.result = result.cpu().numpy().T
76
+ self.result = self.result.transpose((2, 0, 1))
77
+ self.sampling_rate = self.model.sample_rate
78
+ logging.info(
79
+ f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz"
80
+ )
81
+ print(f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz")
82
+ return self.sampling_rate, self.result
83
+ except Exception as e:
84
+ logging.error(f"Failed to generate audio: {e}")
85
+ raise ValueError(f"Failed to generate audio: {e}")
86
+
87
+ def save_audio(self, audio_dir="generated_audio"):
88
+ if self.result is None:
89
+ raise ValueError("Audio is not generated yet.")
90
+ if self.sampling_rate is None:
91
+ raise ValueError("Sampling rate is not available.")
92
+
93
+ paths = []
94
+ os.makedirs(audio_dir, exist_ok=True)
95
+ for i, audio in enumerate(self.result):
96
+ path = os.path.join(audio_dir, f"audio_{i}.wav")
97
+ wav_write(path, self.sampling_rate, audio)
98
+ paths.append(path)
99
+ return paths
100
+
101
+ def get_audio_buffer(self):
102
+ if self.result is None:
103
+ raise ValueError("Audio is not generated yet.")
104
+ if self.sampling_rate is None:
105
+ raise ValueError("Sampling rate is not available.")
106
+
107
+ buffers = []
108
+ for audio in self.result:
109
+ buffer = io.BytesIO()
110
+ wav_write(buffer, self.sampling_rate, audio)
111
+ buffer.seek(0)
112
+ buffers.append(buffer)
113
+ return buffers
114
+
115
+ if __name__ == "__main__":
116
+ audio_gen = GenerateAudio()
117
+ sample_rate, result = audio_gen.generate_audio(
118
+ [
119
+ "A piano playing a jazz melody",
120
+ "A guitar playing a rock riff",
121
+ "A LoFi music for coding"
122
+ ],
123
+ duration=10
124
+ )
125
+ paths = audio_gen.save_audio()
126
+ print(f"Saved audio to: {paths}")
127
+ buffers = audio_gen.get_audio_buffer()
128
+ print(f"Audio buffers: {buffers}")
engine/video_descriptor.py CHANGED
@@ -1,8 +1,7 @@
 
1
  from warnings import simplefilter
2
 
3
  simplefilter("ignore")
4
- import os
5
-
6
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
7
  import json
8
  import time
@@ -78,6 +77,9 @@ class DescribeVideo:
78
 
79
  return json.loads(cleaned_response.text.strip("```json\n"))
80
 
 
 
 
81
  def reset_safety_settings(self):
82
  logging.info("Resetting safety settings")
83
  self.is_safety_set = False
 
1
+ import os
2
  from warnings import simplefilter
3
 
4
  simplefilter("ignore")
 
 
5
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
6
  import json
7
  import time
 
77
 
78
  return json.loads(cleaned_response.text.strip("```json\n"))
79
 
80
+ def __call__(self, video_path):
81
+ return self.describe_video(video_path)
82
+
83
  def reset_safety_settings(self):
84
  logging.info("Resetting safety settings")
85
  self.is_safety_set = False
main.py CHANGED
@@ -1,67 +1,91 @@
1
  import streamlit as st
 
2
 
3
- def main():
4
- st.set_page_config(page_title="VidTune: Where Videos Find Their Melody", layout="centered")
5
 
6
- # Title and Description
7
- st.title("VidTune: Where Videos Find Their Melody")
8
- st.write("VidTune is a web application that allows users to upload videos and generate melodies matching the mood of the video.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Main Page (Page 1)
11
- if 'page' not in st.session_state:
12
- st.session_state.page = 'main'
 
 
 
 
13
 
14
  if st.session_state.page == 'main':
15
- st.header("Video to Music")
16
- uploaded_video = st.file_uploader("Upload Video", type=["mp4"])
17
- if uploaded_video is not None:
18
- st.session_state.uploaded_video = uploaded_video
19
- st.session_state.page = 'video_to_music'
20
-
21
- if st.session_state.page == 'main':
22
- st.header("Prompt to Music")
23
- prompt = st.text_area("Prompt")
24
- if st.button("Generate"):
25
- st.session_state.prompt = prompt
26
- st.session_state.page = 'prompt_to_music'
27
 
28
- # Page 2a (If the user uploads a video)
29
- if st.session_state.page == 'video_to_music':
30
- st.sidebar.title("Settings")
31
- device = st.sidebar.selectbox("Select Device", ["GPU", "CPU"], index=0)
32
- num_samples = st.sidebar.slider("Number of samples", 1, 10, 3)
33
-
34
- st.video(st.session_state.uploaded_video)
35
-
36
- st.text_area("Video Description", "This is a fixed video description", disabled=True)
37
- st.text_area("Music Description")
38
-
39
- if st.button("Generate Music"):
40
- st.session_state.page = 'result'
41
- st.session_state.device = device
42
- st.session_state.num_samples = num_samples
43
 
44
- # Page 2b (If user selects "Prompt to Music" in Page 1)
45
- if st.session_state.page == 'prompt_to_music':
46
- st.sidebar.title("Settings")
47
- device = st.sidebar.selectbox("Select Device", ["GPU", "CPU"], index=0)
48
- num_samples = st.sidebar.slider("Number of samples", 1, 10, 3)
49
-
50
- if st.button("Generate Music"):
51
- st.session_state.page = 'result'
52
- st.session_state.device = device
53
- st.session_state.num_samples = num_samples
54
 
55
- # Page 3 (Results Page)
56
- if st.session_state.page == 'result':
57
- st.header("Generated Music")
58
- for i in range(st.session_state.num_samples):
59
- st.write(f"Music Sample {i+1}")
60
- st.audio(f"Generated Music {i+1}.mp3", format='audio/mp3')
61
- st.download_button(f"Download Music {i+1}", f"Generated Music {i+1}.mp3")
62
-
63
- if st.button("Start Over"):
64
- st.session_state.page = 'main'
65
 
66
- if __name__ == "__main__":
67
- main()
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from engine import DescribeVideo, GenerateAudio
3
 
 
 
4
 
5
+ video_model_map = {
6
+ "Fast": "flash",
7
+ "Quality": "pro",
8
+ }
9
+
10
+ music_model_map = {
11
+ "Fast": "musicgen-stereo-small",
12
+ "Balanced": "musicgen-stereo-medium",
13
+ "Quality": "musicgen-stereo-large",
14
+ }
15
+
16
+
17
+ st.set_page_config(page_title="VidTune: Where Videos Find Their Melody", layout="centered")
18
+
19
+ # Title and Description
20
+ st.title("VidTune: Where Videos Find Their Melody")
21
+ st.write("VidTune is a web application that allows users to upload videos and generate melodies matching the mood of the video.")
22
+
23
+
24
+ # Sidebar
25
+ st.sidebar.title("Settings")
26
+ video_model = st.sidebar.selectbox("Select Video Descriptor", ["Fast", "Balanced", "Quality"], index=0)
27
+ music_model = st.sidebar.selectbox("Select Music Generator", ["Fast", "Balanced", "Quality"], index=0)
28
+ num_samples = st.sidebar.slider("Number of samples", 1, 8, 3)
29
+ generate_button = st.sidebar.button("Generate Music")
30
+
31
+ video_descriptor = DescribeVideo(model=video_model_map[video_model])
32
+ audio_generator = GenerateAudio(model=music_model_map[music_model])
33
+
34
+ video_description = None
35
+
36
+ # Main Page (Page 1)
37
+ if 'page' not in st.session_state:
38
+ st.session_state.page = 'main'
39
+
40
+ if st.session_state.page == 'main':
41
+ st.header("Video to Music")
42
+ uploaded_video = st.file_uploader("Upload Video", type=["mp4"])
43
 
44
+ if uploaded_video is not None:
45
+ st.session_state.uploaded_video = uploaded_video
46
+ with open("temp.mp4", mode='wb') as w:
47
+ w.write(uploaded_video.getvalue())
48
+ video_description = video_descriptor.describe_video("temp.mp4")
49
+
50
+ st.session_state.page = 'video_to_music'
51
 
52
  if st.session_state.page == 'main':
53
+ st.header("Prompt to Music")
54
+ prompt = st.text_area("Prompt")
55
+ if generate_button:
56
+ st.session_state.prompt = prompt
57
+ st.session_state.page = 'prompt_to_music'
58
+
59
+ # Page 2a (If the user uploads a video)
60
+ if st.session_state.page == 'video_to_music':
61
+ st.video(st.session_state.uploaded_video)
 
 
 
62
 
63
+ st.text_area("Video Description", "This is a fixed video description", disabled=True)
64
+ st.text_area("Music Description")
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ if generate_button:
67
+ st.session_state.page = 'result'
68
+ st.session_state.device = device
69
+ st.session_state.num_samples = num_samples
70
+
71
+ # Page 2b (If user selects "Prompt to Music" in Page 1)
72
+ if st.session_state.page == 'prompt_to_music':
73
+ st.sidebar.title("Settings")
74
+ device = st.sidebar.selectbox("Select Device", ["GPU", "CPU"], index=0)
75
+ num_samples = st.sidebar.slider("Number of samples", 1, 10, 3)
76
 
77
+ if generate_button:
78
+ st.session_state.page = 'result'
79
+ st.session_state.device = device
80
+ st.session_state.num_samples = num_samples
 
 
 
 
 
 
81
 
82
+ # Page 3 (Results Page)
83
+ if st.session_state.page == 'result':
84
+ st.header("Generated Music")
85
+ for i in range(st.session_state.num_samples):
86
+ st.write(f"Music Sample {i+1}")
87
+ st.audio(f"Generated Music {i+1}.mp3", format='audio/mp3')
88
+ st.download_button(f"Download Music {i+1}", f"Generated Music {i+1}.mp3")
89
+
90
+ if st.button("Start Over"):
91
+ st.session_state.page = 'main'