berkaygkv54's picture
latest
a20c02a
raw
history blame
3.73 kB
import streamlit as st
from streamlit import session_state as session
from src.config.configs import ProjectPaths
import numpy as np
from src.laion_clap.inference import AudioEncoder
import pickle
import torch
import pandas as pd
import json
import os
import smtplib, ssl
from dotenv import load_dotenv
st.set_page_config(page_title="Curate me a playlist", layout="wide")
load_dotenv()
@st.cache_data
def load_data():
vectors = np.load(ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.npy"))
with open(ProjectPaths.DATA_DIR.joinpath("vectors", "song_names.pkl"), "rb") as reader:
song_names = pickle.load(reader)
with open(ProjectPaths.DATA_DIR.joinpath("json", "youtube_data.json"), "r") as reader:
youtube_data = json.load(reader)
df_youtube = pd.DataFrame(youtube_data)
df_youtube["id"] = df_youtube["artist_name"] + " - " + df_youtube["track_name"] + ".wav"
df_youtube.set_index("id", inplace=True)
return vectors, song_names, df_youtube
@st.cache_resource
def load_model():
recommender = AudioEncoder()
return recommender
def send_curator(text):
port = int(os.getenv("PORT"))
print(port)
smtp_server = "smtp.gmail.com"
sender_email = os.getenv("EMAIL_ADDRESS")
receiver_email = os.getenv("EMAIL_RECEIVER")
password = os.getenv("EMAIL_PASSWORD")
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
msg = MIMEMultipart("alternative")
msg["Subject"] = "Curate me a playlist submission"
part1 = MIMEText(body, "plain")
msg.attach(part1)
context = ssl.create_default_context()
with smtplib.SMTP_SSL(smtp_server, port, context=context) as server:
server.login(sender_email, password)
server.sendmail(sender_email, receiver_email, msg)
print("Email sent.")
recommender = load_model()
audio_vectors, song_names, df_youtube = load_data()
st.title("""Curate me a Playlist.""")
session.text_input = st.text_input(label="Describe a playlist")
session.slider_count = st.slider(label="Track counts", min_value=5, max_value=30, step=5)
buffer1, col1, buffer2 = st.columns([1.45, 1, 1])
is_clicked = col1.button(label="Curate")
if is_clicked:
text_embed = recommender.get_text_embedding(session.text_input)
with torch.no_grad():
ranking = torch.tensor(audio_vectors) @ torch.tensor(text_embed).t()
ranking = ranking[:, 0].reshape(-1, 1)
dataframe = pd.DataFrame(ranking, columns=[session.text_input], index=song_names).rename(columns={session.text_input: "score"})
dataframe = dataframe.merge(df_youtube[["link"]], left_index=True, right_index=True, how="left").nlargest(int(session.slider_count), "score")
# st.dataframe(dataframe, use_container_width=True)
st.data_editor(
dataframe,
column_config={
"link": st.column_config.LinkColumn(
"link",
# help="The top trending Streamlit apps",
# validate="^https://[a-z]+\.streamlit\.app$",
# max_chars=100,
)
},
hide_index=False,
use_container_width=True
)
form = st.form("form")
form.write("You can submit the playlist you've curated")
sender = form.text_input("Name of the curator")
query = session.text_input
playlist = [f"{k}\n" for k in dataframe.index]
playlist_string = "\n".join(dataframe.index.tolist())
body = f"""\
Subject: Curate me a playlist submission
Curator --> {sender}
Query --> {session.text_input}
Playlist
{playlist_string}
"""
print(body)
is_submit = form.form_submit_button("Submit", on_click=send_curator, args=([body]))