berkaygkv54's picture
drop columns and duplicates
0fbe80a
raw
history blame
4.72 kB
import streamlit as st
from streamlit import session_state as session
from src.laion_clap.inference import AudioEncoder
# from src.utils.spotify import SpotifyHandler, SpotifyAuthentication
import pandas as pd
from dotenv import load_dotenv
from langchain.llms import CTransformers, Ollama
from src.llm.chain import LLMChain
from pymongo.mongo_client import MongoClient
import os
st.set_page_config(page_title="Curate me a playlist", layout="wide")
load_dotenv()
def load_llm_pipeline():
ctransformers_config = {
"max_new_tokens": 3000,
"temperature": 0,
"top_k": 1,
"top_p": 1,
"context_length": 2800
}
llm = CTransformers(
model="TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
model_file=os.getenv("LLM_VERSION"),
# model_file="mistral-7b-instruct-v0.1.Q5_K_M.gguf",
config=ctransformers_config
)
# llm = Ollama(temperature=0, model="mistral:7b-instruct-q8_0", top_k=1, top_p=1, num_ctx=2800)
chain = LLMChain(llm)
return chain
@st.cache_resource
def load_resources():
password = os.getenv("MONGODB_PASSWORD")
url = os.getenv("MONGODB_URL")
uri = f"mongodb+srv://berkaygkv:{password}@{url}/?retryWrites=true&w=majority"
client = MongoClient(uri)
db = client.spoti
mongo_db_collection = db.saved_tracks
recommender = AudioEncoder(mongo_db_collection)
recommender.load_existing_audio_vectors()
llm_pipeline = load_llm_pipeline()
return recommender, llm_pipeline
@st.cache_resource
def output_songs(text):
output = llm_pipeline.process_user_description(text)
if output:
song_list = []
for _, song_desc in output:
print(song_desc)
ranking = recommender.list_top_k_songs(song_desc, k=15)
song_list += ranking
return pd.DataFrame(song_list)\
.sort_values("score", ascending=False)\
.drop_duplicates(subset=["track_id"])\
.reset_index(drop=True)
else:
return None
recommender, llm_pipeline = load_resources()
st.title("""Curate me a Playlist.""")
st.info("""
Hey there, introducing the Music Playlist Curator AI! It's designed to craft playlists based on your descriptions.
Here's the breakdown: we've got a Mistral 7B-Instruct 5-bit quantized version LLM running on the CPU to handle user inputs,
and a Contrastive Learning model from the Amazing [LAION AI](https://github.com/LAION-AI/CLAP) team for Audio-Text joint embeddings, scoring song similarity.
The songs are pulled from my personal Spotify Liked Songs through API. Using an automated data extraction pipeline,
I queried each song on my list on YouTube, downloaded it,
extracted audio features, and stored them on MongoDB.
TODOs:
- [ ] Making playlists on users' own Spotify Tracks,
- [ ] Display leaderboard to show the best playlist curated,
- [ ] Generate the playlist on Spotify directly
""")
st.success("The pipeline running on CPU which might take a few minutes to process.")
st.warning("""
A caveat: because the audio data is retrieved from YouTube,
there's a chance some songs might not be top-notch quality or could be live versions, impacting the audio features' quality.
Another caveat: I've given it a spin with some Turkish descriptions, had some wins and some misses. I might wanna upgrade to a GPU powered environment
to enchance LLM capacity in the future.
Give it a shot and see how it goes! 🎶
""")
# st.success("""
# """)
session.text_input = st.text_input(label="Describe a playlist")
session.slider_count = st.slider(label="How many tracks", min_value=5, max_value=35, step=5)
buffer1, col1, buffer2 = st.columns([1.45, 1, 1])
is_clicked = col1.button(label="Curate")
if is_clicked:
dataframe = output_songs(session.text_input)
if isinstance(dataframe, pd.DataFrame):
dataframe = dataframe.iloc[:session.slider_count]
dataframe.drop_duplicates(subset=["track_id"], inplace=True)
dataframe.drop(columns=["track_id", "score"], inplace=True)
st.data_editor(
dataframe,
column_config={
"link": st.column_config.LinkColumn(
"link",
)
},
hide_index=False,
use_container_width=True
)
else:
st.warning("User prompt could not be processed.")
# with st.form(key="spotiform"):
# st.form_submit_button(on_click=authenticate_spotify, args=(session.access_url, ))
# st.markdown(session.access_url)