File size: 4,278 Bytes
e562a59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36b8886
 
e562a59
 
 
 
 
8bc75c1
 
 
 
 
 
 
 
 
 
e562a59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36b8886
e562a59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gc
import hashlib
import os
from glob import glob
from pathlib import Path

import librosa
import torch
from diskcache import Cache
from qdrant_client import QdrantClient
from qdrant_client.http import models
from tqdm import tqdm
from transformers import ClapModel, ClapProcessor

from s3_utils import s3_auth, upload_file_to_bucket
from dotenv import load_dotenv
load_dotenv()

# PARAMETERS #######################################################################################
CACHE_FOLDER = '/home/nahia/audio'
KAGGLE_DB_PATH = '/home/nahia/Documents/audio/actor/Actor_01'
AWS_ACCESS_KEY_ID = os.environ['AWS_ACCESS_KEY_ID']
AWS_SECRET_ACCESS_KEY = os.environ['AWS_SECRET_ACCESS_KEY']
S3_BUCKET = "synthia-research"
S3_FOLDER = "huggingface_spaces_demo"
AWS_REGION = "eu-west-3"
from dotenv import load_dotenv
import os

# Charger les variables d'environnement à partir du fichier .env
load_dotenv()

# Récupérer le mot de passe depuis les variables d'environnement
QDRANT_URL = os.getenv('QDRANT_URL')
QDRANT_KEY = os.getenv('QDRANT_KEY')


s3 = s3_auth(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION)


# Functions utils ##################################################################################
def get_md5(fpath):
    with open(fpath, "rb") as f:
        file_hash = hashlib.md5()
        while chunk := f.read(8192):
            file_hash.update(chunk)
    return file_hash.hexdigest()


def get_audio_embedding(model, audio_file, cache):
    # Compute a unique hash for the audio file
    file_key = f"{model.config._name_or_path}" + get_md5(audio_file)
    if file_key in cache:
        # If the embedding for this file is cached, retrieve it
        embedding = cache[file_key]
    else:
        # Otherwise, compute the embedding and cache it
        y, sr = librosa.load(audio_file, sr=48000)
        inputs = processor(audios=y, sampling_rate=sr, return_tensors="pt")
        embedding = model.get_audio_features(**inputs)[0]
        gc.collect()
        torch.cuda.empty_cache()
        cache[file_key] = embedding
    return embedding



# ################## Loading the CLAP model ###################
# loading the model
print("[INFO] Loading the model...")
model_name = "laion/larger_clap_general"
model = ClapModel.from_pretrained(model_name)
processor = ClapProcessor.from_pretrained(model_name)

# Initialize the cache
os.makedirs(CACHE_FOLDER, exist_ok=True)
cache = Cache(CACHE_FOLDER)

# Creating a qdrant collection #####################################################################
client = QdrantClient(QDRANT_URL,api_key=QDRANT_KEY)
print("[INFO] Client created...")

print("[INFO] Creating qdrant data collection...")
if not client.collection_exists("demo_spaces_db"):
    client.create_collection(
        collection_name="demo_spaces_db",
        vectors_config=models.VectorParams(
            size=model.config.projection_dim,
            distance=models.Distance.COSINE
        ),
    )

# Embed the audio files !
audio_files = [p for p in glob(os.path.join(KAGGLE_DB_PATH, '*/*.wav'))]
chunk_size, idx = 1, 0
total_chunks = int(len(audio_files) / chunk_size)

# Use tqdm for a progress bar
print("Uploading on DB + S3")
for i in tqdm(range(0, len(audio_files), chunk_size),
              desc="[INFO] Uploading data records to data collection..."):
    chunk = audio_files[i:i + chunk_size]  # Get a chunk of audio files
    records = []
    for audio_file in chunk:
        embedding = get_audio_embedding(model, audio_file, cache)
        file_obj = open(audio_file, 'rb')
        s3key = f'{S3_FOLDER}/{Path(audio_file).name}'
        upload_file_to_bucket(s3, file_obj, S3_BUCKET, s3key)
        records.append(
            models.PointStruct(
                id=idx, vector=embedding,
                payload={
                    "audio_path": audio_file,
                    "audio_s3url": f"https://{S3_BUCKET}.s3.amazonaws.com/{s3key}",
                    "style": audio_file.split('/')[-1]}
            )
        )
        f"Uploaded s3 file : {idx}"
        idx += 1
    client.upload_points(
        collection_name="demo_spaces_db",
        points=records
    )
print("[INFO] Successfully uploaded data records to data collection!")


# It's a good practice to close the cache when done
cache.close()