Karlo Pintaric
commited on
Commit
·
fdc1efd
1
Parent(s):
48fb9cc
Upload 25 files
Browse files- .gitattributes +1 -0
- DockerFile.backend +17 -0
- setup.py +39 -0
- src/__init__.py +0 -0
- src/api/ModelService.py +93 -0
- src/api/__init__.py +0 -0
- src/api/main.py +133 -0
- src/api/main_test.py +63 -0
- src/api/models/acc_model_ast.pth +3 -0
- src/api/models/acc_model_thresh.npy +3 -0
- src/api/models/speed_model_ast.pth +3 -0
- src/api/models/speed_model_thresh.npy +3 -0
- src/api/test_files/test.wav +3 -0
- src/frontend/.streamlit/config.toml +10 -0
- src/frontend/__init__.py +0 -0
- src/frontend/ui.py +97 -0
- src/frontend/ui_backend.py +254 -0
- src/modeling/__init__.py +2 -0
- src/modeling/dataset.py +162 -0
- src/modeling/learner.py +333 -0
- src/modeling/loss.py +96 -0
- src/modeling/metrics.py +179 -0
- src/modeling/models.py +313 -0
- src/modeling/preprocess.py +336 -0
- src/modeling/transforms.py +398 -0
- src/modeling/utils.py +336 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
src/api/test_files/test.wav filter=lfs diff=lfs merge=lfs -text
|
DockerFile.backend
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use an official Python runtime as the base image
|
2 |
+
FROM python:3.9-slim
|
3 |
+
|
4 |
+
# Set the working directory in the container
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Copy the setup.py file and the package directory into the container
|
8 |
+
COPY ./setup.py .
|
9 |
+
|
10 |
+
# Install the package and its dependencies
|
11 |
+
COPY ./src ./src
|
12 |
+
|
13 |
+
RUN pip install --no-cache-dir .[backend] torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu
|
14 |
+
|
15 |
+
EXPOSE 7860
|
16 |
+
|
17 |
+
CMD ["uvicorn", "src.api.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
setup.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name="lumen-irmas",
|
5 |
+
version="0.1.0",
|
6 |
+
description="LUMEN Data Science nagradni zadatak",
|
7 |
+
author="Karlo Pintaric i Tatjana Cigula",
|
8 |
+
packages=find_packages(include=["src"]),
|
9 |
+
python_requires=">=3.9",
|
10 |
+
install_requires=[
|
11 |
+
"numpy==1.23.5",
|
12 |
+
"transformers==4.27.4",
|
13 |
+
],
|
14 |
+
extras_require={
|
15 |
+
"backend": ["fastapi==0.95.1", "uvicorn==0.21.1", "pydantic==1.10.7", "python-multipart==0.0.6"],
|
16 |
+
"frontend": ["streamlit==1.21.0", "requests==2.28.2", "soundfile==0.12.1"],
|
17 |
+
"user": [
|
18 |
+
"lumen-irmas[backend]",
|
19 |
+
"lumen-irmas[frontend]",
|
20 |
+
"torch==1.13.1",
|
21 |
+
"torchaudio==0.13.1",
|
22 |
+
"torchvision==0.14.1",
|
23 |
+
],
|
24 |
+
"dev": [
|
25 |
+
"lumen-irmas[user]",
|
26 |
+
"librosa==0.10.0.post2",
|
27 |
+
"pandas==1.5.3",
|
28 |
+
"scikit-learn==1.2.2",
|
29 |
+
"tqdm==4.65.0",
|
30 |
+
"wandb==0.14.2",
|
31 |
+
"pytest==7.3.1",
|
32 |
+
"joblib==1.2.0",
|
33 |
+
"PyYAML==6.0",
|
34 |
+
"flake8==6.0.0",
|
35 |
+
"isort== 5.12.0",
|
36 |
+
"black==23.3.0"
|
37 |
+
]
|
38 |
+
},
|
39 |
+
)
|
src/__init__.py
ADDED
File without changes
|
src/api/ModelService.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms
|
6 |
+
|
7 |
+
from src.modeling import ASTPretrained, FeatureExtractor, PreprocessPipeline, StudentAST
|
8 |
+
|
9 |
+
MODELS_FOLDER = Path(__file__).parent / "models"
|
10 |
+
|
11 |
+
CLASSES = ["tru", "sax", "vio", "gac", "org", "cla", "flu", "voi", "gel", "cel", "pia"]
|
12 |
+
|
13 |
+
|
14 |
+
def load_model(model_type: str):
|
15 |
+
"""
|
16 |
+
Loads a pre-trained AST model of the specified type.
|
17 |
+
|
18 |
+
:param model_type: The type of model to load
|
19 |
+
:type model_type: str
|
20 |
+
:return: The loaded pre-trained AST model.
|
21 |
+
:rtype: ASTPretrained
|
22 |
+
"""
|
23 |
+
|
24 |
+
if model_type == "accuracy":
|
25 |
+
model = ASTPretrained(n_classes=11, download_weights=False)
|
26 |
+
model.load_state_dict(torch.load(f"{MODELS_FOLDER}/acc_model_ast.pth", map_location=torch.device("cpu")))
|
27 |
+
else:
|
28 |
+
model = StudentAST(n_classes=11, hidden_size=192, num_heads=3)
|
29 |
+
model.load_state_dict(torch.load(f"{MODELS_FOLDER}/speed_model_ast.pth", map_location=torch.device("cpu")))
|
30 |
+
model.eval()
|
31 |
+
return model
|
32 |
+
|
33 |
+
|
34 |
+
def load_labels():
|
35 |
+
"""
|
36 |
+
Loads a dictionary of class labels for the AST model.
|
37 |
+
|
38 |
+
:return: A dictionary where the keys are the class indices and the values are the class labels.
|
39 |
+
:rtype: Dict[int, str]
|
40 |
+
"""
|
41 |
+
|
42 |
+
labels = {i: CLASSES[i] for i in range(len(CLASSES))}
|
43 |
+
return labels
|
44 |
+
|
45 |
+
|
46 |
+
def load_thresholds(model_type: str):
|
47 |
+
"""
|
48 |
+
Loads the prediction thresholds for the AST model.
|
49 |
+
|
50 |
+
:return: The prediction thresholds for each class.
|
51 |
+
:rtype: np.ndarray
|
52 |
+
"""
|
53 |
+
if model_type == "accuracy":
|
54 |
+
thresholds = np.load(f"{MODELS_FOLDER}/acc_model_thresh.npy", allow_pickle=True)
|
55 |
+
else:
|
56 |
+
thresholds = np.load(f"{MODELS_FOLDER}/speed_model_thresh.npy", allow_pickle=True)
|
57 |
+
return thresholds
|
58 |
+
|
59 |
+
|
60 |
+
class ModelServiceAST:
|
61 |
+
def __init__(self, model_type: str):
|
62 |
+
"""
|
63 |
+
Initializes a ModelServiceAST instance with the specified model type.
|
64 |
+
|
65 |
+
:param model_type: The type of model to load
|
66 |
+
:type model_type: str
|
67 |
+
"""
|
68 |
+
|
69 |
+
self.model = load_model(model_type)
|
70 |
+
self.labels = load_labels()
|
71 |
+
self.thresholds = load_thresholds(model_type)
|
72 |
+
self.transform = transforms.Compose([PreprocessPipeline(target_sr=16000), FeatureExtractor(sr=16000)])
|
73 |
+
|
74 |
+
def get_prediction(self, audio):
|
75 |
+
"""
|
76 |
+
Gets the binary predictions for the given audio file.
|
77 |
+
|
78 |
+
:param audio_file: The file object for the input audio to make predictions for.
|
79 |
+
:type audio_file: file object
|
80 |
+
:return: A dictionary where the keys are the class labels and the values are binary predictions (0 or 1).
|
81 |
+
:rtype: Dict[str, int]
|
82 |
+
"""
|
83 |
+
processed = self.transform(audio)
|
84 |
+
with torch.no_grad():
|
85 |
+
# Don't forget to transpose the output to seq_len x num_features!!!
|
86 |
+
output = torch.sigmoid(self.model(processed.mT))
|
87 |
+
output = output.squeeze().numpy().astype(float)
|
88 |
+
|
89 |
+
binary_predictions = {}
|
90 |
+
for i, label in enumerate(CLASSES):
|
91 |
+
binary_predictions[label] = int(output[i] >= self.thresholds[i])
|
92 |
+
|
93 |
+
return binary_predictions
|
src/api/__init__.py
ADDED
File without changes
|
src/api/main.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from logging.handlers import RotatingFileHandler
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Dict
|
6 |
+
|
7 |
+
|
8 |
+
from fastapi import Depends, FastAPI, File, UploadFile
|
9 |
+
from fastapi.exceptions import RequestValidationError
|
10 |
+
from fastapi.responses import JSONResponse
|
11 |
+
from src.api.ModelService import ModelServiceAST
|
12 |
+
from pydantic import BaseModel, validator
|
13 |
+
|
14 |
+
LOG_SAVE_DIR = Path(__file__).parent / "logs"
|
15 |
+
if not os.path.exists(LOG_SAVE_DIR):
|
16 |
+
os.makedirs(LOG_SAVE_DIR)
|
17 |
+
|
18 |
+
ml_models = {}
|
19 |
+
ml_models["Accuracy"] = ModelServiceAST(model_type="accuracy")
|
20 |
+
ml_models["Speed"] = ModelServiceAST(model_type="speed")
|
21 |
+
|
22 |
+
app = FastAPI()
|
23 |
+
|
24 |
+
# Define the allowed file formats and maximum file size (in bytes)
|
25 |
+
ALLOWED_FILE_FORMATS = ["wav"]
|
26 |
+
|
27 |
+
# Configure logging
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
logger.setLevel(logging.DEBUG)
|
30 |
+
|
31 |
+
# Create a rotating file handler to save logs to a file
|
32 |
+
handler = RotatingFileHandler(f"{LOG_SAVE_DIR}/app.log", maxBytes=100000, backupCount=5)
|
33 |
+
handler.setLevel(logging.DEBUG)
|
34 |
+
|
35 |
+
# Define the log format
|
36 |
+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
37 |
+
handler.setFormatter(formatter)
|
38 |
+
|
39 |
+
# Add the handler to the logger
|
40 |
+
logger.addHandler(handler)
|
41 |
+
|
42 |
+
|
43 |
+
class InvalidFileTypeError(Exception):
|
44 |
+
def __init__(self):
|
45 |
+
self.message = "Only wav files are supported"
|
46 |
+
super().__init__(self.message)
|
47 |
+
|
48 |
+
|
49 |
+
class InvalidModelError(Exception):
|
50 |
+
def __init__(self):
|
51 |
+
self.message = "Selected model doesn't exist"
|
52 |
+
super().__init__(self.message)
|
53 |
+
|
54 |
+
|
55 |
+
class MissingFileError(Exception):
|
56 |
+
def __init__(self):
|
57 |
+
self.message = "File cannot be None"
|
58 |
+
super().__init__(self.message)
|
59 |
+
|
60 |
+
|
61 |
+
class PredictionRequest(BaseModel):
|
62 |
+
model_name: str
|
63 |
+
|
64 |
+
@validator("model_name")
|
65 |
+
@classmethod
|
66 |
+
def valid_model(cls, v):
|
67 |
+
if v not in ml_models.keys():
|
68 |
+
raise InvalidModelError
|
69 |
+
return v
|
70 |
+
|
71 |
+
|
72 |
+
class PredictionResult(BaseModel):
|
73 |
+
prediction: Dict[str, Dict[str, int]]
|
74 |
+
|
75 |
+
|
76 |
+
@app.exception_handler(RequestValidationError)
|
77 |
+
def validation_exception_handler(request, ex):
|
78 |
+
logger.error(f"Request validation error: {ex}")
|
79 |
+
return JSONResponse(content={"error": "Bad Request", "detail": ex.errors()}, status_code=400)
|
80 |
+
|
81 |
+
|
82 |
+
@app.exception_handler(InvalidFileTypeError)
|
83 |
+
def filetype_exception_handler(request, ex):
|
84 |
+
logger.error(f"Invalid file type error: {ex}")
|
85 |
+
return JSONResponse(content={"error": "Bad Request", "detail": ex.message}, status_code=400)
|
86 |
+
|
87 |
+
|
88 |
+
@app.exception_handler(InvalidModelError)
|
89 |
+
def model_exception_handler(request, ex):
|
90 |
+
logger.error(f"Invalid model error: {ex}")
|
91 |
+
return JSONResponse(content={"error": "Bad Request", "detail": ex.message}, status_code=400)
|
92 |
+
|
93 |
+
|
94 |
+
@app.exception_handler(MissingFileError)
|
95 |
+
def handle_missing_file_error(request, ex):
|
96 |
+
logger.error(f"Missing file error: {ex}")
|
97 |
+
return JSONResponse(content={"error": "Bad Request", "detail": ex.message}, status_code=400)
|
98 |
+
|
99 |
+
|
100 |
+
@app.exception_handler(Exception)
|
101 |
+
def handle_exceptions(request, ex):
|
102 |
+
logger.exception(f"Internal server error: {ex}")
|
103 |
+
# If an exception occurs during processing, return a JSON response with an error message
|
104 |
+
return JSONResponse(content={"error": "Internal Server Error", "detail": str(ex)}, status_code=500)
|
105 |
+
|
106 |
+
|
107 |
+
@app.get("/")
|
108 |
+
def root():
|
109 |
+
logger.info("Received request to root endpoint")
|
110 |
+
return {"message": "Welcome to my API. Go to /docs to view the documentation."}
|
111 |
+
|
112 |
+
|
113 |
+
@app.get("/health-check")
|
114 |
+
def health_check():
|
115 |
+
"""
|
116 |
+
Health check endpoint to verify if the API is running.
|
117 |
+
"""
|
118 |
+
logger.info("Health check endpoint was hit")
|
119 |
+
return {"status": "API is running"}
|
120 |
+
|
121 |
+
|
122 |
+
@app.post("/predict")
|
123 |
+
def predict(request: PredictionRequest = Depends(), file: UploadFile = File(...)) -> PredictionResult: # noqa
|
124 |
+
if not file:
|
125 |
+
raise MissingFileError
|
126 |
+
if file.filename.split(".")[-1].lower() not in ALLOWED_FILE_FORMATS:
|
127 |
+
raise InvalidFileTypeError
|
128 |
+
logger.info(f"Prediction request received: {request}")
|
129 |
+
output = ml_models[request.model_name].get_prediction(file.file)
|
130 |
+
logger.info(f"Prediction result: {output}")
|
131 |
+
prediction_result = PredictionResult(prediction={file.filename: output})
|
132 |
+
|
133 |
+
return prediction_result
|
src/api/main_test.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import sys
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import soundfile as sf
|
6 |
+
from fastapi.testclient import TestClient
|
7 |
+
|
8 |
+
sys.path.append(".")
|
9 |
+
|
10 |
+
from src.api.main import app # noqa
|
11 |
+
|
12 |
+
TEST_FILES_DIR = Path(__file__).parent / "test_files"
|
13 |
+
TEST_WAV_FILE = TEST_FILES_DIR / "test.wav"
|
14 |
+
|
15 |
+
client = TestClient(app)
|
16 |
+
|
17 |
+
|
18 |
+
def test_health_check():
|
19 |
+
response = client.get("/health-check")
|
20 |
+
assert response.status_code == 200
|
21 |
+
assert response.json() == {"status": "API is running"}
|
22 |
+
|
23 |
+
|
24 |
+
def test_predict_valid_cut_file():
|
25 |
+
audio_data, sample_rate = sf.read(TEST_WAV_FILE)
|
26 |
+
audio_file = io.BytesIO()
|
27 |
+
sf.write(audio_file, audio_data, sample_rate, format="wav")
|
28 |
+
audio_file = ("test.wav", audio_file)
|
29 |
+
|
30 |
+
file = {"file": audio_file}
|
31 |
+
request_data = {"model_name": "Accuracy"}
|
32 |
+
# Make a request to the /predict endpoint
|
33 |
+
response = client.post("/predict", params=request_data, files=file)
|
34 |
+
|
35 |
+
# Check that the response is successful
|
36 |
+
assert response.status_code == 200
|
37 |
+
assert response.json()["prediction"]["test.wav"] is not None
|
38 |
+
|
39 |
+
|
40 |
+
def test_predict_valid_file():
|
41 |
+
with open(TEST_WAV_FILE, "rb") as file:
|
42 |
+
data = {"model_name": "Accuracy"}
|
43 |
+
response = client.post("/predict", params=data, files={"file": file})
|
44 |
+
assert response.status_code == 200
|
45 |
+
assert response.json()["prediction"]["test.wav"] is not None
|
46 |
+
|
47 |
+
|
48 |
+
def test_predict_invalid_file_type():
|
49 |
+
file_data = io.BytesIO(b"dummy txt data")
|
50 |
+
file = ("test.txt", file_data)
|
51 |
+
data = {"model_name": "Accuracy"}
|
52 |
+
response = client.post("/predict", params=data, files={"file": file})
|
53 |
+
assert response.status_code == 400
|
54 |
+
assert "Only wav files are supported" in response.json()["detail"]
|
55 |
+
|
56 |
+
|
57 |
+
def test_predict_invalid_model():
|
58 |
+
file_data = io.BytesIO(b"dummy wav data")
|
59 |
+
file = ("test.wav", file_data)
|
60 |
+
data = {"model_name": "InvalidModel"}
|
61 |
+
response = client.post("/predict", params=data, files={"file": file})
|
62 |
+
assert response.status_code == 400
|
63 |
+
assert "Selected model doesn't exist" in response.json()["detail"]
|
src/api/models/acc_model_ast.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2305b1d04ed918b6d6428f86dfde162d6912b5021741ff58785fa7b020094ec0
|
3 |
+
size 344860756
|
src/api/models/acc_model_thresh.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3034a1e953618280465b52b4104184b577e783afdf6231add9b96d119e12addf
|
3 |
+
size 216
|
src/api/models/speed_model_ast.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e529b7b85881d249f455b5386cdb5306915ad34cd5fc5fafeca35fc965573637
|
3 |
+
size 22573905
|
src/api/models/speed_model_thresh.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:56838178f12bccc05cf5ffc92a7ff570a70d3a42f3f87c977ad8c9ae0f4a3359
|
3 |
+
size 216
|
src/api/test_files/test.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:60f854cc407877512a3e68a286cfd26e95dc2f0a4e76ba313fbb3e21ddf2d2f9
|
3 |
+
size 3492764
|
src/frontend/.streamlit/config.toml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
base = "dark"
|
3 |
+
primaryColor = "#FFFFFF"
|
4 |
+
backgroundColor = "#212121"
|
5 |
+
secondaryBackgroundColor = "#757575"
|
6 |
+
textColor = "#FFFFFF"
|
7 |
+
font = "sans serif"
|
8 |
+
|
9 |
+
[browser]
|
10 |
+
gatherUsageStats = false
|
src/frontend/__init__.py
ADDED
File without changes
|
src/frontend/ui.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
from ui_backend import (
|
5 |
+
check_for_api,
|
6 |
+
cut_audio_file,
|
7 |
+
display_predictions,
|
8 |
+
load_audio,
|
9 |
+
predict_multiple,
|
10 |
+
predict_single,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
# Page settings
|
16 |
+
st.set_page_config(
|
17 |
+
page_title="Music Instrument Recognition", page_icon="🎸", layout="wide", initial_sidebar_state="collapsed"
|
18 |
+
)
|
19 |
+
|
20 |
+
# Sidebar
|
21 |
+
with st.sidebar:
|
22 |
+
st.title("⚙️ Settings")
|
23 |
+
selected_model = st.selectbox(
|
24 |
+
"Select Model",
|
25 |
+
("Accuracy", "Speed"),
|
26 |
+
index=0,
|
27 |
+
help="Select a slower but more accurate model or a faster but less accurate model",
|
28 |
+
)
|
29 |
+
|
30 |
+
# Main title
|
31 |
+
st.markdown(
|
32 |
+
"<h1 style='text-align: center; color: #FFFFFF; font-size: 3rem;'>Instrument Recognition 🎶</h1>",
|
33 |
+
unsafe_allow_html=True,
|
34 |
+
)
|
35 |
+
|
36 |
+
# Upload widget
|
37 |
+
audio_file = load_audio()
|
38 |
+
|
39 |
+
# Send a health check request to the API in a loop until it is running
|
40 |
+
api_running = check_for_api(10)
|
41 |
+
|
42 |
+
# Enable or disable a button based on API status
|
43 |
+
predict_valid = False
|
44 |
+
cut_valid = False
|
45 |
+
|
46 |
+
if api_running:
|
47 |
+
st.info("API is running", icon="🤖")
|
48 |
+
|
49 |
+
if audio_file:
|
50 |
+
num_files = len(audio_file)
|
51 |
+
st.write(f"Number of uploaded files: {num_files}")
|
52 |
+
predict_valid = True
|
53 |
+
if len(audio_file) > 1:
|
54 |
+
cut_valid = False
|
55 |
+
else:
|
56 |
+
audio_file = audio_file[0]
|
57 |
+
cut_valid = True
|
58 |
+
name = audio_file.name
|
59 |
+
|
60 |
+
if cut_valid:
|
61 |
+
cut_audio = st.checkbox(
|
62 |
+
"✂️ Cut duration",
|
63 |
+
disabled=not predict_valid,
|
64 |
+
help="Cut a long audio file. Model works best if audio is around 15 seconds",
|
65 |
+
)
|
66 |
+
|
67 |
+
if cut_audio:
|
68 |
+
audio_file = cut_audio_file(audio_file, name)
|
69 |
+
|
70 |
+
result = st.button("Predict", disabled=not predict_valid, help="Send the audio to API to get a prediction")
|
71 |
+
|
72 |
+
if result:
|
73 |
+
predictions = {}
|
74 |
+
if isinstance(audio_file, list):
|
75 |
+
predictions = predict_multiple(audio_file, selected_model)
|
76 |
+
|
77 |
+
else:
|
78 |
+
predictions = predict_single(audio_file, name, selected_model)
|
79 |
+
|
80 |
+
# Sort the dictionary alphabetically by key
|
81 |
+
sorted_predictions = dict(sorted(predictions.items()))
|
82 |
+
|
83 |
+
# Convert the sorted dictionary to a JSON string
|
84 |
+
json_string = json.dumps(sorted_predictions)
|
85 |
+
st.download_button(
|
86 |
+
label="Download JSON",
|
87 |
+
file_name="predictions.json",
|
88 |
+
mime="application/json",
|
89 |
+
data=json_string,
|
90 |
+
help="Download the predictions in JSON format",
|
91 |
+
)
|
92 |
+
|
93 |
+
display_predictions(sorted_predictions)
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
main()
|
src/frontend/ui_backend.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from json import JSONDecodeError
|
5 |
+
import math
|
6 |
+
|
7 |
+
import requests
|
8 |
+
import soundfile as sf
|
9 |
+
import streamlit as st
|
10 |
+
|
11 |
+
if os.environ.get("IS_DOCKER", False):
|
12 |
+
backend = "http://api:7860"
|
13 |
+
else:
|
14 |
+
backend = "http://0.0.0.0:7860"
|
15 |
+
|
16 |
+
INSTRUMENTS = {
|
17 |
+
"tru": "Trumpet",
|
18 |
+
"sax": "Saxophone",
|
19 |
+
"vio": "Violin",
|
20 |
+
"gac": "Acoustic Guitar",
|
21 |
+
"org": "Organ",
|
22 |
+
"cla": "Clarinet",
|
23 |
+
"flu": "Flute",
|
24 |
+
"voi": "Voice",
|
25 |
+
"gel": "Electric Guitar",
|
26 |
+
"cel": "Cello",
|
27 |
+
"pia": "Piano",
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def load_audio():
|
32 |
+
"""
|
33 |
+
Upload a WAV audio file and display it in a Streamlit app.
|
34 |
+
|
35 |
+
:return: A BytesIO object representing the uploaded audio file, or None if no file was uploaded.
|
36 |
+
:rtype: Optional[BytesIO]
|
37 |
+
"""
|
38 |
+
|
39 |
+
audio_file = st.file_uploader(label="Upload audio file", type="wav", accept_multiple_files=True)
|
40 |
+
if len(audio_file) > 0:
|
41 |
+
st.audio(audio_file[0])
|
42 |
+
return audio_file
|
43 |
+
else:
|
44 |
+
return None
|
45 |
+
|
46 |
+
|
47 |
+
@st.cache_data(show_spinner=False)
|
48 |
+
def check_for_api(max_tries: int):
|
49 |
+
"""
|
50 |
+
Check if the API is running by making a health check request.
|
51 |
+
|
52 |
+
:param max_tries: The maximum number of attempts to check the API's health.
|
53 |
+
:type max_tries: int
|
54 |
+
:return: True if the API is running, False otherwise.
|
55 |
+
:rtype: bool
|
56 |
+
"""
|
57 |
+
trial_count = 0
|
58 |
+
|
59 |
+
with st.spinner("Waiting for API..."):
|
60 |
+
while trial_count <= max_tries:
|
61 |
+
try:
|
62 |
+
response = health_check()
|
63 |
+
if response:
|
64 |
+
return True
|
65 |
+
except requests.exceptions.ConnectionError:
|
66 |
+
trial_count += 1
|
67 |
+
# Handle connection error, e.g. API not yet running
|
68 |
+
time.sleep(5) # Sleep for 1 second before retrying
|
69 |
+
st.error("API is not running. Please refresh the page to try again.", icon="🚨")
|
70 |
+
st.stop()
|
71 |
+
|
72 |
+
|
73 |
+
def cut_audio_file(audio_file, name):
|
74 |
+
"""
|
75 |
+
Cut an audio file and return the cut audio data as a tuple.
|
76 |
+
|
77 |
+
:param audio_file: The path of the audio file to be cut.
|
78 |
+
:type audio_file: str
|
79 |
+
:param name: The name of the audio file to be cut.
|
80 |
+
:type name: str
|
81 |
+
:raises RuntimeError: If the audio file cannot be read.
|
82 |
+
:return: A tuple containing the name and the cut audio data as a BytesIO object.
|
83 |
+
:rtype: tuple
|
84 |
+
"""
|
85 |
+
try:
|
86 |
+
audio_data, sample_rate = sf.read(audio_file)
|
87 |
+
except RuntimeError as e:
|
88 |
+
raise e
|
89 |
+
|
90 |
+
# Display audio duration
|
91 |
+
duration = round(len(audio_data) / sample_rate, 2)
|
92 |
+
st.info(f"Audio Duration: {duration} seconds")
|
93 |
+
|
94 |
+
# Get start and end time for cutting
|
95 |
+
start_time = st.number_input("Start Time (seconds)", min_value=0.0, max_value=duration - 1, step=0.1)
|
96 |
+
end_time = st.number_input("End Time (seconds)", min_value=start_time, value=duration, max_value=duration, step=0.1)
|
97 |
+
|
98 |
+
# Convert start and end time to sample indices
|
99 |
+
start_sample = int(start_time * sample_rate)
|
100 |
+
end_sample = int(end_time * sample_rate)
|
101 |
+
|
102 |
+
# Cut audio
|
103 |
+
cut_audio_data = audio_data[start_sample:end_sample]
|
104 |
+
|
105 |
+
# Create a temporary in-memory file for cut audio
|
106 |
+
audio_file = io.BytesIO()
|
107 |
+
sf.write(audio_file, cut_audio_data, sample_rate, format="wav")
|
108 |
+
|
109 |
+
# Display cut audio
|
110 |
+
st.audio(audio_file, format="audio/wav")
|
111 |
+
audio_file = (name, audio_file)
|
112 |
+
|
113 |
+
return audio_file
|
114 |
+
|
115 |
+
|
116 |
+
def display_predictions(predictions: dict):
|
117 |
+
"""
|
118 |
+
Display the predictions using instrument names instead of codes.
|
119 |
+
|
120 |
+
:param predictions: A dictionary containing the filenames and instruments detected in them.
|
121 |
+
:type predictions: dict
|
122 |
+
"""
|
123 |
+
|
124 |
+
# Display the results using instrument names instead of codes
|
125 |
+
for filename, instruments in predictions.items():
|
126 |
+
st.subheader(filename)
|
127 |
+
|
128 |
+
if isinstance(instruments, str):
|
129 |
+
st.write(instruments)
|
130 |
+
|
131 |
+
else:
|
132 |
+
with st.container():
|
133 |
+
col1, col2 = st.columns([1, 3])
|
134 |
+
present_instruments = [
|
135 |
+
INSTRUMENTS[instrument_code] for instrument_code, presence in instruments.items() if presence
|
136 |
+
]
|
137 |
+
if present_instruments:
|
138 |
+
for instrument_name in present_instruments:
|
139 |
+
with col1:
|
140 |
+
st.write(instrument_name)
|
141 |
+
with col2:
|
142 |
+
st.write("✔️")
|
143 |
+
else:
|
144 |
+
st.write("No instruments found in this file.")
|
145 |
+
|
146 |
+
|
147 |
+
def health_check():
|
148 |
+
"""
|
149 |
+
Sends a health check request to the API and checks if it's running.
|
150 |
+
|
151 |
+
:return: Returns True if the API is running, else False.
|
152 |
+
:rtype: bool
|
153 |
+
"""
|
154 |
+
|
155 |
+
# Send a health check request to the API
|
156 |
+
response = requests.get(f"{backend}/health-check", timeout=100)
|
157 |
+
|
158 |
+
# Check if the API is running
|
159 |
+
if response.status_code == 200:
|
160 |
+
return True
|
161 |
+
else:
|
162 |
+
return False
|
163 |
+
|
164 |
+
|
165 |
+
def predict(data, model_name):
|
166 |
+
"""
|
167 |
+
Sends a POST request to the API with the provided data and model name.
|
168 |
+
|
169 |
+
:param data: The audio data to be used for prediction.
|
170 |
+
:type data: bytes
|
171 |
+
:param model_name: The name of the model to be used for prediction.
|
172 |
+
:type model_name: str
|
173 |
+
:return: The response from the API.
|
174 |
+
:rtype: requests.Response
|
175 |
+
"""
|
176 |
+
|
177 |
+
file = {"file": data}
|
178 |
+
request_data = {"model_name": model_name}
|
179 |
+
|
180 |
+
response = requests.post(
|
181 |
+
f"{backend}/predict", params=request_data, files=file, timeout=300
|
182 |
+
) # Replace with your API endpoint URL
|
183 |
+
|
184 |
+
return response
|
185 |
+
|
186 |
+
|
187 |
+
@st.cache_data(show_spinner=False)
|
188 |
+
def predict_single(audio_file, name, selected_model):
|
189 |
+
"""
|
190 |
+
Predicts the instruments in a single audio file using the selected model.
|
191 |
+
|
192 |
+
:param audio_file: The audio file to be used for prediction.
|
193 |
+
:type audio_file: bytes
|
194 |
+
:param name: The name of the audio file.
|
195 |
+
:type name: str
|
196 |
+
:param selected_model: The name of the selected model.
|
197 |
+
:type selected_model: str
|
198 |
+
:return: A dictionary containing the predicted instruments for the audio file.
|
199 |
+
:rtype: dict
|
200 |
+
"""
|
201 |
+
|
202 |
+
predictions = {}
|
203 |
+
|
204 |
+
with st.spinner("Predicting instruments..."):
|
205 |
+
response = predict(audio_file, selected_model)
|
206 |
+
|
207 |
+
if response.status_code == 200:
|
208 |
+
prediction = response.json()["prediction"]
|
209 |
+
predictions[name] = prediction.get(name, "Error making prediction")
|
210 |
+
else:
|
211 |
+
st.write(response)
|
212 |
+
try:
|
213 |
+
st.json(response.json())
|
214 |
+
except JSONDecodeError:
|
215 |
+
st.error(response.text)
|
216 |
+
st.stop()
|
217 |
+
return predictions
|
218 |
+
|
219 |
+
|
220 |
+
@st.cache_data(show_spinner=False)
|
221 |
+
def predict_multiple(audio_files, selected_model):
|
222 |
+
"""
|
223 |
+
Generates predictions for multiple audio files using the selected model.
|
224 |
+
|
225 |
+
:param audio_files: A list of audio files to make predictions on.
|
226 |
+
:type audio_files: List[UploadedFile]
|
227 |
+
:param selected_model: The model to use for making predictions.
|
228 |
+
:type selected_model: str
|
229 |
+
:return: A dictionary where the keys are the names of the audio files and the values are the predicted labels.
|
230 |
+
:rtype: Dict[str, str]
|
231 |
+
"""
|
232 |
+
|
233 |
+
predictions = {}
|
234 |
+
progress_text = "Getting predictions for all files. Please wait."
|
235 |
+
progress_bar = st.empty()
|
236 |
+
progress_bar.progress(0, text=progress_text)
|
237 |
+
|
238 |
+
num_files = len(audio_files)
|
239 |
+
|
240 |
+
for i, file in enumerate(audio_files):
|
241 |
+
name = file.name
|
242 |
+
response = predict(file, selected_model)
|
243 |
+
if response.status_code == 200:
|
244 |
+
prediction = response.json()["prediction"]
|
245 |
+
predictions[name] = prediction[name]
|
246 |
+
progress_bar.progress((i + 1) / num_files, text=progress_text)
|
247 |
+
else:
|
248 |
+
predictions[name] = "Error making prediction."
|
249 |
+
progress_bar.empty()
|
250 |
+
return predictions
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
pass
|
src/modeling/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from src.modeling.models import ASTPretrained, StudentAST
|
2 |
+
from src.modeling.transforms import FeatureExtractor, PreprocessPipeline
|
src/modeling/dataset.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import List, Optional, Tuple, Type, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch.nn.utils.rnn import pad_sequence
|
7 |
+
from torch.utils.data import DataLoader, Dataset
|
8 |
+
from torchvision.transforms import Compose
|
9 |
+
|
10 |
+
import modeling.transforms as transform_module
|
11 |
+
from modeling.transforms import (
|
12 |
+
LabelsFromTxt,
|
13 |
+
OneHotEncode,
|
14 |
+
ParentMultilabel,
|
15 |
+
Preprocess,
|
16 |
+
Transform,
|
17 |
+
)
|
18 |
+
from modeling.utils import CLASSES, get_wav_files, init_obj, init_transforms
|
19 |
+
|
20 |
+
|
21 |
+
class IRMASDataset(Dataset):
|
22 |
+
"""Dataset class for IRMAS dataset.
|
23 |
+
|
24 |
+
:param audio_dir: Directory containing the audio files
|
25 |
+
:type audio_dir: Union[str, Path]
|
26 |
+
:param preprocess: Preprocessing method to apply to the audio files
|
27 |
+
:type preprocess: Type[Preprocess]
|
28 |
+
:param signal_augments: Signal augmentation method to apply to the audio files, defaults to None
|
29 |
+
:type signal_augments: Optional[Union[Type[Compose], Type[Transform]]], optional
|
30 |
+
:param transforms: Transform method to apply to the audio files, defaults to None
|
31 |
+
:type transforms: Optional[Union[Type[Compose], Type[Transform]]], optional
|
32 |
+
:param spec_augments: Spectrogram augmentation method to apply to the audio files, defaults to None
|
33 |
+
:type spec_augments: Optional[Union[Type[Compose], Type[Transform]]], optional
|
34 |
+
:param subset: Subset of the data to load (train, valid, or test), defaults to "train"
|
35 |
+
:type subset: str, optional
|
36 |
+
:raises AssertionError: Raises an assertion error if subset is not train, valid or test
|
37 |
+
:raises OSError: Raises an OS error if test_songs.txt is not found in the data folder
|
38 |
+
:return: A tuple of the preprocessed audio signal and the corresponding one-hot encoded label
|
39 |
+
:rtype: Tuple[Tensor, Tensor]
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
audio_dir: Union[str, Path],
|
45 |
+
preprocess: Type[Preprocess],
|
46 |
+
signal_augments: Optional[Union[Type[Compose], Type[Transform]]] = None,
|
47 |
+
transforms: Optional[Union[Type[Compose], Type[Transform]]] = None,
|
48 |
+
spec_augments: Optional[Union[Type[Compose], Type[Transform]]] = None,
|
49 |
+
subset: str = "train",
|
50 |
+
):
|
51 |
+
self.files = get_wav_files(audio_dir)
|
52 |
+
assert subset in ["train", "valid", "test"], "Subset can only be train, valid or test"
|
53 |
+
self.subset = subset
|
54 |
+
|
55 |
+
if self.subset != "train":
|
56 |
+
try:
|
57 |
+
test_songs = np.genfromtxt("../data/test_songs.txt", dtype=str, ndmin=1, delimiter="\n")
|
58 |
+
except OSError as e:
|
59 |
+
print("Error: {e}")
|
60 |
+
print("test_songs.txt not found in data/. Please generate a split before training")
|
61 |
+
raise e
|
62 |
+
|
63 |
+
if self.subset == "valid":
|
64 |
+
self.files = [file for file in self.files if Path(file).stem not in test_songs]
|
65 |
+
if self.subset == "test":
|
66 |
+
self.files = [file for file in self.files if Path(file).stem in test_songs]
|
67 |
+
|
68 |
+
self.preprocess = preprocess
|
69 |
+
self.transforms = transforms
|
70 |
+
self.signal_augments = signal_augments
|
71 |
+
self.spec_augments = spec_augments
|
72 |
+
|
73 |
+
def __len__(self):
|
74 |
+
"""Return the length of the dataset.
|
75 |
+
|
76 |
+
:return: The length of the dataset
|
77 |
+
:rtype: int
|
78 |
+
"""
|
79 |
+
|
80 |
+
return len(self.files)
|
81 |
+
|
82 |
+
def __getitem__(self, index):
|
83 |
+
"""Get an item from the dataset.
|
84 |
+
|
85 |
+
:param index: The index of the item to get
|
86 |
+
:type index: int
|
87 |
+
:return: A tuple of the preprocessed audio signal and the corresponding one-hot encoded label
|
88 |
+
:rtype: Tuple[Tensor, Tensor]
|
89 |
+
"""
|
90 |
+
|
91 |
+
sample_path = self.files[index]
|
92 |
+
signal = self.preprocess(sample_path)
|
93 |
+
|
94 |
+
if self.subset == "train":
|
95 |
+
target_transforms = Compose([ParentMultilabel(sep="-"), OneHotEncode(CLASSES)])
|
96 |
+
else:
|
97 |
+
target_transforms = Compose([LabelsFromTxt(), OneHotEncode(CLASSES)])
|
98 |
+
|
99 |
+
label = target_transforms(sample_path)
|
100 |
+
|
101 |
+
if self.signal_augments is not None and self.subset == "train":
|
102 |
+
signal = self.signal_augments(signal)
|
103 |
+
|
104 |
+
if self.transforms is not None:
|
105 |
+
signal = self.transforms(signal)
|
106 |
+
|
107 |
+
if self.spec_augments is not None and self.subset == "train":
|
108 |
+
signal = self.spec_augments(signal)
|
109 |
+
|
110 |
+
return signal, label.float()
|
111 |
+
|
112 |
+
|
113 |
+
def collate_fn(data: List[Tuple[torch.Tensor, torch.Tensor]]):
|
114 |
+
"""
|
115 |
+
Function to collate a batch of audio signals and their corresponding labels.
|
116 |
+
|
117 |
+
:param data: A list of tuples containing the audio signals and their corresponding labels.
|
118 |
+
:type data: List[Tuple[torch.Tensor, torch.Tensor]]
|
119 |
+
|
120 |
+
:return: A tuple containing the batch of audio signals and their corresponding labels.
|
121 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
122 |
+
"""
|
123 |
+
|
124 |
+
features, labels = zip(*data)
|
125 |
+
features = [item.squeeze().T for item in features]
|
126 |
+
# Pads items to same length if they're not
|
127 |
+
features = pad_sequence(features, batch_first=True)
|
128 |
+
labels = torch.stack(labels)
|
129 |
+
|
130 |
+
return features, labels
|
131 |
+
|
132 |
+
|
133 |
+
def get_loader(config: dict, subset: str):
|
134 |
+
"""
|
135 |
+
Function to create a PyTorch DataLoader for a given subset of the IRMAS dataset.
|
136 |
+
|
137 |
+
:param config: A configuration object.
|
138 |
+
:type config: Any
|
139 |
+
:param subset: The subset of the dataset to use. Can be "train" or "valid".
|
140 |
+
:type subset: str
|
141 |
+
|
142 |
+
:return: A PyTorch DataLoader for the specified subset of the dataset.
|
143 |
+
:rtype: torch.utils.data.DataLoader
|
144 |
+
"""
|
145 |
+
|
146 |
+
dst = IRMASDataset(
|
147 |
+
config.train_dir if subset == "train" else config.valid_dir,
|
148 |
+
preprocess=init_obj(config.preprocess, transform_module),
|
149 |
+
transforms=init_obj(config.transforms, transform_module),
|
150 |
+
signal_augments=init_transforms(config.signal_augments, transform_module),
|
151 |
+
spec_augments=init_transforms(config.spec_augments, transform_module),
|
152 |
+
subset=subset,
|
153 |
+
)
|
154 |
+
|
155 |
+
return DataLoader(
|
156 |
+
dst,
|
157 |
+
batch_size=config.batch_size,
|
158 |
+
shuffle=True if subset == "train" else False,
|
159 |
+
pin_memory=True if torch.cuda.is_available() else False,
|
160 |
+
num_workers=torch.get_num_threads() - 1,
|
161 |
+
collate_fn=collate_fn,
|
162 |
+
)
|
src/modeling/learner.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.optim as optim
|
8 |
+
import wandb
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from tqdm.autonotebook import tqdm
|
11 |
+
|
12 |
+
import modeling.loss as loss_module
|
13 |
+
import modeling.metrics as metrics_module
|
14 |
+
from modeling.loss import HardDistillationLoss
|
15 |
+
from modeling.models import freeze, layerwise_lr_decay
|
16 |
+
from modeling.utils import init_obj
|
17 |
+
|
18 |
+
|
19 |
+
class BaseLearner(ABC):
|
20 |
+
"""
|
21 |
+
Abstract base class for a learner.
|
22 |
+
|
23 |
+
:param train_dl: DataLoader for training data
|
24 |
+
:type train_dl: Type[DataLoader]
|
25 |
+
:param valid_dl: DataLoader for validation data
|
26 |
+
:type valid_dl: Type[DataLoader]
|
27 |
+
:param model: Model to be trained
|
28 |
+
:type model: Type[nn.Module]
|
29 |
+
:param config: Configuration object
|
30 |
+
:type config: Any
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, train_dl: DataLoader, valid_dl: DataLoader, model: nn.Module, config):
|
34 |
+
self.train_dl = train_dl
|
35 |
+
self.valid_dl = valid_dl
|
36 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
37 |
+
self.model = model.to(self.device)
|
38 |
+
self.config = config
|
39 |
+
|
40 |
+
@abstractmethod
|
41 |
+
def fit(
|
42 |
+
self,
|
43 |
+
):
|
44 |
+
"""Abstract method for fitting the model."""
|
45 |
+
|
46 |
+
pass
|
47 |
+
|
48 |
+
@abstractmethod
|
49 |
+
def _train_epoch(
|
50 |
+
self,
|
51 |
+
):
|
52 |
+
"""Abstract method for training the model for one epoch."""
|
53 |
+
pass
|
54 |
+
|
55 |
+
@abstractmethod
|
56 |
+
def _test_epoch(
|
57 |
+
self,
|
58 |
+
):
|
59 |
+
"""Abstract method for testing the model for one epoch."""
|
60 |
+
pass
|
61 |
+
|
62 |
+
|
63 |
+
class Learner(BaseLearner):
|
64 |
+
def __init__(self, train_dl: DataLoader, valid_dl: DataLoader, model: nn.Module, config):
|
65 |
+
"""
|
66 |
+
A class that inherits from the BaseLearner class and represents a learner object.
|
67 |
+
|
68 |
+
:param train_dl: DataLoader for training data
|
69 |
+
:type train_dl: DataLoader
|
70 |
+
:param valid_dl: DataLoader for validation data
|
71 |
+
:type valid_dl: DataLoader
|
72 |
+
:param model: Model to be trained
|
73 |
+
:type model: nn.Module
|
74 |
+
:param config: Configuration object
|
75 |
+
:type config: Any
|
76 |
+
"""
|
77 |
+
|
78 |
+
super().__init__(train_dl, valid_dl, model, config)
|
79 |
+
|
80 |
+
self.model = torch.nn.DataParallel(module=self.model, device_ids=list(range(config.num_gpus)))
|
81 |
+
self.loss_fn = init_obj(self.config.loss, loss_module)
|
82 |
+
params = layerwise_lr_decay(self.config, self.model)
|
83 |
+
self.optimizer = init_obj(self.config.optimizer, optim, params)
|
84 |
+
self.scheduler = init_obj(
|
85 |
+
self.config.scheduler,
|
86 |
+
optim.lr_scheduler,
|
87 |
+
self.optimizer,
|
88 |
+
max_lr=[param["lr"] for param in params],
|
89 |
+
epochs=self.config.epochs,
|
90 |
+
steps_per_epoch=int(np.ceil(len(train_dl) / self.config.num_accum)),
|
91 |
+
)
|
92 |
+
|
93 |
+
self.verbose = self.config.verbose
|
94 |
+
self.metrics = MetricTracker(self.config.metrics, self.verbose)
|
95 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
96 |
+
|
97 |
+
self.train_step = 0
|
98 |
+
self.test_step = 0
|
99 |
+
|
100 |
+
def fit(self, model_name: str = "model"):
|
101 |
+
"""
|
102 |
+
Method to train the model.
|
103 |
+
|
104 |
+
:param model_name: Name of the model to be saved, defaults to "model"
|
105 |
+
:type model_name: str, optional
|
106 |
+
"""
|
107 |
+
|
108 |
+
loop = tqdm(range(self.config.epochs), leave=False)
|
109 |
+
|
110 |
+
for epoch in loop:
|
111 |
+
train_loss = self._train_epoch()
|
112 |
+
val_loss = self._test_epoch()
|
113 |
+
|
114 |
+
wandb.log({"train_loss": train_loss, "val_loss": val_loss, "epoch": epoch + 1})
|
115 |
+
|
116 |
+
if self.verbose:
|
117 |
+
print(f"| EPOCH: {epoch+1} | train_loss: {train_loss:.3f} | val_loss: {val_loss:.3f} |\n")
|
118 |
+
self.metrics.display()
|
119 |
+
|
120 |
+
if self.config.save_last_checkpoint:
|
121 |
+
torch.save(self.model.module.state_dict(), f"{model_name}.pth")
|
122 |
+
|
123 |
+
def _train_epoch(self, distill: bool = False):
|
124 |
+
"""
|
125 |
+
Method to perform one epoch of training.
|
126 |
+
|
127 |
+
:param distill: Flag to indicate if knowledge distillation is used, defaults to False
|
128 |
+
:type distill: bool, optional
|
129 |
+
:return: Average training loss for the epoch
|
130 |
+
:rtype: float
|
131 |
+
"""
|
132 |
+
|
133 |
+
if distill:
|
134 |
+
print("Distilling knowledge...", flush=True)
|
135 |
+
|
136 |
+
loop = tqdm(self.train_dl, leave=False)
|
137 |
+
self.model.train()
|
138 |
+
|
139 |
+
num_batches = len(self.train_dl)
|
140 |
+
train_loss = 0
|
141 |
+
|
142 |
+
for idx, (xb, yb) in enumerate(loop):
|
143 |
+
xb = xb.to(self.device)
|
144 |
+
yb = yb.to(self.device)
|
145 |
+
|
146 |
+
# forward
|
147 |
+
with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=not distill):
|
148 |
+
predictions = self.model(xb)
|
149 |
+
|
150 |
+
if distill:
|
151 |
+
loss = self.KDloss_fn(xb, predictions, yb)
|
152 |
+
else:
|
153 |
+
loss = self.loss_fn(predictions, yb)
|
154 |
+
|
155 |
+
loss /= self.config.num_accum
|
156 |
+
|
157 |
+
# backward
|
158 |
+
self.scaler.scale(loss).backward()
|
159 |
+
wandb.log({f"lr_param_group_{i}": lr for i, lr in enumerate(self.scheduler.get_last_lr())})
|
160 |
+
|
161 |
+
if ((idx + 1) % self.config.num_accum == 0) or (idx + 1 == num_batches):
|
162 |
+
self.scaler.step(self.optimizer)
|
163 |
+
self.scaler.update()
|
164 |
+
self.scheduler.step()
|
165 |
+
self.optimizer.zero_grad()
|
166 |
+
|
167 |
+
# update loop
|
168 |
+
loop.set_postfix(loss=loss.item())
|
169 |
+
self.train_step += 1
|
170 |
+
wandb.log({"train_loss_per_batch": loss.item(), "train_step": self.train_step})
|
171 |
+
train_loss += loss.item()
|
172 |
+
|
173 |
+
if distill:
|
174 |
+
if ((idx + 1) % 2500 == 0) and not (idx + 1 == num_batches):
|
175 |
+
val_loss = self._test_epoch()
|
176 |
+
wandb.log({"val_loss": val_loss})
|
177 |
+
self.model.train()
|
178 |
+
|
179 |
+
train_loss /= num_batches
|
180 |
+
|
181 |
+
return train_loss
|
182 |
+
|
183 |
+
def _test_epoch(self):
|
184 |
+
"""
|
185 |
+
Method to perform one epoch of validation/testing.
|
186 |
+
|
187 |
+
:return: Average validation/test loss for the epoch
|
188 |
+
:rtype: float
|
189 |
+
"""
|
190 |
+
|
191 |
+
loop = tqdm(self.valid_dl, leave=False)
|
192 |
+
self.model.eval()
|
193 |
+
|
194 |
+
num_batches = len(self.valid_dl)
|
195 |
+
preds = []
|
196 |
+
targets = []
|
197 |
+
test_loss = 0
|
198 |
+
|
199 |
+
with torch.no_grad():
|
200 |
+
for xb, yb in loop:
|
201 |
+
xb, yb = xb.to(self.device), yb.to(self.device)
|
202 |
+
pred = self.model(xb)
|
203 |
+
loss = self.loss_fn(pred, yb).item()
|
204 |
+
self.test_step += 1
|
205 |
+
wandb.log({"valid_loss_per_batch": loss, "test_step": self.test_step})
|
206 |
+
test_loss += loss
|
207 |
+
|
208 |
+
pred = torch.sigmoid(pred)
|
209 |
+
preds.extend(pred.cpu().numpy())
|
210 |
+
targets.extend(yb.cpu().numpy())
|
211 |
+
|
212 |
+
preds, targets = np.array(preds), np.array(targets)
|
213 |
+
self.metrics.update(preds, targets)
|
214 |
+
test_loss /= num_batches
|
215 |
+
|
216 |
+
return test_loss
|
217 |
+
|
218 |
+
|
219 |
+
class KDLearner(Learner):
|
220 |
+
"""
|
221 |
+
Knowledge Distillation Learner class for training a student model with knowledge distillation.
|
222 |
+
|
223 |
+
:param train_dl: Train data loader
|
224 |
+
:type train_dl: DataLoader
|
225 |
+
:param valid_dl: Validation data loader
|
226 |
+
:type valid_dl: DataLoader
|
227 |
+
:param student_model: Student model to be trained
|
228 |
+
:type student_model: nn.Module
|
229 |
+
:param teacher: Teacher model for knowledge distillation
|
230 |
+
:type teacher: nn.Module
|
231 |
+
:param thresholds: Thresholds for HardDistillationLoss
|
232 |
+
:type thresholds: List[float]
|
233 |
+
:param config: Configuration object for training
|
234 |
+
:type config: Config
|
235 |
+
"""
|
236 |
+
|
237 |
+
def __init__(self, train_dl, valid_dl, student_model, teacher, thresholds, config):
|
238 |
+
super().__init__(train_dl, valid_dl, student_model, config)
|
239 |
+
|
240 |
+
self.teacher = nn.DataParallel(freeze(teacher).to(self.device))
|
241 |
+
self.KDloss_fn = HardDistillationLoss(self.teacher, self.loss_fn, thresholds, self.device)
|
242 |
+
self.scaler = torch.cuda.amp.GradScaler(enabled=False)
|
243 |
+
|
244 |
+
def _train_epoch(self):
|
245 |
+
"""
|
246 |
+
Method to perform one epoch of training with knowledge distillation.
|
247 |
+
|
248 |
+
:return: Average training loss for the epoch
|
249 |
+
:rtype: float
|
250 |
+
"""
|
251 |
+
|
252 |
+
return super()._train_epoch(distill=True)
|
253 |
+
|
254 |
+
|
255 |
+
class MetricTracker:
|
256 |
+
"""
|
257 |
+
Metric Tracker class for tracking evaluation metrics during model validation.
|
258 |
+
This class is used to track and display evaluation metrics during model validation.
|
259 |
+
It keeps track of the results of the provided metric functions for each validation batch,
|
260 |
+
and logs them to Weights & Biases using wandb.log(). The display() method can be used
|
261 |
+
to print the tracked metric results, if verbose is set to True during initialization.
|
262 |
+
|
263 |
+
:param metrics: List of metric functions to track
|
264 |
+
:type metrics: List[Callable]
|
265 |
+
:param verbose: Flag to indicate whether to print the results or not, defaults to True
|
266 |
+
:type verbose: bool, optional
|
267 |
+
"""
|
268 |
+
|
269 |
+
def __init__(self, metrics, verbose: bool = True):
|
270 |
+
self.metrics_fn = [getattr(metrics_module, metric) for metric in metrics]
|
271 |
+
self.verbose = verbose
|
272 |
+
self.result = None
|
273 |
+
|
274 |
+
def update(self, preds, targets):
|
275 |
+
"""
|
276 |
+
Update the metric tracker with the latest predictions and targets.
|
277 |
+
|
278 |
+
:param preds: Model predictions
|
279 |
+
:type preds: torch.Tensor
|
280 |
+
:param targets: Ground truth targets
|
281 |
+
:type targets: torch.Tensor
|
282 |
+
"""
|
283 |
+
|
284 |
+
self.result = {metric.__name__: metric(preds, targets) for metric in self.metrics_fn}
|
285 |
+
wandb.log(self.result)
|
286 |
+
|
287 |
+
def display(self):
|
288 |
+
"""Display the tracked metric results."""
|
289 |
+
|
290 |
+
for k, v in self.result.items():
|
291 |
+
print(f"{k}: {v:.2f}")
|
292 |
+
|
293 |
+
|
294 |
+
def get_preds(data: DataLoader, model: nn.Module, device: str = "cpu") -> Tuple[np.ndarray, np.ndarray]:
|
295 |
+
"""
|
296 |
+
Get predictions and targets from a data loader and a PyTorch model.
|
297 |
+
|
298 |
+
:param data: A PyTorch DataLoader containing the data to predict on.
|
299 |
+
:type data: torch.utils.data.DataLoader
|
300 |
+
:param model: A PyTorch model to use for predictions.
|
301 |
+
:type model: torch.nn.Module
|
302 |
+
:param device: The device to use for predictions (default is "cpu").
|
303 |
+
:type device: str
|
304 |
+
:raises TypeError: If any of the input arguments is of an incorrect type.
|
305 |
+
:return: A tuple containing two NumPy arrays: the predictions and the targets.
|
306 |
+
:rtype: Tuple[numpy.ndarray, numpy.ndarray]
|
307 |
+
"""
|
308 |
+
|
309 |
+
if not isinstance(data, DataLoader):
|
310 |
+
raise TypeError("The 'data' argument must be a PyTorch DataLoader.")
|
311 |
+
if not isinstance(model, nn.Module):
|
312 |
+
raise TypeError("The 'model' argument must be a PyTorch model.")
|
313 |
+
if not isinstance(device, str):
|
314 |
+
raise TypeError("The 'device' argument must be a string.")
|
315 |
+
|
316 |
+
loop = tqdm(data, leave=False)
|
317 |
+
model = model.to(device)
|
318 |
+
model.eval()
|
319 |
+
|
320 |
+
preds = []
|
321 |
+
targets = []
|
322 |
+
|
323 |
+
with torch.no_grad():
|
324 |
+
for xb, yb in loop:
|
325 |
+
xb, yb = xb.to(device), yb.to(device)
|
326 |
+
pred = model(xb)
|
327 |
+
pred = torch.sigmoid(pred)
|
328 |
+
preds.extend(pred.cpu().numpy())
|
329 |
+
targets.extend(yb.cpu().numpy())
|
330 |
+
|
331 |
+
preds, targets = np.array(preds), np.array(targets)
|
332 |
+
|
333 |
+
return preds, targets
|
src/modeling/loss.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision.ops import sigmoid_focal_loss
|
8 |
+
|
9 |
+
|
10 |
+
class FocalLoss(nn.Module):
|
11 |
+
"""
|
12 |
+
Focal Loss implementation.
|
13 |
+
|
14 |
+
This class defines the Focal Loss, which is a variant of the Binary Cross Entropy (BCE) loss that is
|
15 |
+
designed to address the problem of class imbalance in binary classification tasks.
|
16 |
+
The Focal Loss introduces two hyperparameters, alpha and gamma, to control the balance between easy
|
17 |
+
and hard examples during training.
|
18 |
+
|
19 |
+
:param alpha: The balancing parameter between positive and negative examples. A float value between 0 and 1.
|
20 |
+
If set to -1, no balancing is applied. Default is 0.25.
|
21 |
+
:type alpha: float
|
22 |
+
:param gamma: The focusing parameter to control the emphasis on hard examples. A positive integer. Default is 2.
|
23 |
+
:type gamma: int
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, alpha: float = 0.25, gamma: int = 2):
|
27 |
+
super().__init__()
|
28 |
+
self.loss_fn = partial(sigmoid_focal_loss, alpha=alpha, gamma=gamma, reduction="mean")
|
29 |
+
|
30 |
+
def forward(self, inputs, targets):
|
31 |
+
"""
|
32 |
+
Compute the Focal Loss.
|
33 |
+
|
34 |
+
:param inputs: The predicted inputs from the model.
|
35 |
+
:type inputs: torch.Tensor
|
36 |
+
:param targets: The ground truth targets.
|
37 |
+
:type targets: torch.Tensor
|
38 |
+
:return: The computed Focal Loss.
|
39 |
+
:rtype: torch.Tensor
|
40 |
+
:raises ValueError: If the inputs and targets have different shapes.
|
41 |
+
"""
|
42 |
+
|
43 |
+
return self.loss_fn(inputs=inputs, targets=targets)
|
44 |
+
|
45 |
+
|
46 |
+
class HardDistillationLoss(nn.Module):
|
47 |
+
"""Hard Distillation Loss implementation.
|
48 |
+
|
49 |
+
This class defines the Hard Distillation Loss, which is used for model distillation,
|
50 |
+
a technique used to transfer knowledge from a large, complex teacher model to a smaller,
|
51 |
+
simpler student model. The Hard Distillation Loss computes the loss by comparing the outputs
|
52 |
+
of the student model and the teacher model using a provided loss function. It also introduces a
|
53 |
+
threshold parameter to convert the teacher model outputs to binary labels for the distillation process.
|
54 |
+
|
55 |
+
:param teacher: The teacher model used for distillation.
|
56 |
+
:type teacher: torch.nn.Module
|
57 |
+
:param loss_fn: The loss function used for computing the distillation loss.
|
58 |
+
:type loss_fn: torch.nn.Module
|
59 |
+
:param threshold: The threshold value used to convert teacher model outputs to binary labels.
|
60 |
+
Can be a list or numpy array of threshold values.
|
61 |
+
:type threshold: Union[list, np.array]
|
62 |
+
:param device: The device to be used for computation. Default is "cuda".
|
63 |
+
:type device: str
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, teacher: nn.Module, loss_fn: nn.Module, threshold: Union[list, np.array], device: str = "cuda"):
|
67 |
+
super().__init__()
|
68 |
+
self.teacher = teacher
|
69 |
+
self.loss_fn = loss_fn
|
70 |
+
self.threshold = torch.tensor(threshold).to(device)
|
71 |
+
|
72 |
+
def forward(self, inputs, student_outputs, targets):
|
73 |
+
"""
|
74 |
+
Compute the Hard Distillation Loss.
|
75 |
+
|
76 |
+
:param inputs: The input data fed to the student model.
|
77 |
+
:type inputs: torch.Tensor
|
78 |
+
:param student_outputs: The output predictions from the student model, which consists of
|
79 |
+
both classification and distillation outputs.
|
80 |
+
:type student_outputs: tuple
|
81 |
+
:param targets: The ground truth targets.
|
82 |
+
:type targets: torch.Tensor
|
83 |
+
:return: The computed Hard Distillation Loss.
|
84 |
+
:rtype: torch.Tensor
|
85 |
+
:raises ValueError: If the inputs and targets have different shapes.
|
86 |
+
"""
|
87 |
+
|
88 |
+
outputs_cls, outputs_dist = student_outputs
|
89 |
+
|
90 |
+
teacher_outputs = torch.sigmoid(self.teacher(inputs))
|
91 |
+
teacher_labels = (teacher_outputs > self.threshold).float()
|
92 |
+
|
93 |
+
base_loss = self.loss_fn(outputs_cls, targets)
|
94 |
+
teacher_loss = self.loss_fn(outputs_dist, teacher_labels)
|
95 |
+
|
96 |
+
return (base_loss + teacher_loss) / 2
|
src/modeling/metrics.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from sklearn.metrics import (
|
3 |
+
accuracy_score,
|
4 |
+
average_precision_score,
|
5 |
+
f1_score,
|
6 |
+
hamming_loss,
|
7 |
+
precision_recall_curve,
|
8 |
+
zero_one_loss,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
def hamming_score(preds, targets, thresholds: np.array = None):
|
13 |
+
"""Compute Hamming Score.
|
14 |
+
|
15 |
+
This function computes the Hamming Score, a performance metric used for multi-label classification tasks.
|
16 |
+
The Hamming Score measures the similarity between the predicted labels and the ground truth labels, where
|
17 |
+
a higher score indicates better prediction accuracy.
|
18 |
+
|
19 |
+
:param preds: The predicted labels.
|
20 |
+
:type preds: numpy array
|
21 |
+
:param targets: The ground truth labels.
|
22 |
+
:type targets: numpy array
|
23 |
+
:return: The computed Hamming Score.
|
24 |
+
:rtype: int
|
25 |
+
"""
|
26 |
+
if thresholds is None:
|
27 |
+
thresholds = optimize_accuracy(preds, targets)
|
28 |
+
|
29 |
+
preds = (preds > thresholds).astype(int)
|
30 |
+
return 1 - hamming_loss(targets, preds)
|
31 |
+
|
32 |
+
|
33 |
+
def zero_one_score(preds, targets, thresholds: np.array = None):
|
34 |
+
"""
|
35 |
+
Compute Zero-One Score.
|
36 |
+
|
37 |
+
This function computes the Zero-One Score, a performance metric used for
|
38 |
+
multi-label classification tasks. The Zero-One Score measures the similarity
|
39 |
+
between the predicted labels and the ground truth labels, where a higher score
|
40 |
+
indicates better prediction accuracy. The Zero-One Score ranges from 0 to 1, with 1 being a perfect match.
|
41 |
+
|
42 |
+
:param preds: The predicted labels.
|
43 |
+
:type preds: numpy array
|
44 |
+
:param targets: The ground truth labels.
|
45 |
+
:type targets: numpy array
|
46 |
+
:return: The computed Zero-One Score.
|
47 |
+
:rtype: int
|
48 |
+
"""
|
49 |
+
|
50 |
+
if thresholds is None:
|
51 |
+
thresholds = optimize_accuracy(preds, targets)
|
52 |
+
|
53 |
+
preds = (preds > thresholds).astype(int)
|
54 |
+
return 1 - zero_one_loss(targets, preds, normalize=True)
|
55 |
+
|
56 |
+
|
57 |
+
def mean_f1_score(preds, targets, thresholds: np.array = None):
|
58 |
+
"""Compute Mean F1 Score.
|
59 |
+
|
60 |
+
This function computes the Mean F1 Score, a performance metric used for multi-label
|
61 |
+
classification tasks. The Mean F1 Score measures the trade-off between precision and recall,
|
62 |
+
where a higher score indicates better prediction accuracy. The Mean F1 Score ranges from
|
63 |
+
0 to 1, with 1 being a perfect match.
|
64 |
+
|
65 |
+
:param preds: The predicted labels.
|
66 |
+
:type preds: numpy array
|
67 |
+
:param targets: The ground truth labels.
|
68 |
+
:type targets: numpy array
|
69 |
+
:return: The computed Mean F1 Score.
|
70 |
+
:rtype: int
|
71 |
+
"""
|
72 |
+
if thresholds is None:
|
73 |
+
thresholds = optimize_f1_score(preds, targets)
|
74 |
+
|
75 |
+
preds = (preds > thresholds).astype(int)
|
76 |
+
return f1_score(targets, preds, average="samples", zero_division=0)
|
77 |
+
|
78 |
+
|
79 |
+
def per_instr_f1_score(preds, targets, thresholds: np.array = None):
|
80 |
+
"""Compute Per-Instrument F1 Score.
|
81 |
+
|
82 |
+
This function computes the F1 Score for each instrument separately in a multi-label
|
83 |
+
classification task. The Per-Instrument F1 Score measures the prediction accuracy for
|
84 |
+
each instrument class independently. The F1 Score is the harmonic mean of precision and recall,
|
85 |
+
where a higher score indicates better prediction accuracy. The Per-Instrument F1 Score ranges
|
86 |
+
from 0 to 1, with 1 being a perfect match.
|
87 |
+
|
88 |
+
:param preds: The predicted labels.
|
89 |
+
:type preds: numpy array
|
90 |
+
:param targets: The ground truth labels.
|
91 |
+
:type targets: numpy array
|
92 |
+
:return: The computed Per-Instrument F1 Score.
|
93 |
+
:rtype: numpy array
|
94 |
+
"""
|
95 |
+
|
96 |
+
if thresholds is None:
|
97 |
+
thresholds = optimize_f1_score(preds, targets)
|
98 |
+
|
99 |
+
preds = (preds > thresholds).astype(int)
|
100 |
+
return f1_score(targets, preds, average=None, zero_division=0)
|
101 |
+
|
102 |
+
|
103 |
+
def mean_average_precision(preds, targets):
|
104 |
+
"""
|
105 |
+
Compute mean Average Precision (mAP).
|
106 |
+
|
107 |
+
This function computes the mean Average Precision (mAP), a performance metric used
|
108 |
+
for multi-label classification tasks. The mAP measures the average precision across
|
109 |
+
all classes, taking into account the precision-recall trade-off, where a higher score
|
110 |
+
indicates better prediction accuracy.
|
111 |
+
|
112 |
+
:param preds: The predicted probabilities or scores.
|
113 |
+
:type preds: numpy array
|
114 |
+
:param targets: The ground truth labels.
|
115 |
+
:type targets: numpy array
|
116 |
+
:return: The computed mAP score.
|
117 |
+
:rtype: int
|
118 |
+
"""
|
119 |
+
|
120 |
+
return average_precision_score(targets, preds, average="samples")
|
121 |
+
|
122 |
+
|
123 |
+
def optimize_f1_score(preds, targets):
|
124 |
+
"""
|
125 |
+
Optimize Threshold.
|
126 |
+
|
127 |
+
This function optimizes the threshold for binary classification based on the predicted probabilities
|
128 |
+
and ground truth labels. It computes the precision, recall, and F1 Score for each class separately
|
129 |
+
using the precision_recall_curve function from sklearn.metrics module. It then selects the threshold
|
130 |
+
that maximizes the F1 Score for each class.
|
131 |
+
|
132 |
+
:param preds: The predicted probabilities.
|
133 |
+
:type preds: numpy array
|
134 |
+
:param targets: The ground truth labels.
|
135 |
+
:type targets: numpy array
|
136 |
+
:return: The optimized thresholds for binary classification.
|
137 |
+
:rtype: numpy array
|
138 |
+
"""
|
139 |
+
|
140 |
+
label_thresholds = np.empty(preds.shape[1])
|
141 |
+
|
142 |
+
for i in range(preds.shape[1]):
|
143 |
+
precision, recall, thresholds = precision_recall_curve(targets[:, i], preds[:, i])
|
144 |
+
fscore = (2 * precision * recall) / (precision + recall)
|
145 |
+
ix = np.argmax(fscore)
|
146 |
+
best_thresh = thresholds[ix]
|
147 |
+
label_thresholds[i] = best_thresh
|
148 |
+
|
149 |
+
return label_thresholds
|
150 |
+
|
151 |
+
|
152 |
+
def optimize_accuracy(preds, targets):
|
153 |
+
"""
|
154 |
+
Determine the optimal threshold for each label, based on the predicted probabilities and the true targets,
|
155 |
+
in order to maximize the accuracy of the predictions.
|
156 |
+
|
157 |
+
:param preds: A 2D NumPy array containing the predicted probabilities for each label.
|
158 |
+
:type preds: numpy.ndarray
|
159 |
+
:param targets: A 2D NumPy array containing the true binary targets for each label.
|
160 |
+
:type targets: numpy.ndarray
|
161 |
+
:raises ValueError: If the input arrays are not 2D arrays or have incompatible shapes.
|
162 |
+
:return: A 1D NumPy array containing the optimal threshold for each label.
|
163 |
+
:rtype: numpy.ndarray
|
164 |
+
"""
|
165 |
+
|
166 |
+
# Vary the threshold for each label and calculate accuracy for each threshold
|
167 |
+
thresholds = np.arange(0.0001, 1, 0.0001)
|
168 |
+
best_thresholds = np.empty(preds.shape[1])
|
169 |
+
for i in range(preds.shape[1]):
|
170 |
+
accuracies = []
|
171 |
+
for th in thresholds:
|
172 |
+
y_pred = (preds[:, i] >= th).astype(int) # Convert probabilities to binary predictions using the threshold
|
173 |
+
acc = accuracy_score(targets[:, i], y_pred)
|
174 |
+
accuracies.append(acc)
|
175 |
+
# Find the threshold that gives the highest accuracy for this label
|
176 |
+
best_idx = np.argmax(accuracies)
|
177 |
+
best_thresholds[i] = thresholds[best_idx]
|
178 |
+
|
179 |
+
return best_thresholds
|
src/modeling/models.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from warnings import warn
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from transformers import ASTConfig, ASTModel
|
7 |
+
|
8 |
+
|
9 |
+
class StudentAST(nn.Module):
|
10 |
+
"""
|
11 |
+
A student model for audio classification using the AST architecture.
|
12 |
+
|
13 |
+
:param n_classes: The number of classes to classify.
|
14 |
+
:type n_classes: int
|
15 |
+
:param hidden_size: The number of units in the hidden layers, defaults to 384.
|
16 |
+
:type hidden_size: int, optional
|
17 |
+
:param num_heads: The number of attention heads to use, defaults to 6.
|
18 |
+
:type num_heads: int, optional
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, n_classes: int, hidden_size: int = 384, num_heads: int = 6):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
config = ASTConfig(hidden_size=hidden_size, num_attention_heads=num_heads, intermediate_size=hidden_size * 4)
|
25 |
+
self.base_model = ASTModel(config=config)
|
26 |
+
self.classifier = StudentClassificationHead(hidden_size, n_classes)
|
27 |
+
|
28 |
+
def forward(self, x: torch.Tensor):
|
29 |
+
"""
|
30 |
+
Forward pass of the student model.
|
31 |
+
|
32 |
+
:param x: The input tensor of shape [batch_size, sequence_length, input_dim].
|
33 |
+
:type x: torch.Tensor
|
34 |
+
:return: The output tensor of shape [batch_size, n_classes].
|
35 |
+
:rtype: torch.Tensor
|
36 |
+
"""
|
37 |
+
|
38 |
+
x = self.base_model(x)[0]
|
39 |
+
x = self.classifier(x)
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class StudentClassificationHead(nn.Module):
|
44 |
+
"""
|
45 |
+
A classification head for the student model.
|
46 |
+
|
47 |
+
:param emb_size: The size of the embedding.
|
48 |
+
:type emb_size: int
|
49 |
+
:param n_classes: The number of classes to classify.
|
50 |
+
:type n_classes: int
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, emb_size: int, n_classes: int):
|
54 |
+
super().__init__()
|
55 |
+
|
56 |
+
self.cls_head = nn.Linear(emb_size, n_classes)
|
57 |
+
self.dist_head = nn.Linear(emb_size, n_classes)
|
58 |
+
|
59 |
+
def forward(self, x: torch.Tensor):
|
60 |
+
"""
|
61 |
+
Forward pass of the classification head.
|
62 |
+
|
63 |
+
:param x: The input tensor of shape [batch_size, emb_size*2].
|
64 |
+
:type x: torch.Tensor
|
65 |
+
:return: The output tensor of shape [batch_size, n_classes].
|
66 |
+
:rtype: torch.Tensor
|
67 |
+
"""
|
68 |
+
|
69 |
+
x_cls, x_dist = x[:, 0], x[:, 1]
|
70 |
+
x_cls_head = self.cls_head(x_cls)
|
71 |
+
x_dist_head = self.dist_head(x_dist)
|
72 |
+
|
73 |
+
if self.training:
|
74 |
+
x = x_cls_head, x_dist_head
|
75 |
+
else:
|
76 |
+
x = (x_cls_head + x_dist_head) / 2
|
77 |
+
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class ASTPretrained(nn.Module):
|
82 |
+
"""
|
83 |
+
This class implements a PyTorch module for a pre-trained Audio Set Transformer (AST) model
|
84 |
+
fine-tuned on MIT's dataset for audio event classification.
|
85 |
+
|
86 |
+
:param n_classes: The number of classes for audio event classification.
|
87 |
+
:type n_classes: int
|
88 |
+
:param dropout: The dropout probability for the fully connected layer, defaults to 0.5.
|
89 |
+
:type dropout: float, optional
|
90 |
+
:raises ValueError: If n_classes is not positive.
|
91 |
+
:raises TypeError: If dropout is not a float or is not between 0 and 1.
|
92 |
+
:return: The output tensor of shape [batch_size, n_classes] containing the probabilities of each class.
|
93 |
+
:rtype: torch.Tensor
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(self, n_classes: int, download_weights: bool = True, freeze_body: bool = False, dropout: float = 0.5):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
if download_weights:
|
100 |
+
self.base_model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
|
101 |
+
else:
|
102 |
+
config = ASTConfig()
|
103 |
+
self.base_model = ASTModel(config=config)
|
104 |
+
|
105 |
+
if freeze_body:
|
106 |
+
self.base_model = freeze(self.base_model)
|
107 |
+
|
108 |
+
fc_in = self.base_model.config.hidden_size
|
109 |
+
|
110 |
+
self.classifier = nn.Sequential(
|
111 |
+
nn.LayerNorm((fc_in,), eps=1e-12), nn.Dropout(p=dropout), nn.Linear(fc_in, n_classes)
|
112 |
+
)
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
"""Passes the input tensor through the pre-trained Audio Set Transformer (AST) model
|
116 |
+
followed by a fully connected layer.
|
117 |
+
|
118 |
+
:param x: The input tensor of shape [batch_size, seq_len, num_features].
|
119 |
+
:type x: torch.Tensor
|
120 |
+
:return: The output tensor of shape [batch_size, n_classes] containing the probabilities of each class.
|
121 |
+
:rtype: torch.Tensor
|
122 |
+
:raises ValueError: If the shape of x is not [batch_size, seq_len, num_features].
|
123 |
+
"""
|
124 |
+
|
125 |
+
x = self.base_model(x)[1]
|
126 |
+
x = self.classifier(x)
|
127 |
+
return x
|
128 |
+
|
129 |
+
|
130 |
+
def layerwise_lr_decay(config, model: ASTModel):
|
131 |
+
"""
|
132 |
+
LLRD (Layer-wise Learning Rate Decay) function computes the learning rate for each layer in a deep neural network
|
133 |
+
using a specific decay rate and a base learning rate for the optimizer.
|
134 |
+
|
135 |
+
:param config: A configuration object that contains the parameters required for LLRD.
|
136 |
+
:type config: Any
|
137 |
+
:param model: A PyTorch neural network model.
|
138 |
+
:type model: ASTModel
|
139 |
+
|
140 |
+
:raises Warning: If the configuration object does not contain the LLRD parameters.
|
141 |
+
|
142 |
+
:return: A dictionary containing the optimizer parameters (parameters, weight decay, and learning rate)
|
143 |
+
for each layer.
|
144 |
+
:rtype: dict
|
145 |
+
"""
|
146 |
+
|
147 |
+
try:
|
148 |
+
config = config.LLRD
|
149 |
+
except Exception:
|
150 |
+
warn("No LLRD found in config. Learner will use single lr for whole model.")
|
151 |
+
return None
|
152 |
+
|
153 |
+
lr = config["base_lr"]
|
154 |
+
weight_decay = config["weight_decay"]
|
155 |
+
no_decay = ["bias", "layernorm"]
|
156 |
+
body = ["embeddings", "encoder.layer"]
|
157 |
+
head_params = [(n, p) for n, p in model.named_parameters() if not any(body_param in n for body_param in body)]
|
158 |
+
optimizer_grouped_parameters = [
|
159 |
+
{
|
160 |
+
"params": [p for n, p in head_params if not any(nd in n for nd in no_decay)],
|
161 |
+
"weight_decay": weight_decay,
|
162 |
+
"lr": lr,
|
163 |
+
},
|
164 |
+
{
|
165 |
+
"params": [p for n, p in head_params if any(nd in n for nd in no_decay)],
|
166 |
+
"weight_decay": 0.0,
|
167 |
+
"lr": lr,
|
168 |
+
},
|
169 |
+
]
|
170 |
+
|
171 |
+
# initialize lrs for every layer
|
172 |
+
layers = [getattr(model.module, config["body"]).embeddings] + list(
|
173 |
+
getattr(model.module, config["body"]).encoder.layer
|
174 |
+
)
|
175 |
+
layers.reverse()
|
176 |
+
for layer in layers:
|
177 |
+
lr *= config["lr_decay_rate"]
|
178 |
+
optimizer_grouped_parameters += [
|
179 |
+
{
|
180 |
+
"params": [p for n, p in layer.named_parameters() if not any(nd in n for nd in no_decay)],
|
181 |
+
"weight_decay": weight_decay,
|
182 |
+
"lr": lr,
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"params": [p for n, p in layer.named_parameters() if any(nd in n for nd in no_decay)],
|
186 |
+
"weight_decay": 0.0,
|
187 |
+
"lr": lr,
|
188 |
+
},
|
189 |
+
]
|
190 |
+
|
191 |
+
return optimizer_grouped_parameters
|
192 |
+
|
193 |
+
|
194 |
+
def freeze(model: nn.Module):
|
195 |
+
"""
|
196 |
+
Freeze function sets the requires_grad attribute to False for all parameters
|
197 |
+
in the given PyTorch neural network model. This is used to freeze the weights of
|
198 |
+
the model during training or inference.
|
199 |
+
|
200 |
+
:param model: A PyTorch neural network model.
|
201 |
+
:type model: nn.Module
|
202 |
+
|
203 |
+
:return: The same model with requires_grad attribute set to False for all parameters.
|
204 |
+
:rtype: nn.Module
|
205 |
+
"""
|
206 |
+
|
207 |
+
model.eval()
|
208 |
+
for param in model.parameters():
|
209 |
+
param.requires_grad = False
|
210 |
+
|
211 |
+
return model
|
212 |
+
|
213 |
+
|
214 |
+
def unfreeze(model: nn.Module):
|
215 |
+
"""
|
216 |
+
Unfreeze the model by setting requires_grad to True for all parameters.
|
217 |
+
|
218 |
+
:param model: The model to unfreeze.
|
219 |
+
:type model: nn.Module
|
220 |
+
:return: The unfrozen model.
|
221 |
+
:rtype: nn.Module
|
222 |
+
"""
|
223 |
+
|
224 |
+
model.train()
|
225 |
+
for param in model.parameters():
|
226 |
+
param.requires_grad = True
|
227 |
+
|
228 |
+
return model
|
229 |
+
|
230 |
+
|
231 |
+
def interpolate_params(student: nn.Module, teacher: nn.Module):
|
232 |
+
"""
|
233 |
+
Interpolate parameters between two models. This function scales the parameters of the
|
234 |
+
teacher model to match the shape of the corresponding parameters in the student model
|
235 |
+
using bilinear interpolation. If the shapes of the parameters in the two models are already the same,
|
236 |
+
the parameters are unchanged.
|
237 |
+
|
238 |
+
:param student: The student model.
|
239 |
+
:type student: nn.Module
|
240 |
+
:param teacher: The teacher model.
|
241 |
+
:type teacher: nn.Module
|
242 |
+
:return: A dictionary of interpolated parameters for the student model.
|
243 |
+
:rtype: dict
|
244 |
+
"""
|
245 |
+
|
246 |
+
new_params = {}
|
247 |
+
|
248 |
+
# Iterate over the parameters in the first model
|
249 |
+
for name, param in teacher.base_model.named_parameters():
|
250 |
+
# Scale the parameter using interpolate if its shape is different from that of the second model
|
251 |
+
target_param = student.base_model.state_dict()[name]
|
252 |
+
if param.shape != target_param.shape:
|
253 |
+
squeeze_count = 0
|
254 |
+
permuted = False
|
255 |
+
while param.ndim < 4:
|
256 |
+
param = param.unsqueeze(0)
|
257 |
+
squeeze_count += 1
|
258 |
+
|
259 |
+
if param.shape[0] > 1:
|
260 |
+
param = param.permute(1, 2, 3, 0)
|
261 |
+
target_param = target_param.permute(1, 2, 3, 0)
|
262 |
+
permuted = True
|
263 |
+
|
264 |
+
if target_param.ndim < 2:
|
265 |
+
target_param = target_param.unsqueeze(0)
|
266 |
+
|
267 |
+
scaled_param = F.interpolate(param, size=(target_param.shape[-2:]), mode="bilinear")
|
268 |
+
|
269 |
+
while squeeze_count > 0:
|
270 |
+
scaled_param = scaled_param.squeeze(0)
|
271 |
+
squeeze_count -= 1
|
272 |
+
|
273 |
+
if permuted:
|
274 |
+
scaled_param = scaled_param.permute(-1, 0, 1, 2)
|
275 |
+
|
276 |
+
else:
|
277 |
+
scaled_param = param
|
278 |
+
new_params[name] = scaled_param
|
279 |
+
|
280 |
+
return new_params
|
281 |
+
|
282 |
+
|
283 |
+
def average_model_weights(model_weights_list):
|
284 |
+
"""
|
285 |
+
Compute the average weights of a list of PyTorch models.
|
286 |
+
|
287 |
+
:param model_weights_list: A list of file paths to PyTorch model weight files.
|
288 |
+
:type model_weights_list: List[str]
|
289 |
+
:raises ValueError: If the input list is empty.
|
290 |
+
:return: A dictionary containing the average weights of the models.
|
291 |
+
:rtype: Dict[str, torch.Tensor]
|
292 |
+
"""
|
293 |
+
|
294 |
+
if not model_weights_list:
|
295 |
+
raise ValueError("The input list cannot be empty.")
|
296 |
+
|
297 |
+
num_models = len(model_weights_list)
|
298 |
+
averaged_weights = {}
|
299 |
+
|
300 |
+
# Load the first model weights
|
301 |
+
state_dict = torch.load(model_weights_list[0])
|
302 |
+
|
303 |
+
# Iterate through the remaining models and add their weights to the first model's weights
|
304 |
+
for i in range(1, num_models):
|
305 |
+
state_dict_i = torch.load(model_weights_list[i])
|
306 |
+
for key in state_dict.keys():
|
307 |
+
state_dict[key] += state_dict_i[key]
|
308 |
+
|
309 |
+
# Compute the average of the weights
|
310 |
+
for key in state_dict.keys():
|
311 |
+
averaged_weights[key] = state_dict[key] / num_models
|
312 |
+
|
313 |
+
return averaged_weights
|
src/modeling/preprocess.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import soundfile as sf
|
10 |
+
from joblib import Parallel, delayed
|
11 |
+
from sklearn.model_selection import StratifiedGroupKFold
|
12 |
+
from tqdm.autonotebook import tqdm
|
13 |
+
|
14 |
+
from modeling.transforms import LabelsFromTxt, ParentMultilabel
|
15 |
+
from modeling.utils import get_file_info, sync_bpm, sync_onset, sync_pitch
|
16 |
+
|
17 |
+
|
18 |
+
def generate_metadata(
|
19 |
+
data_dir: Union[str, Path],
|
20 |
+
save_path: str = ".",
|
21 |
+
subset: str = "train",
|
22 |
+
extract_music_features: bool = False,
|
23 |
+
n_jobs: int = -2,
|
24 |
+
):
|
25 |
+
"""
|
26 |
+
Generate metadata CSV file containing information about audio files in a directory.
|
27 |
+
|
28 |
+
:param data_dir: Directory containing audio files.
|
29 |
+
:type data_dir: Union[str, Path]
|
30 |
+
:param save_path: Directory path to save metadata CSV file.
|
31 |
+
:type save_path: str
|
32 |
+
:param subset: Subset of the dataset (train or test), defaults to 'train'.
|
33 |
+
:type subset: str
|
34 |
+
:param extract_music_features: Flag to indicate whether to extract music features or not, defaults to False.
|
35 |
+
:type extract_music_features: bool
|
36 |
+
:param n_jobs: Number of parallel jobs to run, defaults to -2.
|
37 |
+
:type n_jobs: int
|
38 |
+
:raises FileNotFoundError: If the provided data directory does not exist.
|
39 |
+
:return: DataFrame containing the metadata information.
|
40 |
+
:rtype: pandas.DataFrame
|
41 |
+
"""
|
42 |
+
|
43 |
+
data_dir = Path(data_dir) if isinstance(data_dir, str) else data_dir
|
44 |
+
|
45 |
+
if subset == "train":
|
46 |
+
pattern = r"(.*)__[\d]+$"
|
47 |
+
label_extractor = ParentMultilabel()
|
48 |
+
else:
|
49 |
+
pattern = r"(.*)-[\d]+$"
|
50 |
+
label_extractor = LabelsFromTxt()
|
51 |
+
|
52 |
+
sound_files = list(data_dir.glob("**/*.wav"))
|
53 |
+
output = Parallel(n_jobs=n_jobs)(delayed(get_file_info)(path, extract_music_features) for path in tqdm(sound_files))
|
54 |
+
|
55 |
+
df = pd.DataFrame(data=output)
|
56 |
+
|
57 |
+
df["fname"] = df.path.map(lambda x: Path(x).stem)
|
58 |
+
df["song_name"] = df.fname.str.extract(pattern)
|
59 |
+
df["inst"] = df.path.map(lambda x: "-".join(sorted(list(label_extractor(x)))))
|
60 |
+
df["label_count"] = df.inst.map(lambda x: len(x.split("-")))
|
61 |
+
|
62 |
+
df.to_csv(f"{save_path}/metadata_{subset}.csv", index=False)
|
63 |
+
|
64 |
+
return df
|
65 |
+
|
66 |
+
|
67 |
+
def create_test_split(metadata_path: str, txt_save_path: str, random_state: Optional[int] = None):
|
68 |
+
"""Create test split by generating a list of test songs and saving them to a text file.
|
69 |
+
|
70 |
+
:param metadata_path: Path to the CSV file containing metadata of all songs
|
71 |
+
:type metadata_path: str
|
72 |
+
:param txt_save_path: Path to the directory where the text file containing test songs will be saved
|
73 |
+
:type txt_save_path: str
|
74 |
+
:param random_state: Seed value for the random number generator, defaults to None
|
75 |
+
:type random_state: int, optional
|
76 |
+
:raises TypeError: If metadata_path or txt_save_path is not a string or if random_state is not an integer or None
|
77 |
+
:raises FileNotFoundError: If metadata_path does not exist
|
78 |
+
:raises PermissionError: If the program does not have permission to write to txt_save_path
|
79 |
+
:return: None
|
80 |
+
:rtype: None
|
81 |
+
"""
|
82 |
+
|
83 |
+
df = pd.read_csv(metadata_path)
|
84 |
+
kf = StratifiedGroupKFold(n_splits=2, shuffle=True, random_state=random_state)
|
85 |
+
splits = kf.split(df.fname, df.inst, groups=df.song_name)
|
86 |
+
_, test = list(splits)[0]
|
87 |
+
|
88 |
+
test_songs = df.iloc[test].fname.sort_values().to_numpy()
|
89 |
+
|
90 |
+
with open(f"{txt_save_path}/test_songs.txt", "w") as f:
|
91 |
+
# iterate over the list of names and write each one to a new line in the file
|
92 |
+
for song in test_songs:
|
93 |
+
f.write(song + "\n")
|
94 |
+
|
95 |
+
|
96 |
+
class IRMASPreprocessor:
|
97 |
+
"""
|
98 |
+
A class to preprocess IRMAS dataset metadata and create a mapping between
|
99 |
+
file paths and their corresponding instrument labels.
|
100 |
+
|
101 |
+
:param metadata: A pandas DataFrame or path to csv file containing metadata, defaults to None
|
102 |
+
:type metadata: Union[pd.DataFrame, str], optional
|
103 |
+
:param data_dir: Path to the directory containing the IRMAS dataset, defaults to None
|
104 |
+
:type data_dir: Union[str, Path], optional
|
105 |
+
:param sample_rate: Sample rate of the audio files, defaults to 16000
|
106 |
+
:type sample_rate: int, optional
|
107 |
+
|
108 |
+
:raises AssertionError: Raised when metadata is None and data_dir is also None.
|
109 |
+
|
110 |
+
:return: An instance of IRMASPreprocessor
|
111 |
+
:rtype: IRMASPreprocessor
|
112 |
+
"""
|
113 |
+
|
114 |
+
def __init__(
|
115 |
+
self, metadata: Union[pd.DataFrame, str] = None, data_dir: Union[str, Path] = None, sample_rate: int = 16000
|
116 |
+
):
|
117 |
+
if metadata is not None:
|
118 |
+
self.metadata = pd.read_csv(metadata) if isinstance(metadata, str) else metadata
|
119 |
+
if data_dir is not None:
|
120 |
+
self.metadata["path"] = self.metadata.apply(lambda x: f"{data_dir}/{x.inst}/{x.fname}.wav", axis=1)
|
121 |
+
else:
|
122 |
+
assert data_dir is not None, "No metadata found. Need to provide data directory"
|
123 |
+
self.metadata = generate_metadata(data_dir=data_dir, subset="train", extract_music_features=True)
|
124 |
+
|
125 |
+
self.instruments = self.metadata.inst.unique()
|
126 |
+
self.sample_rate = sample_rate
|
127 |
+
|
128 |
+
def preprocess_and_mix(self, save_dir: str, sync: str, ordered: bool, num_track_to_mix: int, n_jobs: int = -2):
|
129 |
+
"""
|
130 |
+
A method to preprocess and mix audio tracks from the IRMAS dataset.
|
131 |
+
|
132 |
+
:param save_dir: The directory to save the preprocessed and mixed tracks
|
133 |
+
:type save_dir: str
|
134 |
+
:param sync: The column name used to synchronize the audio tracks during mixing
|
135 |
+
:type sync: str
|
136 |
+
:param ordered: Whether to order the metadata by the sync column before mixing the tracks
|
137 |
+
:type ordered: bool
|
138 |
+
:param num_track_to_mix: The number of tracks to mix together
|
139 |
+
:type num_track_to_mix: int
|
140 |
+
:param n_jobs: The number of parallel jobs to run, defaults to -2
|
141 |
+
:type n_jobs: int, optional
|
142 |
+
|
143 |
+
:raises None
|
144 |
+
|
145 |
+
:return: None
|
146 |
+
:rtype: None
|
147 |
+
"""
|
148 |
+
|
149 |
+
combs = itertools.combinations(self.instruments, r=num_track_to_mix)
|
150 |
+
|
151 |
+
if ordered:
|
152 |
+
self.metadata = self.metadata.sort_values(by=sync)
|
153 |
+
else:
|
154 |
+
self.metadata = self.metadata.sample(frac=1)
|
155 |
+
|
156 |
+
Parallel(n_jobs=n_jobs)(delayed(self._mix)(insts, save_dir, sync) for (insts) in tqdm(combs))
|
157 |
+
print("Parallel preprocessing done!")
|
158 |
+
|
159 |
+
def _mix(self, insts: Tuple[str], save_dir: str, sync: str):
|
160 |
+
"""
|
161 |
+
A private method to mix audio tracks and save them to disk.
|
162 |
+
|
163 |
+
:param insts: A tuple of instrument labels to mix
|
164 |
+
:type insts: Tuple[str]
|
165 |
+
:param save_dir: The directory to save the mixed tracks
|
166 |
+
:type save_dir: str
|
167 |
+
:param sync: The column name used to synchronize the audio tracks during mixing
|
168 |
+
:type sync: str
|
169 |
+
|
170 |
+
:raises None
|
171 |
+
|
172 |
+
:return: None
|
173 |
+
:rtype: None
|
174 |
+
"""
|
175 |
+
|
176 |
+
save_dir = self._create_save_dir(insts, save_dir)
|
177 |
+
|
178 |
+
insts_files_list = [self._get_filepaths(inst) for inst in insts]
|
179 |
+
|
180 |
+
max_length = max([inst_files.shape[0] for inst_files in insts_files_list])
|
181 |
+
for i, inst_files in enumerate(insts_files_list):
|
182 |
+
if inst_files.shape[0] < max_length:
|
183 |
+
diff = max_length - inst_files.shape[0]
|
184 |
+
inst_files = np.pad(inst_files, (0, diff), mode="symmetric")
|
185 |
+
insts_files_list[i] = [Path(x) for x in inst_files]
|
186 |
+
|
187 |
+
self._mix_files_and_save(insts_files_list, save_dir, sync)
|
188 |
+
|
189 |
+
def _get_filepaths(self, inst: str):
|
190 |
+
"""
|
191 |
+
A private method to retrieve file paths of audio tracks for a given instrument label.
|
192 |
+
|
193 |
+
:param inst: The label of the instrument for which to retrieve the file paths
|
194 |
+
:type inst: str
|
195 |
+
|
196 |
+
:raises KeyError: Raised when the instrument label is not found in the metadata.
|
197 |
+
|
198 |
+
:return: A numpy array of file paths corresponding to the instrument label.
|
199 |
+
:rtype: numpy.ndarray
|
200 |
+
"""
|
201 |
+
|
202 |
+
metadata = self.metadata.loc[self.metadata.inst == inst]
|
203 |
+
|
204 |
+
if metadata.empty:
|
205 |
+
raise KeyError("Instrument not found. Please regenerate metadata!")
|
206 |
+
|
207 |
+
files = metadata.path.to_numpy()
|
208 |
+
|
209 |
+
return files
|
210 |
+
|
211 |
+
def _mix_files_and_save(self, insts_files_list: List[List[Path]], save_dir: str, sync: str):
|
212 |
+
"""
|
213 |
+
A private method to mix audio files, synchronize them using a given column name in the metadata,
|
214 |
+
and save the mixed file to disk.
|
215 |
+
|
216 |
+
:param insts_files_list: A list of lists of file paths corresponding to each instrument label
|
217 |
+
:type insts_files_list: List[List[Path]]
|
218 |
+
:param save_dir: The directory to save the mixed tracks
|
219 |
+
:type save_dir: str
|
220 |
+
:param sync: The column name used to synchronize the audio tracks during mixing
|
221 |
+
:type sync: str
|
222 |
+
|
223 |
+
:raises None
|
224 |
+
|
225 |
+
:return: None
|
226 |
+
:rtype: None
|
227 |
+
"""
|
228 |
+
|
229 |
+
for i in range(len(insts_files_list[0])):
|
230 |
+
files_to_sync = [inst_files[i] for inst_files in insts_files_list]
|
231 |
+
new_name = f"{'-'.join([file.stem for file in files_to_sync])}.wav"
|
232 |
+
|
233 |
+
synced_file = self._sync_and_mix(files_to_sync, sync)
|
234 |
+
sf.write(os.path.join(save_dir, new_name), synced_file, samplerate=self.sample_rate)
|
235 |
+
|
236 |
+
def _sync_and_mix(self, files_to_sync: List[Path], sync: str):
|
237 |
+
"""
|
238 |
+
Synchronize and mix audio files.
|
239 |
+
|
240 |
+
:param files_to_sync: A list of file paths to synchronize and mix.
|
241 |
+
:type files_to_sync: List[Path]
|
242 |
+
:param sync: The type of synchronization to use. One of ['bpm', 'pitch', None].
|
243 |
+
:type sync: str, optional
|
244 |
+
:raises KeyError: If any file in files_to_sync is not found in metadata.
|
245 |
+
:return: The synchronized and mixed audio signal.
|
246 |
+
:rtype: numpy.ndarray
|
247 |
+
"""
|
248 |
+
|
249 |
+
cols = ["pitch", "bpm", "onset"]
|
250 |
+
files_metadata_df = self.metadata.loc[
|
251 |
+
self.metadata.path.isin([str(file_path) for file_path in files_to_sync])
|
252 |
+
].set_index("path")
|
253 |
+
|
254 |
+
num_files = files_metadata_df.shape[0]
|
255 |
+
if num_files != len(files_to_sync):
|
256 |
+
raise KeyError("File not found in metadata. Please regenerate")
|
257 |
+
|
258 |
+
if sync is not None:
|
259 |
+
mean_features = files_metadata_df[cols].mean().to_dict()
|
260 |
+
|
261 |
+
metadata_dict = files_metadata_df.to_dict("index")
|
262 |
+
|
263 |
+
for i, (file_to_sync_path, features) in enumerate(metadata_dict.items()):
|
264 |
+
file_to_sync, sr_sync = librosa.load(file_to_sync_path, sr=None)
|
265 |
+
|
266 |
+
if sr_sync != 44100:
|
267 |
+
file_to_sync = librosa.resample(y=file_to_sync, orig_sr=sr_sync, target_sr=self.sample_rate)
|
268 |
+
|
269 |
+
if sync == "bpm":
|
270 |
+
file_to_sync = sync_bpm(file_to_sync, sr_sync, bpm_base=mean_features["bpm"], bpm=features["bpm"])
|
271 |
+
|
272 |
+
if sync == "pitch":
|
273 |
+
file_to_sync = sync_pitch(
|
274 |
+
file_to_sync, sr_sync, pitch_base=mean_features["pitch"], pitch=features["pitch"]
|
275 |
+
)
|
276 |
+
|
277 |
+
if sync is not None:
|
278 |
+
file_to_sync = sync_onset(
|
279 |
+
file_to_sync, sr_sync, onset_base=mean_features["onset"], onset=features["onset"]
|
280 |
+
)
|
281 |
+
|
282 |
+
file_to_sync = librosa.util.normalize(file_to_sync)
|
283 |
+
|
284 |
+
if i == 0:
|
285 |
+
mixed_sound = np.zeros_like(file_to_sync)
|
286 |
+
|
287 |
+
if mixed_sound.shape[0] > file_to_sync.shape[0]:
|
288 |
+
file_to_sync = np.resize(file_to_sync, mixed_sound.shape)
|
289 |
+
else:
|
290 |
+
mixed_sound = np.resize(mixed_sound, file_to_sync.shape)
|
291 |
+
|
292 |
+
mixed_sound += file_to_sync
|
293 |
+
|
294 |
+
mixed_sound /= num_files
|
295 |
+
|
296 |
+
return librosa.resample(y=mixed_sound, orig_sr=44100, target_sr=self.sample_rate)
|
297 |
+
|
298 |
+
def _create_save_dir(self, insts: Union[Tuple[str], List[str]], save_dir: str):
|
299 |
+
"""
|
300 |
+
Create and return a directory to save instrument-specific files.
|
301 |
+
|
302 |
+
:param insts: A tuple or list of instrument names.
|
303 |
+
:type insts: Union[Tuple[str], List[str]]
|
304 |
+
:param save_dir: The path to the directory where the new directory will be created.
|
305 |
+
:type save_dir: str
|
306 |
+
:return: The path to the newly created directory.
|
307 |
+
:rtype: str
|
308 |
+
"""
|
309 |
+
|
310 |
+
new_dir_name = "-".join(insts)
|
311 |
+
new_dir_path = os.path.join(save_dir, new_dir_name)
|
312 |
+
os.makedirs(new_dir_path, exist_ok=True)
|
313 |
+
return new_dir_path
|
314 |
+
|
315 |
+
@classmethod
|
316 |
+
def from_metadata(cls, metadata_path: str, **kwargs):
|
317 |
+
"""
|
318 |
+
Create a new instance of the class from a metadata file.
|
319 |
+
|
320 |
+
:param metadata_path: The path to the metadata file.
|
321 |
+
:type metadata_path: str
|
322 |
+
:param **kwargs: Additional keyword arguments to pass to the class constructor.
|
323 |
+
:return: A new instance of the class.
|
324 |
+
:rtype: cls
|
325 |
+
"""
|
326 |
+
|
327 |
+
metadata = pd.read_csv(metadata_path)
|
328 |
+
return cls(metadata, **kwargs)
|
329 |
+
|
330 |
+
|
331 |
+
if __name__ == "__main__":
|
332 |
+
data_dir = "/home/kpintaric/lumen-irmas/data/raw/IRMAS_Training_Data"
|
333 |
+
metadata_path = "/home/kpintaric/lumen-irmas/data/metadata_train.csv"
|
334 |
+
preprocess = IRMASPreprocessor(metadata=metadata_path, data_dir=data_dir)
|
335 |
+
preprocess.preprocess_and_mix(save_dir="data", sync="pitch", ordered=False, num_track_to_mix=3)
|
336 |
+
a = 1
|
src/modeling/transforms.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchaudio
|
9 |
+
from torchaudio.transforms import FrequencyMasking, TimeMasking
|
10 |
+
from torchvision.transforms import Compose
|
11 |
+
from transformers import ASTFeatureExtractor
|
12 |
+
|
13 |
+
|
14 |
+
class Transform(ABC):
|
15 |
+
"""Abstract base class for audio transformations."""
|
16 |
+
|
17 |
+
@abstractmethod
|
18 |
+
def __call__(self):
|
19 |
+
"""
|
20 |
+
Abstract method to apply the transformation.
|
21 |
+
|
22 |
+
:raises NotImplementedError: If the subclass does not implement this method.
|
23 |
+
|
24 |
+
"""
|
25 |
+
pass
|
26 |
+
|
27 |
+
|
28 |
+
class Preprocess(ABC):
|
29 |
+
"""Abstract base class for preprocessing data.
|
30 |
+
|
31 |
+
This class defines the interface for preprocessing data. Subclasses must implement the call method.
|
32 |
+
|
33 |
+
"""
|
34 |
+
|
35 |
+
@abstractmethod
|
36 |
+
def __call__(self):
|
37 |
+
"""Process the data.
|
38 |
+
|
39 |
+
This method must be implemented by subclasses.
|
40 |
+
|
41 |
+
:raises NotImplementedError: Subclasses must implement this method.
|
42 |
+
|
43 |
+
"""
|
44 |
+
pass
|
45 |
+
|
46 |
+
|
47 |
+
class OneHotEncode(Transform):
|
48 |
+
"""Transform labels to one-hot encoded tensor.
|
49 |
+
|
50 |
+
This class is a transform that takes a list of labels and returns a one-hot encoded tensor.
|
51 |
+
The labels are converted to a tensor with one-hot encoding using the specified classes.
|
52 |
+
|
53 |
+
:param c: A list of classes to be used for one-hot encoding.
|
54 |
+
:type c: list
|
55 |
+
:return: A one-hot encoded tensor.
|
56 |
+
:rtype: torch.Tensor
|
57 |
+
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, c: list):
|
61 |
+
self.c = c
|
62 |
+
|
63 |
+
def __call__(self, labels):
|
64 |
+
"""
|
65 |
+
Transform labels to one-hot encoded tensor.
|
66 |
+
|
67 |
+
:param labels: A list of labels to be encoded.
|
68 |
+
:type labels: list
|
69 |
+
:return: A one-hot encoded tensor.
|
70 |
+
:rtype: torch.Tensor
|
71 |
+
|
72 |
+
"""
|
73 |
+
|
74 |
+
target = torch.zeros(len(self.c), dtype=torch.float)
|
75 |
+
for label in labels:
|
76 |
+
idx = self.c.index(label)
|
77 |
+
target[idx] = 1
|
78 |
+
return target
|
79 |
+
|
80 |
+
|
81 |
+
class ParentMultilabel(Transform):
|
82 |
+
"""
|
83 |
+
A transform that extracts a list of labels from the parent directory name of a file path.
|
84 |
+
|
85 |
+
:param sep: The separator used to split the parent directory name into labels. Defaults to " ".
|
86 |
+
:type sep: str
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(self, sep=" "):
|
90 |
+
self.sep = sep
|
91 |
+
|
92 |
+
def __call__(self, path):
|
93 |
+
"""
|
94 |
+
Extract a list of labels from the parent directory name of a file path.
|
95 |
+
|
96 |
+
:param path: The file path from which to extract labels.
|
97 |
+
:type path: str
|
98 |
+
:return: A list of labels extracted from the parent directory name of the input file path.
|
99 |
+
:rtype: List[str]
|
100 |
+
"""
|
101 |
+
|
102 |
+
label = path.split(os.path.sep)[-2].split(self.sep)
|
103 |
+
return label
|
104 |
+
|
105 |
+
|
106 |
+
class LabelsFromTxt(Transform):
|
107 |
+
"""
|
108 |
+
Extract multilabel parent directory from file path.
|
109 |
+
|
110 |
+
This class is a transform that extracts a multilabel parent directory from a file path.
|
111 |
+
The directory names are split by a specified separator.
|
112 |
+
|
113 |
+
:param sep: The separator used to split the directory names. Defaults to " ".
|
114 |
+
:type sep: str
|
115 |
+
|
116 |
+
"""
|
117 |
+
|
118 |
+
def __init__(self, delimiter=None):
|
119 |
+
self.delimiter = delimiter
|
120 |
+
|
121 |
+
def __call__(self, path):
|
122 |
+
"""
|
123 |
+
Extract multilabel parent directory from file path.
|
124 |
+
|
125 |
+
:param path: The path of the file to extract the multilabel directory from.
|
126 |
+
:type path: str
|
127 |
+
:return: A list of directory names representing the multilabel parent directory.
|
128 |
+
:rtype: list
|
129 |
+
|
130 |
+
"""
|
131 |
+
|
132 |
+
path = path.replace("wav", "txt")
|
133 |
+
label = np.loadtxt(path, dtype=str, ndmin=1, delimiter=self.delimiter)
|
134 |
+
return label
|
135 |
+
|
136 |
+
|
137 |
+
class PreprocessPipeline(Preprocess):
|
138 |
+
"""A preprocessing pipeline for audio data.
|
139 |
+
|
140 |
+
This class is a preprocessing pipeline for audio data.
|
141 |
+
The pipeline includes resampling to a target sampling rate, mixing down stereo to mono,
|
142 |
+
and loading audio from a file.
|
143 |
+
|
144 |
+
:param target_sr: The target sampling rate to resample to.
|
145 |
+
:type target_sr: int
|
146 |
+
"""
|
147 |
+
|
148 |
+
def __init__(self, target_sr):
|
149 |
+
self.target_sr = target_sr
|
150 |
+
|
151 |
+
def __call__(self, path):
|
152 |
+
"""
|
153 |
+
Preprocess audio data using a pipeline.
|
154 |
+
|
155 |
+
:param path: The path to the audio file to load.
|
156 |
+
:type path: str
|
157 |
+
:return: A NumPy array of preprocessed audio data.
|
158 |
+
:rtype: numpy.ndarray
|
159 |
+
|
160 |
+
"""
|
161 |
+
|
162 |
+
signal, sr = torchaudio.load(path)
|
163 |
+
signal = self._resample(signal, sr)
|
164 |
+
signal = self._mix_down(signal)
|
165 |
+
return signal.numpy()
|
166 |
+
|
167 |
+
def _mix_down(self, signal):
|
168 |
+
"""
|
169 |
+
Mix down stereo to mono.
|
170 |
+
|
171 |
+
:param signal: The audio signal to mix down.
|
172 |
+
:type signal: torch.Tensor
|
173 |
+
:return: The mixed down audio signal.
|
174 |
+
:rtype: torch.Tensor
|
175 |
+
|
176 |
+
"""
|
177 |
+
|
178 |
+
if signal.shape[0] > 1:
|
179 |
+
signal = torch.mean(signal, dim=0, keepdim=True)
|
180 |
+
return signal
|
181 |
+
|
182 |
+
def _resample(self, signal, input_sr):
|
183 |
+
"""
|
184 |
+
Resample audio signal to a target sampling rate.
|
185 |
+
|
186 |
+
:param signal: The audio signal to resample.
|
187 |
+
:type signal: torch.Tensor
|
188 |
+
:param input_sr: The current sampling rate of the audio signal.
|
189 |
+
:type input_sr: int
|
190 |
+
:return: The resampled audio signal.
|
191 |
+
:rtype: torch.Tensor
|
192 |
+
|
193 |
+
"""
|
194 |
+
|
195 |
+
if input_sr != self.target_sr:
|
196 |
+
resampler = torchaudio.transforms.Resample(input_sr, self.target_sr)
|
197 |
+
signal = resampler(signal)
|
198 |
+
return signal
|
199 |
+
|
200 |
+
|
201 |
+
class SpecToImage(Transform):
|
202 |
+
def __init__(self, mean=None, std=None, eps=1e-6):
|
203 |
+
self.mean = mean
|
204 |
+
self.std = std
|
205 |
+
self.eps = eps
|
206 |
+
|
207 |
+
def __call__(self, spec):
|
208 |
+
spec = torch.stack([spec, spec, spec], dim=-1)
|
209 |
+
|
210 |
+
mean = torch.mean(spec) if self.mean is None else self.mean
|
211 |
+
std = torch.std(spec) if self.std is None else self.std
|
212 |
+
spec_norm = (spec - mean) / std
|
213 |
+
|
214 |
+
spec_min, spec_max = torch.min(spec_norm), torch.max(spec_norm)
|
215 |
+
spec_scaled = 255 * (spec_norm - spec_min) / (spec_max - spec_min)
|
216 |
+
|
217 |
+
return spec_scaled.type(torch.uint8)
|
218 |
+
|
219 |
+
|
220 |
+
class MinMaxScale(Transform):
|
221 |
+
def __call__(self, spec):
|
222 |
+
spec_min, spec_max = torch.min(spec), torch.max(spec)
|
223 |
+
|
224 |
+
return (spec - spec_min) / (spec_max - spec_min)
|
225 |
+
|
226 |
+
|
227 |
+
class Normalize(Transform):
|
228 |
+
def __init__(self, mean, std):
|
229 |
+
self.mean = mean
|
230 |
+
self.std = std
|
231 |
+
|
232 |
+
def __call__(self, spec):
|
233 |
+
return (spec - self.mean) / self.std
|
234 |
+
|
235 |
+
|
236 |
+
class FeatureExtractor(Transform):
|
237 |
+
"""Extract features from audio signal using an AST feature extractor.
|
238 |
+
|
239 |
+
This class is a transform that extracts features from an audio signal using an AST feature extractor.
|
240 |
+
The features are returned as a PyTorch tensor.
|
241 |
+
|
242 |
+
:param sr: The sampling rate of the audio signal.
|
243 |
+
:type sr: int
|
244 |
+
"""
|
245 |
+
|
246 |
+
def __init__(self, sr):
|
247 |
+
self.transform = partial(ASTFeatureExtractor(), sampling_rate=sr, return_tensors="pt")
|
248 |
+
|
249 |
+
def __call__(self, signal):
|
250 |
+
"""
|
251 |
+
Extract features from audio signal using an AST feature extractor.
|
252 |
+
|
253 |
+
:param signal: The audio signal to extract features from.
|
254 |
+
:type signal: numpy.ndarray
|
255 |
+
:return: A tensor of extracted audio features.
|
256 |
+
:rtype: torch.Tensor
|
257 |
+
|
258 |
+
"""
|
259 |
+
|
260 |
+
return self.transform(signal.squeeze()).input_values.mT
|
261 |
+
|
262 |
+
|
263 |
+
class Preemphasis(Transform):
|
264 |
+
"""perform preemphasis on the input signal.
|
265 |
+
:param signal: The signal to filter.
|
266 |
+
:param coeff: The preemphasis coefficient. 0 is none, default 0.97.
|
267 |
+
:returns: the filtered signal.
|
268 |
+
"""
|
269 |
+
|
270 |
+
def __init__(self, coeff: float = 0.97):
|
271 |
+
self.coeff = coeff
|
272 |
+
|
273 |
+
def __call__(self, signal):
|
274 |
+
return torch.cat([signal[:, :1], signal[:, 1:] - self.coeff * signal[:, :-1]], dim=1)
|
275 |
+
|
276 |
+
|
277 |
+
class Spectrogram(Transform):
|
278 |
+
def __init__(self, sample_rate, n_mels, hop_length, n_fft):
|
279 |
+
self.transform = torchaudio.transforms.MelSpectrogram(
|
280 |
+
sample_rate=sample_rate, n_mels=n_mels, hop_length=hop_length, n_fft=n_fft, f_min=20, center=False
|
281 |
+
)
|
282 |
+
|
283 |
+
def __call__(self, signal):
|
284 |
+
return self.transform(signal)
|
285 |
+
|
286 |
+
|
287 |
+
class LogTransform(Transform):
|
288 |
+
def __call__(self, signal):
|
289 |
+
return torch.log(signal + 1e-8)
|
290 |
+
|
291 |
+
|
292 |
+
class PadCutToLength(Transform):
|
293 |
+
def __init__(self, max_length):
|
294 |
+
self.max_length = max_length
|
295 |
+
|
296 |
+
def __call__(self, spec):
|
297 |
+
seq_len = spec.shape[-1]
|
298 |
+
|
299 |
+
if seq_len > self.max_length:
|
300 |
+
return spec[..., : self.max_length]
|
301 |
+
if seq_len < self.max_length:
|
302 |
+
diff = self.max_length - seq_len
|
303 |
+
return F.pad(spec, (0, diff), mode="constant", value=0)
|
304 |
+
|
305 |
+
|
306 |
+
class CustomFeatureExtractor(Transform):
|
307 |
+
def __init__(self, sample_rate, n_mels, hop_length, n_fft, max_length, mean, std):
|
308 |
+
self.extract = Compose(
|
309 |
+
[
|
310 |
+
Preemphasis(),
|
311 |
+
Spectrogram(sample_rate=sample_rate, n_mels=n_mels, hop_length=hop_length, n_fft=n_fft),
|
312 |
+
LogTransform(),
|
313 |
+
PadCutToLength(max_length=max_length),
|
314 |
+
Normalize(mean=mean, std=std),
|
315 |
+
]
|
316 |
+
)
|
317 |
+
|
318 |
+
def __call__(self, x):
|
319 |
+
return self.extract(x)
|
320 |
+
|
321 |
+
|
322 |
+
class RepeatAudio(Transform):
|
323 |
+
"""A transform to repeat audio data.
|
324 |
+
|
325 |
+
This class is a transform that repeats audio data a random number of times up to a maximum specified value.
|
326 |
+
|
327 |
+
:param max_repeats: The maximum number of times to repeat the audio data.
|
328 |
+
:type max_repeats: int
|
329 |
+
"""
|
330 |
+
|
331 |
+
def __init__(self, max_repeats: int = 2):
|
332 |
+
self.max_repeats = max_repeats
|
333 |
+
|
334 |
+
def __call__(self, signal):
|
335 |
+
"""
|
336 |
+
Repeat audio data a random number of times up to a maximum specified value.
|
337 |
+
|
338 |
+
:param signal: The audio data to repeat.
|
339 |
+
:type signal: numpy.ndarray
|
340 |
+
:return: The repeated audio data.
|
341 |
+
:rtype: numpy.ndarray
|
342 |
+
|
343 |
+
"""
|
344 |
+
|
345 |
+
num_repeats = torch.randint(1, self.max_repeats, (1,)).item()
|
346 |
+
return np.tile(signal, reps=num_repeats)
|
347 |
+
|
348 |
+
|
349 |
+
class MaskFrequency(Transform):
|
350 |
+
"""A transform to mask frequency of a spectrogram.
|
351 |
+
|
352 |
+
This class is a transform that masks out a random number of consecutive frequencies from a spectrogram.
|
353 |
+
|
354 |
+
:param max_mask_length: The maximum number of consecutive frequencies to mask out from the spectrogram.
|
355 |
+
:type max_mask_length: int
|
356 |
+
"""
|
357 |
+
|
358 |
+
def __init__(self, max_mask_length: int = 0):
|
359 |
+
self.aug = FrequencyMasking(max_mask_length)
|
360 |
+
|
361 |
+
def __call__(self, spec):
|
362 |
+
"""
|
363 |
+
Mask out a random number of consecutive frequencies from a spectrogram.
|
364 |
+
|
365 |
+
:param spec: The input spectrogram.
|
366 |
+
:type spec: numpy.ndarray
|
367 |
+
:return: The spectrogram with masked frequencies.
|
368 |
+
:rtype: numpy.ndarray
|
369 |
+
|
370 |
+
"""
|
371 |
+
|
372 |
+
return self.aug(spec)
|
373 |
+
|
374 |
+
|
375 |
+
class MaskTime(Transform):
|
376 |
+
"""A transform to mask time of a spectrogram.
|
377 |
+
|
378 |
+
This class is a transform that masks out a random number of consecutive time steps from a spectrogram.
|
379 |
+
|
380 |
+
:param max_mask_length: The maximum number of consecutive time steps to mask out from the spectrogram.
|
381 |
+
:type max_mask_length: int
|
382 |
+
"""
|
383 |
+
|
384 |
+
def __init__(self, max_mask_length: int = 0):
|
385 |
+
self.aug = TimeMasking(max_mask_length)
|
386 |
+
|
387 |
+
def __call__(self, spec):
|
388 |
+
"""
|
389 |
+
Mask out a random number of consecutive time steps from a spectrogram.
|
390 |
+
|
391 |
+
:param spec: The input spectrogram.
|
392 |
+
:type spec: numpy.ndarray
|
393 |
+
:return: The spectrogram with masked time steps.
|
394 |
+
:rtype: numpy.ndarray
|
395 |
+
|
396 |
+
"""
|
397 |
+
|
398 |
+
return self.aug(spec)
|
src/modeling/utils.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from glob import glob
|
2 |
+
from pathlib import Path
|
3 |
+
from types import SimpleNamespace
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import numpy as np
|
8 |
+
import yaml
|
9 |
+
|
10 |
+
CLASSES = ["tru", "sax", "vio", "gac", "org", "cla", "flu", "voi", "gel", "cel", "pia"]
|
11 |
+
|
12 |
+
|
13 |
+
def get_wav_files(base_path):
|
14 |
+
"""
|
15 |
+
Function to recursively get all the .wav files in a directory.
|
16 |
+
|
17 |
+
:param base_path: The base path of the directory to search.
|
18 |
+
:type base_path: str or pathlib.Path
|
19 |
+
|
20 |
+
:return: A list of paths to .wav files found in the directory.
|
21 |
+
:rtype: List[str]
|
22 |
+
"""
|
23 |
+
|
24 |
+
return glob(f"{base_path}/**/*.wav", recursive=True)
|
25 |
+
|
26 |
+
|
27 |
+
def parse_config(config_path):
|
28 |
+
"""
|
29 |
+
Parse a YAML configuration file and return the configuration as a SimpleNamespace object.
|
30 |
+
|
31 |
+
:param config_path: The path to the YAML configuration file.
|
32 |
+
:type config_path: str or pathlib.Path
|
33 |
+
|
34 |
+
:return: A SimpleNamespace object representing the configuration.
|
35 |
+
:rtype: types.SimpleNamespace
|
36 |
+
"""
|
37 |
+
with open(config_path) as file:
|
38 |
+
return SimpleNamespace(**yaml.safe_load(file))
|
39 |
+
|
40 |
+
|
41 |
+
def init_transforms(fn_dict, module):
|
42 |
+
"""
|
43 |
+
Initialize a list of transforms from a dictionary of function names and their parameters.
|
44 |
+
|
45 |
+
:param fn_dict: A dictionary where keys are the names of transform functions
|
46 |
+
and values are dictionaries of parameters.
|
47 |
+
:type fn_dict: Dict[str, Dict[str, Any]]
|
48 |
+
|
49 |
+
:param module: The module where the transform functions are defined.
|
50 |
+
:type module: module
|
51 |
+
|
52 |
+
:return: A list of transform functions.
|
53 |
+
:rtype: List[Callable]
|
54 |
+
"""
|
55 |
+
transforms = init_objs(fn_dict, module)
|
56 |
+
if transforms is not None:
|
57 |
+
transforms = ComposeTransforms(transforms)
|
58 |
+
return transforms
|
59 |
+
|
60 |
+
|
61 |
+
def init_objs(fn_dict, module):
|
62 |
+
"""
|
63 |
+
Initialize a list of objects from a dictionary of object names and their parameters.
|
64 |
+
|
65 |
+
:param fn_dict: A dictionary where keys are the names of object classes and values are dictionaries of parameters.
|
66 |
+
:type fn_dict: Dict[str, Dict[str, Any]]
|
67 |
+
|
68 |
+
:param module: The module where the object classes are defined.
|
69 |
+
:type module: module
|
70 |
+
|
71 |
+
:return: A list of objects.
|
72 |
+
:rtype: List[Any]
|
73 |
+
"""
|
74 |
+
|
75 |
+
if fn_dict is None:
|
76 |
+
return None
|
77 |
+
|
78 |
+
transforms = []
|
79 |
+
for transform in fn_dict.keys():
|
80 |
+
fn = getattr(module, transform)
|
81 |
+
if fn is None:
|
82 |
+
raise NotImplementedError(
|
83 |
+
"The attribute '{}' is not implemented in the module '{}'.".format(transform, module.__name__)
|
84 |
+
)
|
85 |
+
|
86 |
+
fn_args = fn_dict[transform]
|
87 |
+
|
88 |
+
if fn_args is None:
|
89 |
+
transforms.append(fn())
|
90 |
+
else:
|
91 |
+
transforms.append(fn(**fn_args))
|
92 |
+
|
93 |
+
return transforms
|
94 |
+
|
95 |
+
|
96 |
+
def init_obj(fn_dict, module, *args, **kwargs):
|
97 |
+
"""
|
98 |
+
Initialize an object by calling a function with the provided arguments.
|
99 |
+
|
100 |
+
:param fn_dict: A dictionary that maps the function name to its arguments.
|
101 |
+
:type fn_dict: dict or None
|
102 |
+
:param module: The module containing the function.
|
103 |
+
:type module: module
|
104 |
+
:param args: The positional arguments for the function.
|
105 |
+
:type args: tuple
|
106 |
+
:param kwargs: The keyword arguments for the function.
|
107 |
+
:type kwargs: dict
|
108 |
+
:raises AssertionError: If a keyword argument is already specified in fn_dict.
|
109 |
+
:return: The result of calling the function with the provided arguments.
|
110 |
+
:rtype: Any
|
111 |
+
"""
|
112 |
+
|
113 |
+
if fn_dict is None:
|
114 |
+
return None
|
115 |
+
|
116 |
+
name = list(fn_dict.keys())[0]
|
117 |
+
|
118 |
+
fn = getattr(module, name)
|
119 |
+
if fn is None:
|
120 |
+
raise NotImplementedError(
|
121 |
+
"The attribute '{}' is not implemented in the module '{}'.".format(name, module.__name__)
|
122 |
+
)
|
123 |
+
|
124 |
+
fn_args = fn_dict[name]
|
125 |
+
|
126 |
+
if fn_args is not None:
|
127 |
+
assert all(k not in fn_args for k in kwargs)
|
128 |
+
fn_args.update(kwargs)
|
129 |
+
|
130 |
+
return fn(*args, **fn_args)
|
131 |
+
else:
|
132 |
+
return fn(*args, **kwargs)
|
133 |
+
|
134 |
+
|
135 |
+
class ComposeTransforms:
|
136 |
+
"""
|
137 |
+
Composes a list of transforms to be applied in sequence to input data.
|
138 |
+
|
139 |
+
:param transforms: A list of transforms to be applied.
|
140 |
+
:type transforms: List[callable]
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(self, transforms: list):
|
144 |
+
self.transforms = transforms
|
145 |
+
|
146 |
+
def __call__(self, data, *args):
|
147 |
+
for t in self.transforms:
|
148 |
+
data = t(data, *args)
|
149 |
+
return data
|
150 |
+
|
151 |
+
|
152 |
+
def load_raw_file(path: Union[str, Path]):
|
153 |
+
"""
|
154 |
+
Loads an audio file from disk and returns its raw waveform and sample rate.
|
155 |
+
|
156 |
+
:param path: The path to the audio file to load.
|
157 |
+
:type path: Union[str, Path]
|
158 |
+
:return: A tuple containing the raw waveform and sample rate.
|
159 |
+
:rtype: tuple
|
160 |
+
"""
|
161 |
+
return librosa.load(path, sr=None, mono=False)
|
162 |
+
|
163 |
+
|
164 |
+
def get_onset(signal, sr):
|
165 |
+
"""
|
166 |
+
Computes the onset of an audio signal.
|
167 |
+
|
168 |
+
:param signal: The audio signal.
|
169 |
+
:type signal: np.ndarray
|
170 |
+
:param sr: The sample rate of the audio signal.
|
171 |
+
:type sr: int
|
172 |
+
:return: The onset of the audio signal in seconds.
|
173 |
+
:rtype: float
|
174 |
+
"""
|
175 |
+
onset = librosa.onset.onset_detect(y=signal, sr=sr, units="time")[0]
|
176 |
+
return onset
|
177 |
+
|
178 |
+
|
179 |
+
def get_bpm(signal, sr):
|
180 |
+
"""
|
181 |
+
Computes the estimated beats per minute (BPM) of an audio signal.
|
182 |
+
|
183 |
+
:param signal: The audio signal.
|
184 |
+
:type signal: np.ndarray
|
185 |
+
:param sr: The sample rate of the audio signal.
|
186 |
+
:type sr: int
|
187 |
+
:return: The estimated BPM of the audio signal, or None if the BPM cannot be computed.
|
188 |
+
:rtype: Union[float, None]
|
189 |
+
"""
|
190 |
+
|
191 |
+
bpm, _ = librosa.beat.beat_track(y=signal, sr=sr)
|
192 |
+
return bpm if bpm != 0 else None
|
193 |
+
|
194 |
+
|
195 |
+
def get_pitch(signal, sr):
|
196 |
+
"""
|
197 |
+
Computes the estimated pitch of an audio signal.
|
198 |
+
|
199 |
+
:param signal: The audio signal.
|
200 |
+
:type signal: np.ndarray
|
201 |
+
:param sr: The sample rate of the audio signal.
|
202 |
+
:type sr: int
|
203 |
+
:return: The estimated pitch of the audio signal in logarithmic scale, or None if the pitch cannot be computed.
|
204 |
+
:rtype: Union[float, None]
|
205 |
+
"""
|
206 |
+
|
207 |
+
eps = 1e-8
|
208 |
+
fmin = librosa.note_to_hz("C2")
|
209 |
+
fmax = librosa.note_to_hz("C7")
|
210 |
+
|
211 |
+
pitch, _, _ = librosa.pyin(y=signal, sr=sr, fmin=fmin, fmax=fmax)
|
212 |
+
|
213 |
+
if not np.isnan(pitch).all():
|
214 |
+
mean_log_pitch = np.nanmean(np.log(pitch + eps))
|
215 |
+
else:
|
216 |
+
mean_log_pitch = None
|
217 |
+
|
218 |
+
return mean_log_pitch
|
219 |
+
|
220 |
+
|
221 |
+
def get_file_info(path: Union[str, Path], extract_music_features: bool):
|
222 |
+
"""
|
223 |
+
Loads an audio file and computes some basic information about it,
|
224 |
+
such as pitch, BPM, onset time, duration, sample rate, and number of channels.
|
225 |
+
|
226 |
+
:param path: The path to the audio file.
|
227 |
+
:type path: Union[str, Path]
|
228 |
+
:param extract_music_features: Whether to extract music features such as pitch, BPM, and onset time.
|
229 |
+
:type extract_music_features: bool
|
230 |
+
:return: A dictionary containing information about the audio file.
|
231 |
+
:rtype: dict
|
232 |
+
"""
|
233 |
+
|
234 |
+
path = str(path) if isinstance(path, Path) else path
|
235 |
+
|
236 |
+
signal, sr = load_raw_file(path)
|
237 |
+
channels = signal.shape[0]
|
238 |
+
|
239 |
+
signal = librosa.to_mono(signal)
|
240 |
+
duration = len(signal) / sr
|
241 |
+
|
242 |
+
pitch, bpm, onset = None, None, None
|
243 |
+
if extract_music_features:
|
244 |
+
pitch = get_pitch(signal, sr)
|
245 |
+
bpm = get_bpm(signal, sr)
|
246 |
+
onset = get_onset(signal, sr)
|
247 |
+
|
248 |
+
return {
|
249 |
+
"path": path,
|
250 |
+
"pitch": pitch,
|
251 |
+
"bpm": bpm,
|
252 |
+
"onset": onset,
|
253 |
+
"sample_rate": sr,
|
254 |
+
"duration": duration,
|
255 |
+
"channels": channels,
|
256 |
+
}
|
257 |
+
|
258 |
+
|
259 |
+
def sync_pitch(file_to_sync: np.ndarray, sr: int, pitch_base: float, pitch: float):
|
260 |
+
"""
|
261 |
+
Shift the pitch of an audio file to match a new pitch value.
|
262 |
+
|
263 |
+
:param file_to_sync: The input audio file as a NumPy array.
|
264 |
+
:type file_to_sync: np.ndarray
|
265 |
+
:param sr: The sample rate of the input file.
|
266 |
+
:type sr: int
|
267 |
+
:param pitch_base: The pitch value of the original file.
|
268 |
+
:type pitch_base: float
|
269 |
+
:param pitch: The pitch value to synchronize the input file to.
|
270 |
+
:type pitch: float
|
271 |
+
:return: The synchronized audio file as a NumPy array.
|
272 |
+
:rtype: np.ndarray
|
273 |
+
"""
|
274 |
+
|
275 |
+
assert np.ndim(file_to_sync) == 1, "Input array has more than one dimensions"
|
276 |
+
|
277 |
+
if any(np.isnan(x) for x in [pitch_base, pitch]):
|
278 |
+
return file_to_sync
|
279 |
+
|
280 |
+
steps = np.round(12 * np.log2(np.exp(pitch_base) / np.exp(pitch)), 0)
|
281 |
+
|
282 |
+
return librosa.effects.pitch_shift(y=file_to_sync, sr=sr, n_steps=steps)
|
283 |
+
|
284 |
+
|
285 |
+
def sync_bpm(file_to_sync: np.ndarray, sr: int, bpm_base: float, bpm: float):
|
286 |
+
"""
|
287 |
+
Stretch or compress the duration of an audio file to match a new tempo.
|
288 |
+
|
289 |
+
:param file_to_sync: The input audio file as a NumPy array.
|
290 |
+
:type file_to_sync: np.ndarray
|
291 |
+
:param sr: The sample rate of the input file.
|
292 |
+
:type sr: int
|
293 |
+
:param bpm_base: The tempo of the original file.
|
294 |
+
:type bpm_base: float
|
295 |
+
:param bpm: The tempo to synchronize the input file to.
|
296 |
+
:type bpm: float
|
297 |
+
:return: The synchronized audio file as a NumPy array.
|
298 |
+
:rtype: np.ndarray
|
299 |
+
"""
|
300 |
+
|
301 |
+
assert np.ndim(file_to_sync) == 1, "Input array has more than one dimensions"
|
302 |
+
|
303 |
+
if any(np.isnan(x) for x in [bpm_base, bpm]):
|
304 |
+
return file_to_sync
|
305 |
+
|
306 |
+
return librosa.effects.time_stretch(y=file_to_sync, rate=bpm_base / bpm)
|
307 |
+
|
308 |
+
|
309 |
+
def sync_onset(file_to_sync: np.ndarray, sr: int, onset_base: float, onset: float):
|
310 |
+
"""
|
311 |
+
Sync the onset of an audio signal by adding or removing silence at the beginning.
|
312 |
+
|
313 |
+
:param file_to_sync: The audio signal to synchronize.
|
314 |
+
:type file_to_sync: np.ndarray
|
315 |
+
:param sr: The sample rate of the audio signal.
|
316 |
+
:type sr: int
|
317 |
+
:param onset_base: The onset of the reference signal in seconds.
|
318 |
+
:type onset_base: float
|
319 |
+
:param onset: The onset of the signal to synchronize in seconds.
|
320 |
+
:type onset: float
|
321 |
+
:raises AssertionError: If the input array has more than one dimension.
|
322 |
+
:return: The synchronized audio signal.
|
323 |
+
:rtype: np.ndarray
|
324 |
+
"""
|
325 |
+
|
326 |
+
assert np.ndim(file_to_sync) == 1, "Input array has more than one dimensions"
|
327 |
+
|
328 |
+
if any(np.isnan(x) for x in [onset_base, onset]):
|
329 |
+
return file_to_sync
|
330 |
+
|
331 |
+
diff = int(round(abs(onset_base * sr - onset * sr), 0))
|
332 |
+
|
333 |
+
if onset_base > onset:
|
334 |
+
return np.pad(file_to_sync, (diff, 0), mode="constant", constant_values=0)
|
335 |
+
else:
|
336 |
+
return file_to_sync[diff:]
|