Commit
·
61f246b
1
Parent(s):
1fa4e61
Delete src/frontend
Browse files- src/frontend/.streamlit/config.toml +0 -10
- src/frontend/__init__.py +0 -0
- src/frontend/ui.py +0 -97
- src/frontend/ui_backend.py +0 -254
src/frontend/.streamlit/config.toml
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
[theme]
|
2 |
-
base = "dark"
|
3 |
-
primaryColor = "#FFFFFF"
|
4 |
-
backgroundColor = "#212121"
|
5 |
-
secondaryBackgroundColor = "#757575"
|
6 |
-
textColor = "#FFFFFF"
|
7 |
-
font = "sans serif"
|
8 |
-
|
9 |
-
[browser]
|
10 |
-
gatherUsageStats = false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/frontend/__init__.py
DELETED
File without changes
|
src/frontend/ui.py
DELETED
@@ -1,97 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
|
3 |
-
import streamlit as st
|
4 |
-
from ui_backend import (
|
5 |
-
check_for_api,
|
6 |
-
cut_audio_file,
|
7 |
-
display_predictions,
|
8 |
-
load_audio,
|
9 |
-
predict_multiple,
|
10 |
-
predict_single,
|
11 |
-
)
|
12 |
-
|
13 |
-
|
14 |
-
def main():
|
15 |
-
# Page settings
|
16 |
-
st.set_page_config(
|
17 |
-
page_title="Music Instrument Recognition", page_icon="🎸", layout="wide", initial_sidebar_state="collapsed"
|
18 |
-
)
|
19 |
-
|
20 |
-
# Sidebar
|
21 |
-
with st.sidebar:
|
22 |
-
st.title("⚙️ Settings")
|
23 |
-
selected_model = st.selectbox(
|
24 |
-
"Select Model",
|
25 |
-
("Accuracy", "Speed"),
|
26 |
-
index=0,
|
27 |
-
help="Select a slower but more accurate model or a faster but less accurate model",
|
28 |
-
)
|
29 |
-
|
30 |
-
# Main title
|
31 |
-
st.markdown(
|
32 |
-
"<h1 style='text-align: center; color: #FFFFFF; font-size: 3rem;'>Instrument Recognition 🎶</h1>",
|
33 |
-
unsafe_allow_html=True,
|
34 |
-
)
|
35 |
-
|
36 |
-
# Upload widget
|
37 |
-
audio_file = load_audio()
|
38 |
-
|
39 |
-
# Send a health check request to the API in a loop until it is running
|
40 |
-
api_running = check_for_api(10)
|
41 |
-
|
42 |
-
# Enable or disable a button based on API status
|
43 |
-
predict_valid = False
|
44 |
-
cut_valid = False
|
45 |
-
|
46 |
-
if api_running:
|
47 |
-
st.info("API is running", icon="🤖")
|
48 |
-
|
49 |
-
if audio_file:
|
50 |
-
num_files = len(audio_file)
|
51 |
-
st.write(f"Number of uploaded files: {num_files}")
|
52 |
-
predict_valid = True
|
53 |
-
if len(audio_file) > 1:
|
54 |
-
cut_valid = False
|
55 |
-
else:
|
56 |
-
audio_file = audio_file[0]
|
57 |
-
cut_valid = True
|
58 |
-
name = audio_file.name
|
59 |
-
|
60 |
-
if cut_valid:
|
61 |
-
cut_audio = st.checkbox(
|
62 |
-
"✂️ Cut duration",
|
63 |
-
disabled=not predict_valid,
|
64 |
-
help="Cut a long audio file. Model works best if audio is around 15 seconds",
|
65 |
-
)
|
66 |
-
|
67 |
-
if cut_audio:
|
68 |
-
audio_file = cut_audio_file(audio_file, name)
|
69 |
-
|
70 |
-
result = st.button("Predict", disabled=not predict_valid, help="Send the audio to API to get a prediction")
|
71 |
-
|
72 |
-
if result:
|
73 |
-
predictions = {}
|
74 |
-
if isinstance(audio_file, list):
|
75 |
-
predictions = predict_multiple(audio_file, selected_model)
|
76 |
-
|
77 |
-
else:
|
78 |
-
predictions = predict_single(audio_file, name, selected_model)
|
79 |
-
|
80 |
-
# Sort the dictionary alphabetically by key
|
81 |
-
sorted_predictions = dict(sorted(predictions.items()))
|
82 |
-
|
83 |
-
# Convert the sorted dictionary to a JSON string
|
84 |
-
json_string = json.dumps(sorted_predictions)
|
85 |
-
st.download_button(
|
86 |
-
label="Download JSON",
|
87 |
-
file_name="predictions.json",
|
88 |
-
mime="application/json",
|
89 |
-
data=json_string,
|
90 |
-
help="Download the predictions in JSON format",
|
91 |
-
)
|
92 |
-
|
93 |
-
display_predictions(sorted_predictions)
|
94 |
-
|
95 |
-
|
96 |
-
if __name__ == "__main__":
|
97 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/frontend/ui_backend.py
DELETED
@@ -1,254 +0,0 @@
|
|
1 |
-
import io
|
2 |
-
import os
|
3 |
-
import time
|
4 |
-
from json import JSONDecodeError
|
5 |
-
import math
|
6 |
-
|
7 |
-
import requests
|
8 |
-
import soundfile as sf
|
9 |
-
import streamlit as st
|
10 |
-
|
11 |
-
if os.environ.get("IS_DOCKER", False):
|
12 |
-
backend = "http://api:7860"
|
13 |
-
else:
|
14 |
-
backend = "http://0.0.0.0:7860"
|
15 |
-
|
16 |
-
INSTRUMENTS = {
|
17 |
-
"tru": "Trumpet",
|
18 |
-
"sax": "Saxophone",
|
19 |
-
"vio": "Violin",
|
20 |
-
"gac": "Acoustic Guitar",
|
21 |
-
"org": "Organ",
|
22 |
-
"cla": "Clarinet",
|
23 |
-
"flu": "Flute",
|
24 |
-
"voi": "Voice",
|
25 |
-
"gel": "Electric Guitar",
|
26 |
-
"cel": "Cello",
|
27 |
-
"pia": "Piano",
|
28 |
-
}
|
29 |
-
|
30 |
-
|
31 |
-
def load_audio():
|
32 |
-
"""
|
33 |
-
Upload a WAV audio file and display it in a Streamlit app.
|
34 |
-
|
35 |
-
:return: A BytesIO object representing the uploaded audio file, or None if no file was uploaded.
|
36 |
-
:rtype: Optional[BytesIO]
|
37 |
-
"""
|
38 |
-
|
39 |
-
audio_file = st.file_uploader(label="Upload audio file", type="wav", accept_multiple_files=True)
|
40 |
-
if len(audio_file) > 0:
|
41 |
-
st.audio(audio_file[0])
|
42 |
-
return audio_file
|
43 |
-
else:
|
44 |
-
return None
|
45 |
-
|
46 |
-
|
47 |
-
@st.cache_data(show_spinner=False)
|
48 |
-
def check_for_api(max_tries: int):
|
49 |
-
"""
|
50 |
-
Check if the API is running by making a health check request.
|
51 |
-
|
52 |
-
:param max_tries: The maximum number of attempts to check the API's health.
|
53 |
-
:type max_tries: int
|
54 |
-
:return: True if the API is running, False otherwise.
|
55 |
-
:rtype: bool
|
56 |
-
"""
|
57 |
-
trial_count = 0
|
58 |
-
|
59 |
-
with st.spinner("Waiting for API..."):
|
60 |
-
while trial_count <= max_tries:
|
61 |
-
try:
|
62 |
-
response = health_check()
|
63 |
-
if response:
|
64 |
-
return True
|
65 |
-
except requests.exceptions.ConnectionError:
|
66 |
-
trial_count += 1
|
67 |
-
# Handle connection error, e.g. API not yet running
|
68 |
-
time.sleep(5) # Sleep for 1 second before retrying
|
69 |
-
st.error("API is not running. Please refresh the page to try again.", icon="🚨")
|
70 |
-
st.stop()
|
71 |
-
|
72 |
-
|
73 |
-
def cut_audio_file(audio_file, name):
|
74 |
-
"""
|
75 |
-
Cut an audio file and return the cut audio data as a tuple.
|
76 |
-
|
77 |
-
:param audio_file: The path of the audio file to be cut.
|
78 |
-
:type audio_file: str
|
79 |
-
:param name: The name of the audio file to be cut.
|
80 |
-
:type name: str
|
81 |
-
:raises RuntimeError: If the audio file cannot be read.
|
82 |
-
:return: A tuple containing the name and the cut audio data as a BytesIO object.
|
83 |
-
:rtype: tuple
|
84 |
-
"""
|
85 |
-
try:
|
86 |
-
audio_data, sample_rate = sf.read(audio_file)
|
87 |
-
except RuntimeError as e:
|
88 |
-
raise e
|
89 |
-
|
90 |
-
# Display audio duration
|
91 |
-
duration = round(len(audio_data) / sample_rate, 2)
|
92 |
-
st.info(f"Audio Duration: {duration} seconds")
|
93 |
-
|
94 |
-
# Get start and end time for cutting
|
95 |
-
start_time = st.number_input("Start Time (seconds)", min_value=0.0, max_value=duration - 1, step=0.1)
|
96 |
-
end_time = st.number_input("End Time (seconds)", min_value=start_time, value=duration, max_value=duration, step=0.1)
|
97 |
-
|
98 |
-
# Convert start and end time to sample indices
|
99 |
-
start_sample = int(start_time * sample_rate)
|
100 |
-
end_sample = int(end_time * sample_rate)
|
101 |
-
|
102 |
-
# Cut audio
|
103 |
-
cut_audio_data = audio_data[start_sample:end_sample]
|
104 |
-
|
105 |
-
# Create a temporary in-memory file for cut audio
|
106 |
-
audio_file = io.BytesIO()
|
107 |
-
sf.write(audio_file, cut_audio_data, sample_rate, format="wav")
|
108 |
-
|
109 |
-
# Display cut audio
|
110 |
-
st.audio(audio_file, format="audio/wav")
|
111 |
-
audio_file = (name, audio_file)
|
112 |
-
|
113 |
-
return audio_file
|
114 |
-
|
115 |
-
|
116 |
-
def display_predictions(predictions: dict):
|
117 |
-
"""
|
118 |
-
Display the predictions using instrument names instead of codes.
|
119 |
-
|
120 |
-
:param predictions: A dictionary containing the filenames and instruments detected in them.
|
121 |
-
:type predictions: dict
|
122 |
-
"""
|
123 |
-
|
124 |
-
# Display the results using instrument names instead of codes
|
125 |
-
for filename, instruments in predictions.items():
|
126 |
-
st.subheader(filename)
|
127 |
-
|
128 |
-
if isinstance(instruments, str):
|
129 |
-
st.write(instruments)
|
130 |
-
|
131 |
-
else:
|
132 |
-
with st.container():
|
133 |
-
col1, col2 = st.columns([1, 3])
|
134 |
-
present_instruments = [
|
135 |
-
INSTRUMENTS[instrument_code] for instrument_code, presence in instruments.items() if presence
|
136 |
-
]
|
137 |
-
if present_instruments:
|
138 |
-
for instrument_name in present_instruments:
|
139 |
-
with col1:
|
140 |
-
st.write(instrument_name)
|
141 |
-
with col2:
|
142 |
-
st.write("✔️")
|
143 |
-
else:
|
144 |
-
st.write("No instruments found in this file.")
|
145 |
-
|
146 |
-
|
147 |
-
def health_check():
|
148 |
-
"""
|
149 |
-
Sends a health check request to the API and checks if it's running.
|
150 |
-
|
151 |
-
:return: Returns True if the API is running, else False.
|
152 |
-
:rtype: bool
|
153 |
-
"""
|
154 |
-
|
155 |
-
# Send a health check request to the API
|
156 |
-
response = requests.get(f"{backend}/health-check", timeout=100)
|
157 |
-
|
158 |
-
# Check if the API is running
|
159 |
-
if response.status_code == 200:
|
160 |
-
return True
|
161 |
-
else:
|
162 |
-
return False
|
163 |
-
|
164 |
-
|
165 |
-
def predict(data, model_name):
|
166 |
-
"""
|
167 |
-
Sends a POST request to the API with the provided data and model name.
|
168 |
-
|
169 |
-
:param data: The audio data to be used for prediction.
|
170 |
-
:type data: bytes
|
171 |
-
:param model_name: The name of the model to be used for prediction.
|
172 |
-
:type model_name: str
|
173 |
-
:return: The response from the API.
|
174 |
-
:rtype: requests.Response
|
175 |
-
"""
|
176 |
-
|
177 |
-
file = {"file": data}
|
178 |
-
request_data = {"model_name": model_name}
|
179 |
-
|
180 |
-
response = requests.post(
|
181 |
-
f"{backend}/predict", params=request_data, files=file, timeout=300
|
182 |
-
) # Replace with your API endpoint URL
|
183 |
-
|
184 |
-
return response
|
185 |
-
|
186 |
-
|
187 |
-
@st.cache_data(show_spinner=False)
|
188 |
-
def predict_single(audio_file, name, selected_model):
|
189 |
-
"""
|
190 |
-
Predicts the instruments in a single audio file using the selected model.
|
191 |
-
|
192 |
-
:param audio_file: The audio file to be used for prediction.
|
193 |
-
:type audio_file: bytes
|
194 |
-
:param name: The name of the audio file.
|
195 |
-
:type name: str
|
196 |
-
:param selected_model: The name of the selected model.
|
197 |
-
:type selected_model: str
|
198 |
-
:return: A dictionary containing the predicted instruments for the audio file.
|
199 |
-
:rtype: dict
|
200 |
-
"""
|
201 |
-
|
202 |
-
predictions = {}
|
203 |
-
|
204 |
-
with st.spinner("Predicting instruments..."):
|
205 |
-
response = predict(audio_file, selected_model)
|
206 |
-
|
207 |
-
if response.status_code == 200:
|
208 |
-
prediction = response.json()["prediction"]
|
209 |
-
predictions[name] = prediction.get(name, "Error making prediction")
|
210 |
-
else:
|
211 |
-
st.write(response)
|
212 |
-
try:
|
213 |
-
st.json(response.json())
|
214 |
-
except JSONDecodeError:
|
215 |
-
st.error(response.text)
|
216 |
-
st.stop()
|
217 |
-
return predictions
|
218 |
-
|
219 |
-
|
220 |
-
@st.cache_data(show_spinner=False)
|
221 |
-
def predict_multiple(audio_files, selected_model):
|
222 |
-
"""
|
223 |
-
Generates predictions for multiple audio files using the selected model.
|
224 |
-
|
225 |
-
:param audio_files: A list of audio files to make predictions on.
|
226 |
-
:type audio_files: List[UploadedFile]
|
227 |
-
:param selected_model: The model to use for making predictions.
|
228 |
-
:type selected_model: str
|
229 |
-
:return: A dictionary where the keys are the names of the audio files and the values are the predicted labels.
|
230 |
-
:rtype: Dict[str, str]
|
231 |
-
"""
|
232 |
-
|
233 |
-
predictions = {}
|
234 |
-
progress_text = "Getting predictions for all files. Please wait."
|
235 |
-
progress_bar = st.empty()
|
236 |
-
progress_bar.progress(0, text=progress_text)
|
237 |
-
|
238 |
-
num_files = len(audio_files)
|
239 |
-
|
240 |
-
for i, file in enumerate(audio_files):
|
241 |
-
name = file.name
|
242 |
-
response = predict(file, selected_model)
|
243 |
-
if response.status_code == 200:
|
244 |
-
prediction = response.json()["prediction"]
|
245 |
-
predictions[name] = prediction[name]
|
246 |
-
progress_bar.progress((i + 1) / num_files, text=progress_text)
|
247 |
-
else:
|
248 |
-
predictions[name] = "Error making prediction."
|
249 |
-
progress_bar.empty()
|
250 |
-
return predictions
|
251 |
-
|
252 |
-
|
253 |
-
if __name__ == "__main__":
|
254 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|