from tqdm import tqdm import librosa import torch from src import laion_clap from ..config.configs import ProjectPaths class AudioEncoder(laion_clap.CLAP_Module): def __init__(self, collection=None) -> None: super().__init__(enable_fusion=False, amodel="HTSAT-base") self.music_data = None self.load_ckpt(ckpt=ProjectPaths.MODEL_PATH) self.collection = collection # def _get_track_data(self): # with open(ProjectPaths.DATA_DIR.joinpath("json", "final_track_data.json"), "r") as reader: # track_data = json.load(reader) # return track_data def _get_track_data(self): data = self.collection.find({}) return data def update_collection_item(self, track_id, vector): self.collection.update_one({"track_id": track_id}, {"$set": {"embedding": vector}}) def extract_audio_representaion(self, file_name): audio_data, _ = librosa.load(file_name, sr=48000) audio_data = audio_data.reshape(1, -1) audio_data = torch.from_numpy(audio_data) with torch.no_grad(): audio_embed = self.get_audio_embedding_from_data( x=audio_data, use_tensor=True ) return audio_embed def extract_bulk_audio_representaions(self, save=False): track_data = self._get_track_data() processed_data = [] idx = 0 for track in tqdm(track_data): if track["youtube_data"]["file_path"] and track["youtube_data"]["link"] not in processed_data: tensor = self.extract_audio_representaion(track["youtube_data"]["file_path"]) self.update_collection_item(track["track_id"], tensor.tolist()) idx += 1 # def load_existing_audio_vectors(self): # self.music_data = torch.load( # ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.pt") # ) # with open( # ProjectPaths.DATA_DIR.joinpath("vectors", "final_track_data_w_links.json"), # "r", # ) as rd: # self.track_data = json.load(rd) def load_existing_audio_vectors(self): # embedding_result = list(self.collection.find({}, {"embedding": 1})) # tracking_result = list(self.collection.find({}, {"embedding": 0})) arrays = [] track_data = [] idx = 0 for track in self.collection.find({}): if not track.get("embedding"): continue data = track.copy() data.pop("embedding") data.update({"vector_idx": idx}) arrays.append(track["embedding"][0]) track_data.append(data) idx += 1 self.music_data = torch.tensor(arrays) self.track_data = track_data.copy() def extract_text_representation(self, text): text_data = [text] text_embed = self.get_text_embedding(text_data) return text_embed def list_top_k_songs(self, text, k=10): assert self.music_data is not None with torch.no_grad(): text_embed = self.get_text_embedding(text, use_tensor=True) dot_product = self.music_data @ text_embed.T top_10 = torch.topk(dot_product.flatten(), k) indices = top_10.indices.tolist() final_result = [] for k, i in enumerate(indices): piece = { "title": self.track_data[i]["youtube_data"]["title"], "score": round(top_10.values[k].item(), 2), "link": self.track_data[i]["youtube_data"]["link"], "track_id": self.track_data[i]["track_id"], } final_result.append(piece) return final_result