nickkornienko's picture
added feedback submit confirmation
9d7da71
import gradio as gr
import torch
import torch.nn as nn
from joblib import load
import numpy as np
import pandas as pd
# Define the neural network model
class ImprovedSongRecommender(nn.Module):
def __init__(self, input_size, num_titles):
super(ImprovedSongRecommender, self).__init__()
self.fc1 = nn.Linear(input_size, 128)
self.bn1 = nn.BatchNorm1d(128)
self.fc2 = nn.Linear(128, 256)
self.bn2 = nn.BatchNorm1d(256)
self.fc3 = nn.Linear(256, 128)
self.bn3 = nn.BatchNorm1d(128)
self.output = nn.Linear(128, num_titles)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = torch.relu(self.bn1(self.fc1(x)))
x = self.dropout(x)
x = torch.relu(self.bn2(self.fc2(x)))
x = self.dropout(x)
x = torch.relu(self.bn3(self.fc3(x)))
x = self.dropout(x)
x = self.output(x)
return x
# Load the trained model
model_path = "models/improved_model.pth"
num_unique_titles = 4855
model = ImprovedSongRecommender(input_size=2, num_titles=num_unique_titles)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# Load the label encoders
label_encoders_path = "data/new_label_encoders.joblib"
label_encoders = load(label_encoders_path)
def encode_input(tags, artist_name):
tags_list = [tag.strip() for tag in tags.split(',')]
encoded_tags_list = []
for tag in tags_list:
try:
encoded_tags_list.append(
label_encoders['tags'].transform([tag])[0])
except ValueError:
encoded_tags_list.append(
label_encoders['tags'].transform(['unknown'])[0])
encoded_tags = np.mean(encoded_tags_list).astype(
int) if encoded_tags_list else label_encoders['tags'].transform(['unknown'])[0]
try:
encoded_artist = label_encoders['artist_name'].transform([artist_name])[
0]
except ValueError:
encoded_artist = label_encoders['artist_name'].transform(['unknown'])[
0]
return [encoded_tags, encoded_artist]
def recommend_songs(tags, artist_name):
encoded_input = encode_input(tags, artist_name)
input_tensor = torch.tensor([encoded_input]).float()
with torch.no_grad():
output = model(input_tensor)
recommendations_indices = torch.topk(output, 5).indices.squeeze().tolist()
recommendations = [label_encoders['title'].inverse_transform(
[idx])[0] for idx in recommendations_indices]
return recommendations
def record_feedback(tags, recommendations, feedbacks):
# Load existing feedback if it exists
try:
feedback_df = pd.read_csv("feedback_data/feedback_data.csv")
except FileNotFoundError:
feedback_df = pd.DataFrame(
columns=["Tags", "Recommendation", "Feedback"])
# Create new feedback entries
new_feedbacks = pd.DataFrame({
"Tags": [tags] * len(recommendations),
"Recommendation": recommendations,
"Feedback": feedbacks
})
# Only keep rows where both a song recommendation and a rating are present
new_feedbacks = new_feedbacks[new_feedbacks["Recommendation"]
!= "No recommendations found"]
new_feedbacks = new_feedbacks[new_feedbacks["Feedback"].notna()]
# Append new feedback to the existing dataframe
feedback_df = pd.concat([feedback_df, new_feedbacks], ignore_index=True)
# Save the updated dataframe to CSV
feedback_df.to_csv("feedback_data/feedback_data.csv", index=False)
return "Feedback recorded!"
app = gr.Blocks()
with app:
gr.Markdown("## Music Recommendation System")
tags_input = gr.Textbox(
label="Enter Tags (e.g., rock, jazz, pop)", placeholder="rock, pop")
submit_button = gr.Button("Get Recommendations")
recommendation_outputs = [
gr.HTML(label=f"Recommendation {i+1}") for i in range(5)]
feedback_inputs = [gr.Radio(
choices=["Thumbs Up", "Thumbs Down"], label=f"Feedback {i+1}") for i in range(5)]
feedback_submit_button = gr.Button("Submit Feedback")
song_recommendations = []
def display_recommendations(tags):
global song_recommendations
song_recommendations = recommend_songs(tags, "")
updated_recommendations = [
gr.update(value=song) for song in song_recommendations]
updated_feedbacks = [gr.update(label=song)
for song in song_recommendations]
return updated_recommendations + updated_feedbacks
submit_button.click(
fn=display_recommendations,
inputs=[tags_input],
outputs=recommendation_outputs + feedback_inputs
)
def collect_feedback(tags, *feedbacks):
global song_recommendations
feedbacks = list(feedbacks)
record_feedback(tags, song_recommendations, feedbacks)
return "Feedback submitted successfully!"
feedback_confirmation = gr.Markdown("")
feedback_submit_button.click(
fn=collect_feedback,
inputs=[tags_input] + feedback_inputs,
outputs=feedback_confirmation
)
for i in range(5):
with gr.Row():
gr.Column([recommendation_outputs[i], feedback_inputs[i]])
with gr.Row():
gr.Column([feedback_submit_button, feedback_confirmation])
app.launch()