Spaces:
Sleeping
Sleeping
from tqdm import tqdm | |
import librosa | |
import torch | |
from src import laion_clap | |
from ..config.configs import ProjectPaths | |
class AudioEncoder(laion_clap.CLAP_Module): | |
def __init__(self, collection=None) -> None: | |
super().__init__(enable_fusion=False, amodel="HTSAT-base") | |
self.music_data = None | |
self.load_ckpt(ckpt=ProjectPaths.MODEL_PATH) | |
self.collection = collection | |
# def _get_track_data(self): | |
# with open(ProjectPaths.DATA_DIR.joinpath("json", "final_track_data.json"), "r") as reader: | |
# track_data = json.load(reader) | |
# return track_data | |
def _get_track_data(self): | |
data = self.collection.find({}) | |
return data | |
def update_collection_item(self, track_id, vector): | |
self.collection.update_one({"track_id": track_id}, {"$set": {"embedding": vector}}) | |
def extract_audio_representaion(self, file_name): | |
audio_data, _ = librosa.load(file_name, sr=48000) | |
audio_data = audio_data.reshape(1, -1) | |
audio_data = torch.from_numpy(audio_data) | |
with torch.no_grad(): | |
audio_embed = self.get_audio_embedding_from_data( | |
x=audio_data, use_tensor=True | |
) | |
return audio_embed | |
def extract_bulk_audio_representaions(self, save=False): | |
track_data = self._get_track_data() | |
processed_data = [] | |
idx = 0 | |
for track in tqdm(track_data): | |
if track["youtube_data"]["file_path"] and track["youtube_data"]["link"] not in processed_data: | |
tensor = self.extract_audio_representaion(track["youtube_data"]["file_path"]) | |
self.update_collection_item(track["track_id"], tensor.tolist()) | |
idx += 1 | |
# def load_existing_audio_vectors(self): | |
# self.music_data = torch.load( | |
# ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.pt") | |
# ) | |
# with open( | |
# ProjectPaths.DATA_DIR.joinpath("vectors", "final_track_data_w_links.json"), | |
# "r", | |
# ) as rd: | |
# self.track_data = json.load(rd) | |
def load_existing_audio_vectors(self): | |
# embedding_result = list(self.collection.find({}, {"embedding": 1})) | |
# tracking_result = list(self.collection.find({}, {"embedding": 0})) | |
arrays = [] | |
track_data = [] | |
idx = 0 | |
for track in self.collection.find({}): | |
if not track.get("embedding"): | |
continue | |
data = track.copy() | |
data.pop("embedding") | |
data.update({"vector_idx": idx}) | |
arrays.append(track["embedding"][0]) | |
track_data.append(data) | |
idx += 1 | |
self.music_data = torch.tensor(arrays) | |
self.track_data = track_data.copy() | |
def extract_text_representation(self, text): | |
text_data = [text] | |
text_embed = self.get_text_embedding(text_data) | |
return text_embed | |
def list_top_k_songs(self, text, k=10): | |
assert self.music_data is not None | |
with torch.no_grad(): | |
text_embed = self.get_text_embedding(text, use_tensor=True) | |
dot_product = self.music_data @ text_embed.T | |
top_10 = torch.topk(dot_product.flatten(), k) | |
indices = top_10.indices.tolist() | |
final_result = [] | |
for k, i in enumerate(indices): | |
piece = { | |
"title": self.track_data[i]["youtube_data"]["title"], | |
"score": round(top_10.values[k].item(), 2), | |
"link": self.track_data[i]["youtube_data"]["link"], | |
"track_id": self.track_data[i]["track_id"], | |
} | |
final_result.append(piece) | |
return final_result |