berkaygkv54 commited on
Commit
24510fe
1 Parent(s): a20c02a

llm integration

Browse files
app.py CHANGED
@@ -1,79 +1,67 @@
1
  import streamlit as st
2
  from streamlit import session_state as session
3
- from src.config.configs import ProjectPaths
4
- import numpy as np
5
  from src.laion_clap.inference import AudioEncoder
6
- import pickle
7
- import torch
8
  import pandas as pd
9
- import json
10
- import os
11
- import smtplib, ssl
12
  from dotenv import load_dotenv
13
- st.set_page_config(page_title="Curate me a playlist", layout="wide")
 
 
 
14
 
 
15
  load_dotenv()
16
 
17
- @st.cache_data
18
- def load_data():
19
- vectors = np.load(ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.npy"))
20
- with open(ProjectPaths.DATA_DIR.joinpath("vectors", "song_names.pkl"), "rb") as reader:
21
- song_names = pickle.load(reader)
22
-
23
- with open(ProjectPaths.DATA_DIR.joinpath("json", "youtube_data.json"), "r") as reader:
24
- youtube_data = json.load(reader)
25
-
26
- df_youtube = pd.DataFrame(youtube_data)
27
- df_youtube["id"] = df_youtube["artist_name"] + " - " + df_youtube["track_name"] + ".wav"
28
- df_youtube.set_index("id", inplace=True)
29
- return vectors, song_names, df_youtube
30
-
 
 
 
31
 
32
  @st.cache_resource
33
- def load_model():
34
- recommender = AudioEncoder()
35
- return recommender
 
 
 
 
 
 
 
 
36
 
37
 
38
- def send_curator(text):
39
- port = int(os.getenv("PORT"))
40
- print(port)
41
- smtp_server = "smtp.gmail.com"
42
- sender_email = os.getenv("EMAIL_ADDRESS")
43
- receiver_email = os.getenv("EMAIL_RECEIVER")
44
- password = os.getenv("EMAIL_PASSWORD")
45
- from email.mime.multipart import MIMEMultipart
46
- from email.mime.text import MIMEText
47
-
48
- msg = MIMEMultipart("alternative")
49
- msg["Subject"] = "Curate me a playlist submission"
50
- part1 = MIMEText(body, "plain")
51
- msg.attach(part1)
52
- context = ssl.create_default_context()
53
- with smtplib.SMTP_SSL(smtp_server, port, context=context) as server:
54
- server.login(sender_email, password)
55
- server.sendmail(sender_email, receiver_email, msg)
56
-
57
- print("Email sent.")
58
-
59
-
60
- recommender = load_model()
61
- audio_vectors, song_names, df_youtube = load_data()
62
 
63
  st.title("""Curate me a Playlist.""")
64
  session.text_input = st.text_input(label="Describe a playlist")
65
- session.slider_count = st.slider(label="Track counts", min_value=5, max_value=30, step=5)
66
- buffer1, col1, buffer2 = st.columns([1.45, 1, 1])
67
 
68
  is_clicked = col1.button(label="Curate")
69
  if is_clicked:
70
- text_embed = recommender.get_text_embedding(session.text_input)
71
- with torch.no_grad():
72
- ranking = torch.tensor(audio_vectors) @ torch.tensor(text_embed).t()
73
- ranking = ranking[:, 0].reshape(-1, 1)
74
- dataframe = pd.DataFrame(ranking, columns=[session.text_input], index=song_names).rename(columns={session.text_input: "score"})
75
- dataframe = dataframe.merge(df_youtube[["link"]], left_index=True, right_index=True, how="left").nlargest(int(session.slider_count), "score")
76
- # st.dataframe(dataframe, use_container_width=True)
 
 
77
  st.data_editor(
78
  dataframe,
79
  column_config={
@@ -88,22 +76,7 @@ if is_clicked:
88
  use_container_width=True
89
  )
90
 
91
- form = st.form("form")
92
- form.write("You can submit the playlist you've curated")
93
- sender = form.text_input("Name of the curator")
94
- query = session.text_input
95
- playlist = [f"{k}\n" for k in dataframe.index]
96
- playlist_string = "\n".join(dataframe.index.tolist())
97
- body = f"""\
98
- Subject: Curate me a playlist submission
99
-
100
- Curator --> {sender}
101
- Query --> {session.text_input}
102
-
103
- Playlist
104
- {playlist_string}
105
- """
106
-
107
 
108
- print(body)
109
- is_submit = form.form_submit_button("Submit", on_click=send_curator, args=([body]))
 
 
1
  import streamlit as st
2
  from streamlit import session_state as session
 
 
3
  from src.laion_clap.inference import AudioEncoder
4
+ from src.utils.spotify import SpotifyHandler, SpotifyAuthentication
 
5
  import pandas as pd
 
 
 
6
  from dotenv import load_dotenv
7
+ from langchain.llms import CTransformers, Ollama
8
+ from src.llm.chain import LLMChain
9
+ from pymongo.mongo_client import MongoClient
10
+ import os
11
 
12
+ st.set_page_config(page_title="Curate me a playlist", layout="wide")
13
  load_dotenv()
14
 
15
+ def load_llm_pipeline():
16
+ ctransformers_config = {
17
+ "max_new_tokens": 3000,
18
+ "temperature": 0,
19
+ "top_k": 1,
20
+ "top_p": 1,
21
+ "context_length": 2800
22
+ }
23
+
24
+ llm = CTransformers(
25
+ model="TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
26
+ model_file="mistral-7b-instruct-v0.1.Q5_K_M.gguf",
27
+ config=ctransformers_config
28
+ )
29
+ # llm = Ollama(temperature=0, model="mistral:7b-instruct-q8_0", top_k=1, top_p=1, num_ctx=2800)
30
+ chain = LLMChain(llm)
31
+ return chain
32
 
33
  @st.cache_resource
34
+ def load_resources():
35
+ password = os.getenv("MONGODB_PASSWORD")
36
+ url = os.getenv("MONGODB_URL")
37
+ uri = f"mongodb+srv://berkaygkv:{password}@{url}/?retryWrites=true&w=majority"
38
+ client = MongoClient(uri)
39
+ db = client.spoti
40
+ mongo_db_collection = db.saved_tracks
41
+ recommender = AudioEncoder(mongo_db_collection)
42
+ recommender.load_existing_audio_vectors()
43
+ llm_pipeline = load_llm_pipeline()
44
+ return recommender, llm_pipeline
45
 
46
 
47
+ recommender, llm_pipeline = load_resources()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  st.title("""Curate me a Playlist.""")
50
  session.text_input = st.text_input(label="Describe a playlist")
51
+ session.slider_count = st.slider(label="Track counts", min_value=5, max_value=35, step=5)
52
+ buffer1, col1, col2, buffer2 = st.columns([1.45, 1, 1, 1])
53
 
54
  is_clicked = col1.button(label="Curate")
55
  if is_clicked:
56
+ output = llm_pipeline.process_user_description(session.text_input)
57
+ song_list = []
58
+ for _, song_desc in output:
59
+ print(song_desc)
60
+ ranking = recommender.list_top_k_songs(song_desc, k=15)
61
+ song_list += ranking
62
+
63
+ dataframe = pd.DataFrame(song_list).sort_values("score", ascending=False).drop_duplicates(subset=["track_id"]).drop(columns=["track_id"]).reset_index(drop=True)
64
+ dataframe = dataframe.iloc[:session.slider_count]
65
  st.data_editor(
66
  dataframe,
67
  column_config={
 
76
  use_container_width=True
77
  )
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ # with st.form(key="spotiform"):
81
+ # st.form_submit_button(on_click=authenticate_spotify, args=(session.access_url, ))
82
+ # st.markdown(session.access_url)
src/laion_clap/inference.py CHANGED
@@ -1,41 +1,102 @@
1
- import numpy as np
2
  import librosa
3
  import torch
4
  from src import laion_clap
5
- from glob import glob
6
- import pandas as pd
7
  from ..config.configs import ProjectPaths
8
- import pickle
9
 
10
 
11
  class AudioEncoder(laion_clap.CLAP_Module):
12
- def __init__(self) -> None:
13
- super().__init__(enable_fusion=False, amodel='HTSAT-base')
 
14
  self.load_ckpt(ckpt=ProjectPaths.MODEL_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def extract_audio_representaion(self, file_name):
17
  audio_data, _ = librosa.load(file_name, sr=48000)
18
  audio_data = audio_data.reshape(1, -1)
 
19
  with torch.no_grad():
20
- audio_embed = self.get_audio_embedding_from_data(x=audio_data, use_tensor=False)
 
 
21
  return audio_embed
22
 
23
  def extract_bulk_audio_representaions(self, save=False):
24
- music_files = glob(str(ProjectPaths.DATA_DIR.joinpath("audio", "*.wav")))
25
- song_names = [k.split("/")[-1] for k in music_files]
26
- music_data = np.zeros((len(music_files), 512), dtype=np.float32)
27
- for m in range(music_data.shape[0]):
28
- music_data[m] = self.extract_audio_representaion(music_files[m])
 
 
 
 
29
 
30
- if not save:
31
- return music_data, song_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- else:
34
- np.save(ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.npy"))
35
- with open(ProjectPaths.DATA_DIR.joinpath("vectors", "song_names.pkl", "rb")) as writer:
36
- pickle.dump(song_names, writer)
37
 
38
  def extract_text_representation(self, text):
39
  text_data = [text]
40
  text_embed = self.get_text_embedding(text_data)
41
  return text_embed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
  import librosa
3
  import torch
4
  from src import laion_clap
5
+ import json
6
+ import jmespath
7
  from ..config.configs import ProjectPaths
 
8
 
9
 
10
  class AudioEncoder(laion_clap.CLAP_Module):
11
+ def __init__(self, collection=None) -> None:
12
+ super().__init__(enable_fusion=False, amodel="HTSAT-base")
13
+ self.music_data = None
14
  self.load_ckpt(ckpt=ProjectPaths.MODEL_PATH)
15
+ self.collection = collection
16
+
17
+ # def _get_track_data(self):
18
+ # with open(ProjectPaths.DATA_DIR.joinpath("json", "final_track_data.json"), "r") as reader:
19
+ # track_data = json.load(reader)
20
+ # return track_data
21
+
22
+ def _get_track_data(self):
23
+ data = self.collection.find({})
24
+ return data
25
+
26
+
27
+ def update_collection_item(self, track_id, vector):
28
+ self.collection.update_one({"track_id": track_id}, {"$set": {"embedding": vector}})
29
+
30
 
31
  def extract_audio_representaion(self, file_name):
32
  audio_data, _ = librosa.load(file_name, sr=48000)
33
  audio_data = audio_data.reshape(1, -1)
34
+ audio_data = torch.from_numpy(audio_data)
35
  with torch.no_grad():
36
+ audio_embed = self.get_audio_embedding_from_data(
37
+ x=audio_data, use_tensor=True
38
+ )
39
  return audio_embed
40
 
41
  def extract_bulk_audio_representaions(self, save=False):
42
+ track_data = self._get_track_data()
43
+ processed_data = []
44
+ idx = 0
45
+ for track in tqdm(track_data):
46
+ if track["youtube_data"]["file_path"] and track["youtube_data"]["link"] not in processed_data:
47
+ tensor = self.extract_audio_representaion(track["youtube_data"]["file_path"])
48
+ self.update_collection_item(track["track_id"], tensor.tolist())
49
+ idx += 1
50
+
51
 
52
+ # def load_existing_audio_vectors(self):
53
+ # self.music_data = torch.load(
54
+ # ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.pt")
55
+ # )
56
+ # with open(
57
+ # ProjectPaths.DATA_DIR.joinpath("vectors", "final_track_data_w_links.json"),
58
+ # "r",
59
+ # ) as rd:
60
+ # self.track_data = json.load(rd)
61
+
62
+ def load_existing_audio_vectors(self):
63
+ # embedding_result = list(self.collection.find({}, {"embedding": 1}))
64
+ # tracking_result = list(self.collection.find({}, {"embedding": 0}))
65
+ arrays = []
66
+ track_data = []
67
+ for idx, track in enumerate(self.collection.find({})):
68
+ if not track.get("embedding"):
69
+ continue
70
+ data = track.copy()
71
+ data.pop("embedding")
72
+ data.update({"vector_idx": idx})
73
+ arrays.append(track["embedding"][0])
74
+ track_data.append(data)
75
+
76
+ self.music_data = torch.tensor(arrays)
77
+ self.track_data = track_data.copy()
78
 
 
 
 
 
79
 
80
  def extract_text_representation(self, text):
81
  text_data = [text]
82
  text_embed = self.get_text_embedding(text_data)
83
  return text_embed
84
+
85
+ def list_top_k_songs(self, text, k=10):
86
+ assert self.music_data is not None
87
+ with torch.no_grad():
88
+ text_embed = self.get_text_embedding(text, use_tensor=True)
89
+
90
+ dot_product = self.music_data @ text_embed.T
91
+ top_10 = torch.topk(dot_product.flatten(), k)
92
+ indices = top_10.indices.tolist()
93
+ final_result = []
94
+ for k, i in enumerate(indices):
95
+ piece = {
96
+ "title": self.track_data[i]["youtube_data"]["title"],
97
+ "score": round(top_10.values[k].item(), 2),
98
+ "link": self.track_data[i]["youtube_data"]["link"],
99
+ "track_id": self.track_data[i]["track_id"],
100
+ }
101
+ final_result.append(piece)
102
+ return final_result
src/llm/chain.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import ChatPromptTemplate, PromptTemplate
2
+ from langchain.schema.runnable import RunnableLambda
3
+ from operator import itemgetter
4
+ from langchain.output_parsers import PydanticOutputParser
5
+ from .output_parser import SongDescriptions
6
+ from langchain.llms.base import LLM
7
+ import json
8
+
9
+
10
+ class LLMChain:
11
+ def __init__(self, llm_model: LLM) -> None:
12
+ self.llm_model = llm_model
13
+ self.parser = PydanticOutputParser(pydantic_object=SongDescriptions)
14
+ self.full_chain = self._create_llm_chain()
15
+
16
+
17
+ def _get_output_format(self, _):
18
+ return self.parser.get_format_instructions()
19
+
20
+ def _create_llm_chain(self):
21
+ prompt_response = ChatPromptTemplate.from_messages([
22
+ ("system", "You are an AI assistant, helping the user to turn a music playlist text description into four separate song descriptions that are probably contained in the playlist. Try to be specific with descriptions. Make sure all 4 song descriptions are similar.\n"),
23
+ ("system", "{format_instructions}\n"),
24
+ ("human", "Playlist description: {description}.\n"),
25
+ # ("human", "Song descriptions:"),
26
+ ])
27
+ # prompt = PromptTemplate(
28
+ # template="You are an AI assistant, helping the user to turn a music playlist text description into three separate generic song descriptions that are probably contained in the playlist.\n{format_instructions}\n{description}\n",
29
+ # input_variables=["description"],
30
+ # partial_variables={"format_instructions": self.parser.get_format_instructions()},
31
+ # )
32
+
33
+
34
+ full_chain = (
35
+ {
36
+ "format_instructions": RunnableLambda(self._get_output_format),
37
+ "description": itemgetter("description"),
38
+ }
39
+ | prompt_response
40
+ | self.llm_model
41
+ )
42
+ return full_chain
43
+
44
+ def process_user_description(self, user_input):
45
+ output = self.full_chain.invoke(
46
+ {
47
+ "description": user_input
48
+ }
49
+ ).replace("\\", '')
50
+ return self.parser.parse(output)
51
+
src/llm/output_parser.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+
3
+
4
+ class SongDescriptions(BaseModel):
5
+ song_description_1: str = Field(description="description of the first song")
6
+ song_description_2: str = Field(description="description of the second song")
7
+ song_description_3: str = Field(description="description of the third song")
8
+ song_description_4: str = Field(description="description of the fourth song")
src/utils/__init__.py ADDED
File without changes