leadr64 commited on
Commit
e562a59
1 Parent(s): f030819

Ajouter le script Gradio et les dépendances

Browse files
Files changed (3) hide show
  1. app.py +45 -0
  2. database.py +118 -0
  3. s3_utils.py +66 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import laion_clap
3
+ from qdrant_client import QdrantClient
4
+ import os
5
+
6
+ # Utilisez les variables d'environnement pour la configuration
7
+ QDRANT_HOST = os.getenv('QDRANT_HOST', 'localhost')
8
+ QDRANT_PORT = int(os.getenv('QDRANT_PORT', 6333))
9
+
10
+ # Connexion à Qdrant
11
+ client = QdrantClient(QDRANT_HOST, port=QDRANT_PORT)
12
+ print("[INFO] Client created...")
13
+
14
+ # Charger le modèle
15
+ print("[INFO] Loading the model...")
16
+ model_name = "laion/larger_clap_music"
17
+ model = laion_clap.CLAP_Module(enable_fusion=False)
18
+ model.load_ckpt() # télécharger le checkpoint préentraîné par défaut
19
+
20
+ # Interface Gradio
21
+ max_results = 10
22
+
23
+ def sound_search(query):
24
+ text_embed = model.get_text_embedding([query, ''])[0] # trick because can't accept singleton
25
+ hits = client.search(
26
+ collection_name="demo_db7",
27
+ query_vector=text_embed,
28
+ limit=max_results,
29
+ )
30
+ return [
31
+ gr.Audio(
32
+ hit.payload['audio_path'],
33
+ label=f"style: {hit.payload['style']} -- score: {hit.score}")
34
+ for hit in hits
35
+ ]
36
+
37
+ with gr.Blocks() as demo:
38
+ gr.Markdown(
39
+ """# Sound search database """
40
+ )
41
+ inp = gr.Textbox(placeholder="What sound are you looking for ?")
42
+ out = [gr.Audio(label=f"{x}") for x in range(max_results)] # Nécessaire pour avoir différents objets
43
+ inp.change(sound_search, inp, out)
44
+
45
+ demo.launch()
database.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import hashlib
3
+ import os
4
+ from glob import glob
5
+ from pathlib import Path
6
+
7
+ import librosa
8
+ import torch
9
+ from diskcache import Cache
10
+ from qdrant_client import QdrantClient
11
+ from qdrant_client.http import models
12
+ from tqdm import tqdm
13
+ from transformers import ClapModel, ClapProcessor
14
+
15
+ from s3_utils import s3_auth, upload_file_to_bucket
16
+ from dotenv import load_dotenv
17
+ load_dotenv()
18
+
19
+ # PARAMETERS #######################################################################################
20
+ CACHE_FOLDER = '/home/arthur/data/music/demo_audio_search/audio_embeddings_cache_individual/'
21
+ KAGGLE_DB_PATH = '/home/arthur/data/kaggle/park-spring-2023-music-genre-recognition/train/train'
22
+ AWS_ACCESS_KEY_ID = os.environ['AWS_ACCESS_KEY_ID']
23
+ AWS_SECRET_ACCESS_KEY = os.environ['AWS_SECRET_ACCESS_KEY']
24
+ S3_BUCKET = "synthia-research"
25
+ S3_FOLDER = "huggingface_spaces_demo"
26
+ AWS_REGION = "eu-west-3"
27
+
28
+ s3 = s3_auth(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION)
29
+
30
+
31
+ # Functions utils ##################################################################################
32
+ def get_md5(fpath):
33
+ with open(fpath, "rb") as f:
34
+ file_hash = hashlib.md5()
35
+ while chunk := f.read(8192):
36
+ file_hash.update(chunk)
37
+ return file_hash.hexdigest()
38
+
39
+
40
+ def get_audio_embedding(model, audio_file, cache):
41
+ # Compute a unique hash for the audio file
42
+ file_key = f"{model.config._name_or_path}" + get_md5(audio_file)
43
+ if file_key in cache:
44
+ # If the embedding for this file is cached, retrieve it
45
+ embedding = cache[file_key]
46
+ else:
47
+ # Otherwise, compute the embedding and cache it
48
+ y, sr = librosa.load(audio_file, sr=48000)
49
+ inputs = processor(audios=y, sampling_rate=sr, return_tensors="pt")
50
+ embedding = model.get_audio_features(**inputs)[0]
51
+ gc.collect()
52
+ torch.cuda.empty_cache()
53
+ cache[file_key] = embedding
54
+ return embedding
55
+
56
+
57
+
58
+ # ################## Loading the CLAP model ###################
59
+ # loading the model
60
+ print("[INFO] Loading the model...")
61
+ model_name = "laion/larger_clap_general"
62
+ model = ClapModel.from_pretrained(model_name)
63
+ processor = ClapProcessor.from_pretrained(model_name)
64
+
65
+ # Initialize the cache
66
+ os.makedirs(CACHE_FOLDER, exist_ok=True)
67
+ cache = Cache(CACHE_FOLDER)
68
+
69
+ # Creating a qdrant collection #####################################################################
70
+ client = QdrantClient(os.environ['QDRANT_URL'], api_key=os.environ['QDRANT_KEY'])
71
+ print("[INFO] Client created...")
72
+
73
+ print("[INFO] Creating qdrant data collection...")
74
+ if not client.collection_exists("demo_spaces_db"):
75
+ client.create_collection(
76
+ collection_name="demo_spaces_db",
77
+ vectors_config=models.VectorParams(
78
+ size=model.config.projection_dim,
79
+ distance=models.Distance.COSINE
80
+ ),
81
+ )
82
+
83
+ # Embed the audio files !
84
+ audio_files = [p for p in glob(os.path.join(KAGGLE_DB_PATH, '*/*.wav'))]
85
+ chunk_size, idx = 1, 0
86
+ total_chunks = int(len(audio_files) / chunk_size)
87
+
88
+ # Use tqdm for a progress bar
89
+ print("Uploading on DB + S3")
90
+ for i in tqdm(range(0, len(audio_files), chunk_size),
91
+ desc="[INFO] Uploading data records to data collection..."):
92
+ chunk = audio_files[i:i + chunk_size] # Get a chunk of audio files
93
+ records = []
94
+ for audio_file in chunk:
95
+ embedding = get_audio_embedding(model, audio_file, cache)
96
+ file_obj = open(audio_file, 'rb')
97
+ s3key = f'{S3_FOLDER}/{Path(audio_file).name}'
98
+ upload_file_to_bucket(s3, file_obj, S3_BUCKET, s3key)
99
+ records.append(
100
+ models.PointStruct(
101
+ id=idx, vector=embedding,
102
+ payload={
103
+ "audio_path": audio_file,
104
+ "audio_s3url": f"https://{S3_BUCKET}.s3.amazonaws.com/{s3key}",
105
+ "style": audio_file.split('/')[-1]}
106
+ )
107
+ )
108
+ f"Uploaded s3 file : {idx}"
109
+ idx += 1
110
+ client.upload_points(
111
+ collection_name="demo_spaces_db",
112
+ points=records
113
+ )
114
+ print("[INFO] Successfully uploaded data records to data collection!")
115
+
116
+
117
+ # It's a good practice to close the cache when done
118
+ cache.close()
s3_utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from enum import Enum
3
+
4
+ import boto3
5
+ from botocore.client import BaseClient
6
+
7
+
8
+ # S3 HANDLING ######################################################################################
9
+ def get_md5(fpath):
10
+ with open(fpath, "rb") as f:
11
+ file_hash = hashlib.md5()
12
+ while chunk := f.read(8192):
13
+ file_hash.update(chunk)
14
+ return file_hash.hexdigest()
15
+
16
+
17
+ def upload_file_to_bucket(s3_client, file_obj, bucket, s3key):
18
+ """Upload a file to an S3 bucket
19
+ :param file_obj: File to upload
20
+ :param bucket: Bucket to upload to
21
+ :param s3key: s3key
22
+ :param object_name: S3 object name. If not specified then file_name is used
23
+ :return: True if file was uploaded, else False
24
+ """
25
+ # Upload the file
26
+ return s3_client.upload_fileobj(
27
+ file_obj, bucket, s3key,
28
+ ExtraArgs={"ACL": "public-read", "ContentType": "Content-Type: audio/mpeg"}
29
+ )
30
+
31
+
32
+ def s3_auth(aws_access_key_id, aws_secret_access_key, region_name) -> BaseClient:
33
+ s3 = boto3.client(
34
+ service_name='s3',
35
+ aws_access_key_id=aws_access_key_id,
36
+ aws_secret_access_key=aws_secret_access_key,
37
+ region_name=region_name
38
+ )
39
+ return s3
40
+
41
+
42
+ def get_list_of_buckets(s3: BaseClient):
43
+ response = s3.list_buckets()
44
+ buckets = {}
45
+
46
+ for buckets in response['Buckets']:
47
+ buckets[response['Name']] = response['Name']
48
+
49
+ BucketName = Enum('BucketName', buckets)
50
+ return BucketName
51
+
52
+
53
+ if __name__ == '__main__':
54
+ import os
55
+
56
+ AWS_ACCESS_KEY_ID = os.environ['AWS_ACCESS_KEY_ID']
57
+ AWS_SECRET_ACCESS_KEY = os.environ['AWS_SECRET_ACCESS_KEY']
58
+ S3_BUCKET = "synthia-research"
59
+ S3_FOLDER = "huggingface_spaces_demo"
60
+ AWS_REGION = "eu-west-3"
61
+
62
+ s3 = s3_auth(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION)
63
+ print(s3.list_buckets())
64
+
65
+ s3key = f'{S3_FOLDER}/015.WAV'
66
+ #print(upload_file_to_bucket(s3, file_obj, S3_BUCKET, s3key))