File size: 3,732 Bytes
24510fe
19759e2
 
 
 
 
 
 
24510fe
 
 
19759e2
24510fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19759e2
 
 
 
24510fe
19759e2
24510fe
 
 
19759e2
 
 
24510fe
 
 
 
 
 
 
 
 
19759e2
24510fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a350907
 
24510fe
 
 
 
 
 
 
a350907
24510fe
 
 
19759e2
 
 
 
 
a20c02a
24510fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from tqdm import tqdm
import librosa
import torch
from src import laion_clap
from ..config.configs import ProjectPaths


class AudioEncoder(laion_clap.CLAP_Module):
    def __init__(self, collection=None) -> None:
        super().__init__(enable_fusion=False, amodel="HTSAT-base")
        self.music_data = None
        self.load_ckpt(ckpt=ProjectPaths.MODEL_PATH)
        self.collection = collection

    # def _get_track_data(self):
    #     with open(ProjectPaths.DATA_DIR.joinpath("json", "final_track_data.json"), "r") as reader:
    #         track_data = json.load(reader)
    #     return track_data

    def _get_track_data(self):
        data = self.collection.find({})
        return data
    

    def update_collection_item(self, track_id, vector):
        self.collection.update_one({"track_id": track_id}, {"$set": {"embedding": vector}})


    def extract_audio_representaion(self, file_name):
        audio_data, _ = librosa.load(file_name, sr=48000)
        audio_data = audio_data.reshape(1, -1)
        audio_data = torch.from_numpy(audio_data)
        with torch.no_grad():
            audio_embed = self.get_audio_embedding_from_data(
                x=audio_data, use_tensor=True
            )
        return audio_embed

    def extract_bulk_audio_representaions(self, save=False):
        track_data = self._get_track_data()
        processed_data = []
        idx = 0
        for track in tqdm(track_data):
            if track["youtube_data"]["file_path"] and track["youtube_data"]["link"] not in processed_data:
                tensor = self.extract_audio_representaion(track["youtube_data"]["file_path"])
                self.update_collection_item(track["track_id"], tensor.tolist())
                idx += 1


    # def load_existing_audio_vectors(self):
    #     self.music_data = torch.load(
    #         ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.pt")
    #     )
    #     with open(
    #         ProjectPaths.DATA_DIR.joinpath("vectors", "final_track_data_w_links.json"),
    #         "r",
    #     ) as rd:
    #         self.track_data = json.load(rd)

    def load_existing_audio_vectors(self):
        # embedding_result = list(self.collection.find({}, {"embedding": 1}))
        # tracking_result = list(self.collection.find({}, {"embedding": 0}))
        arrays = []
        track_data = []
        idx = 0
        for track in self.collection.find({}):
            if not track.get("embedding"):
                continue
            data = track.copy()
            data.pop("embedding")
            data.update({"vector_idx": idx})
            arrays.append(track["embedding"][0])
            track_data.append(data)
            idx += 1

        self.music_data = torch.tensor(arrays)
        self.track_data = track_data.copy()


    def extract_text_representation(self, text):
        text_data = [text]
        text_embed = self.get_text_embedding(text_data)
        return text_embed

    def list_top_k_songs(self, text, k=10):
        assert self.music_data is not None
        with torch.no_grad():
            text_embed = self.get_text_embedding(text, use_tensor=True)

        dot_product = self.music_data @ text_embed.T
        top_10 = torch.topk(dot_product.flatten(), k)
        indices = top_10.indices.tolist()
        final_result = []
        for k, i in enumerate(indices):
            piece = {
                "title": self.track_data[i]["youtube_data"]["title"],
                "score": round(top_10.values[k].item(), 2),
                "link": self.track_data[i]["youtube_data"]["link"],
                "track_id": self.track_data[i]["track_id"],
                }
            final_result.append(piece)
        return final_result