Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit import session_state as session | |
from src.laion_clap.inference import AudioEncoder | |
# from src.utils.spotify import SpotifyHandler, SpotifyAuthentication | |
import pandas as pd | |
from dotenv import load_dotenv | |
from langchain.llms import CTransformers, Ollama | |
from src.llm.chain import LLMChain | |
from pymongo.mongo_client import MongoClient | |
import os | |
st.set_page_config(page_title="Curate me a playlist", layout="wide") | |
load_dotenv() | |
def load_llm_pipeline(): | |
ctransformers_config = { | |
"max_new_tokens": 3000, | |
"temperature": 0, | |
"top_k": 1, | |
"top_p": 1, | |
"context_length": 2800 | |
} | |
llm = CTransformers( | |
model="TheBloke/Mistral-7B-Instruct-v0.1-GGUF", | |
model_file=os.getenv("LLM_VERSION"), | |
# model_file="mistral-7b-instruct-v0.1.Q5_K_M.gguf", | |
config=ctransformers_config | |
) | |
# llm = Ollama(temperature=0, model="mistral:7b-instruct-q8_0", top_k=1, top_p=1, num_ctx=2800) | |
chain = LLMChain(llm) | |
return chain | |
def load_resources(): | |
password = os.getenv("MONGODB_PASSWORD") | |
url = os.getenv("MONGODB_URL") | |
uri = f"mongodb+srv://berkaygkv:{password}@{url}/?retryWrites=true&w=majority" | |
client = MongoClient(uri) | |
db = client.spoti | |
mongo_db_collection = db.saved_tracks | |
recommender = AudioEncoder(mongo_db_collection) | |
recommender.load_existing_audio_vectors() | |
llm_pipeline = load_llm_pipeline() | |
return recommender, llm_pipeline | |
def output_songs(text): | |
output = llm_pipeline.process_user_description(text) | |
if output: | |
song_list = [] | |
for _, song_desc in output: | |
print(song_desc) | |
ranking = recommender.list_top_k_songs(song_desc, k=15) | |
song_list += ranking | |
return pd.DataFrame(song_list)\ | |
.sort_values("score", ascending=False)\ | |
.drop_duplicates(subset=["track_id"])\ | |
.reset_index(drop=True) | |
else: | |
return None | |
recommender, llm_pipeline = load_resources() | |
st.title("""Curate me a Playlist.""") | |
st.info(""" | |
Hey there, introducing the Music Playlist Curator AI! It's designed to craft playlists based on your descriptions. | |
Here's the breakdown: we've got a Mistral 7B-Instruct 5-bit quantized version LLM running on the CPU to handle user inputs, | |
and a Contrastive Learning model from the Amazing [LAION AI](https://github.com/LAION-AI/CLAP) team for Audio-Text joint embeddings, scoring song similarity. | |
The songs are pulled from my personal Spotify Liked Songs through API. Using an automated data extraction pipeline, | |
I queried each song on my list on YouTube, downloaded it, | |
extracted audio features, and stored them on MongoDB. | |
TODOs: | |
- [ ] Making playlists on users' own Spotify Tracks, | |
- [ ] Display leaderboard to show the best playlist curated, | |
- [ ] Generate the playlist on Spotify directly | |
""") | |
st.success("The pipeline running on CPU which might take a few minutes to process.") | |
st.warning(""" | |
A caveat: because the audio data is retrieved from YouTube, | |
there's a chance some songs might not be top-notch quality or could be live versions, impacting the audio features' quality. | |
Another caveat: I've given it a spin with some Turkish descriptions, had some wins and some misses. I might wanna upgrade to a GPU powered environment | |
to enchance LLM capacity in the future. | |
Give it a shot and see how it goes! 🎶 | |
""") | |
# st.success(""" | |
# """) | |
session.text_input = st.text_input(label="Describe a playlist") | |
session.slider_count = st.slider(label="How many tracks", min_value=5, max_value=35, step=5) | |
buffer1, col1, buffer2 = st.columns([1.45, 1, 1]) | |
is_clicked = col1.button(label="Curate") | |
if is_clicked: | |
dataframe = output_songs(session.text_input) | |
if isinstance(dataframe, pd.DataFrame): | |
dataframe = dataframe.iloc[:session.slider_count] | |
dataframe.drop_duplicates(subset=["track_id"], inplace=True) | |
dataframe.drop(columns=["track_id", "score"], inplace=True) | |
st.data_editor( | |
dataframe, | |
column_config={ | |
"link": st.column_config.LinkColumn( | |
"link", | |
) | |
}, | |
hide_index=False, | |
use_container_width=True | |
) | |
else: | |
st.warning("User prompt could not be processed.") | |
# with st.form(key="spotiform"): | |
# st.form_submit_button(on_click=authenticate_spotify, args=(session.access_url, )) | |
# st.markdown(session.access_url) | |