import numpy as np import librosa import torch from src import laion_clap from glob import glob import pandas as pd from ..config.configs import ProjectPaths import pickle class AudioEncoder(laion_clap.CLAP_Module): def __init__(self) -> None: super().__init__(enable_fusion=False, amodel='HTSAT-base') self.load_ckpt(ckpt=ProjectPaths.MODEL_PATH) def extract_audio_representaion(self, file_name): audio_data, _ = librosa.load(file_name, sr=48000) audio_data = audio_data.reshape(1, -1) with torch.no_grad(): audio_embed = self.get_audio_embedding_from_data(x=audio_data, use_tensor=False) return audio_embed def extract_bulk_audio_representaions(self, save=False): music_files = glob(str(ProjectPaths.DATA_DIR.joinpath("audio", "*.wav"))) song_names = [k.split("/")[-1] for k in music_files] music_data = np.zeros((len(music_files), 512), dtype=np.float32) for m in range(music_data.shape[0]): music_data[m] = self.extract_audio_representaion(music_files[m]) if not save: return music_data, song_names else: np.save(ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.npy")) with open(ProjectPaths.DATA_DIR.joinpath("vectors", "song_names.pkl", "rb")) as writer: pickle.dump(song_names, writer) def extract_text_representation(self, text): text_data = [text] text_embed = self.get_text_embedding(text_data) return text_embed