berkaygkv54's picture
llm integration
24510fe
raw
history blame
No virus
2.86 kB
import streamlit as st
from streamlit import session_state as session
from src.laion_clap.inference import AudioEncoder
from src.utils.spotify import SpotifyHandler, SpotifyAuthentication
import pandas as pd
from dotenv import load_dotenv
from langchain.llms import CTransformers, Ollama
from src.llm.chain import LLMChain
from pymongo.mongo_client import MongoClient
import os
st.set_page_config(page_title="Curate me a playlist", layout="wide")
load_dotenv()
def load_llm_pipeline():
ctransformers_config = {
"max_new_tokens": 3000,
"temperature": 0,
"top_k": 1,
"top_p": 1,
"context_length": 2800
}
llm = CTransformers(
model="TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
model_file="mistral-7b-instruct-v0.1.Q5_K_M.gguf",
config=ctransformers_config
)
# llm = Ollama(temperature=0, model="mistral:7b-instruct-q8_0", top_k=1, top_p=1, num_ctx=2800)
chain = LLMChain(llm)
return chain
@st.cache_resource
def load_resources():
password = os.getenv("MONGODB_PASSWORD")
url = os.getenv("MONGODB_URL")
uri = f"mongodb+srv://berkaygkv:{password}@{url}/?retryWrites=true&w=majority"
client = MongoClient(uri)
db = client.spoti
mongo_db_collection = db.saved_tracks
recommender = AudioEncoder(mongo_db_collection)
recommender.load_existing_audio_vectors()
llm_pipeline = load_llm_pipeline()
return recommender, llm_pipeline
recommender, llm_pipeline = load_resources()
st.title("""Curate me a Playlist.""")
session.text_input = st.text_input(label="Describe a playlist")
session.slider_count = st.slider(label="Track counts", min_value=5, max_value=35, step=5)
buffer1, col1, col2, buffer2 = st.columns([1.45, 1, 1, 1])
is_clicked = col1.button(label="Curate")
if is_clicked:
output = llm_pipeline.process_user_description(session.text_input)
song_list = []
for _, song_desc in output:
print(song_desc)
ranking = recommender.list_top_k_songs(song_desc, k=15)
song_list += ranking
dataframe = pd.DataFrame(song_list).sort_values("score", ascending=False).drop_duplicates(subset=["track_id"]).drop(columns=["track_id"]).reset_index(drop=True)
dataframe = dataframe.iloc[:session.slider_count]
st.data_editor(
dataframe,
column_config={
"link": st.column_config.LinkColumn(
"link",
# help="The top trending Streamlit apps",
# validate="^https://[a-z]+\.streamlit\.app$",
# max_chars=100,
)
},
hide_index=False,
use_container_width=True
)
# with st.form(key="spotiform"):
# st.form_submit_button(on_click=authenticate_spotify, args=(session.access_url, ))
# st.markdown(session.access_url)