berkaygkv54's picture
skipping vectors fixed
a350907
raw
history blame contribute delete
No virus
3.73 kB
from tqdm import tqdm
import librosa
import torch
from src import laion_clap
from ..config.configs import ProjectPaths
class AudioEncoder(laion_clap.CLAP_Module):
def __init__(self, collection=None) -> None:
super().__init__(enable_fusion=False, amodel="HTSAT-base")
self.music_data = None
self.load_ckpt(ckpt=ProjectPaths.MODEL_PATH)
self.collection = collection
# def _get_track_data(self):
# with open(ProjectPaths.DATA_DIR.joinpath("json", "final_track_data.json"), "r") as reader:
# track_data = json.load(reader)
# return track_data
def _get_track_data(self):
data = self.collection.find({})
return data
def update_collection_item(self, track_id, vector):
self.collection.update_one({"track_id": track_id}, {"$set": {"embedding": vector}})
def extract_audio_representaion(self, file_name):
audio_data, _ = librosa.load(file_name, sr=48000)
audio_data = audio_data.reshape(1, -1)
audio_data = torch.from_numpy(audio_data)
with torch.no_grad():
audio_embed = self.get_audio_embedding_from_data(
x=audio_data, use_tensor=True
)
return audio_embed
def extract_bulk_audio_representaions(self, save=False):
track_data = self._get_track_data()
processed_data = []
idx = 0
for track in tqdm(track_data):
if track["youtube_data"]["file_path"] and track["youtube_data"]["link"] not in processed_data:
tensor = self.extract_audio_representaion(track["youtube_data"]["file_path"])
self.update_collection_item(track["track_id"], tensor.tolist())
idx += 1
# def load_existing_audio_vectors(self):
# self.music_data = torch.load(
# ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.pt")
# )
# with open(
# ProjectPaths.DATA_DIR.joinpath("vectors", "final_track_data_w_links.json"),
# "r",
# ) as rd:
# self.track_data = json.load(rd)
def load_existing_audio_vectors(self):
# embedding_result = list(self.collection.find({}, {"embedding": 1}))
# tracking_result = list(self.collection.find({}, {"embedding": 0}))
arrays = []
track_data = []
idx = 0
for track in self.collection.find({}):
if not track.get("embedding"):
continue
data = track.copy()
data.pop("embedding")
data.update({"vector_idx": idx})
arrays.append(track["embedding"][0])
track_data.append(data)
idx += 1
self.music_data = torch.tensor(arrays)
self.track_data = track_data.copy()
def extract_text_representation(self, text):
text_data = [text]
text_embed = self.get_text_embedding(text_data)
return text_embed
def list_top_k_songs(self, text, k=10):
assert self.music_data is not None
with torch.no_grad():
text_embed = self.get_text_embedding(text, use_tensor=True)
dot_product = self.music_data @ text_embed.T
top_10 = torch.topk(dot_product.flatten(), k)
indices = top_10.indices.tolist()
final_result = []
for k, i in enumerate(indices):
piece = {
"title": self.track_data[i]["youtube_data"]["title"],
"score": round(top_10.values[k].item(), 2),
"link": self.track_data[i]["youtube_data"]["link"],
"track_id": self.track_data[i]["track_id"],
}
final_result.append(piece)
return final_result