Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from joblib import load | |
import numpy as np | |
import pandas as pd | |
# Define the neural network model | |
class ImprovedSongRecommender(nn.Module): | |
def __init__(self, input_size, num_titles): | |
super(ImprovedSongRecommender, self).__init__() | |
self.fc1 = nn.Linear(input_size, 128) | |
self.bn1 = nn.BatchNorm1d(128) | |
self.fc2 = nn.Linear(128, 256) | |
self.bn2 = nn.BatchNorm1d(256) | |
self.fc3 = nn.Linear(256, 128) | |
self.bn3 = nn.BatchNorm1d(128) | |
self.output = nn.Linear(128, num_titles) | |
self.dropout = nn.Dropout(0.5) | |
def forward(self, x): | |
x = torch.relu(self.bn1(self.fc1(x))) | |
x = self.dropout(x) | |
x = torch.relu(self.bn2(self.fc2(x))) | |
x = self.dropout(x) | |
x = torch.relu(self.bn3(self.fc3(x))) | |
x = self.dropout(x) | |
x = self.output(x) | |
return x | |
# Load the trained model | |
model_path = "models/improved_model.pth" | |
num_unique_titles = 4855 | |
model = ImprovedSongRecommender(input_size=2, num_titles=num_unique_titles) | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
model.eval() | |
# Load the label encoders | |
label_encoders_path = "data/new_label_encoders.joblib" | |
label_encoders = load(label_encoders_path) | |
def encode_input(tags, artist_name): | |
tags_list = [tag.strip() for tag in tags.split(',')] | |
encoded_tags_list = [] | |
for tag in tags_list: | |
try: | |
encoded_tags_list.append( | |
label_encoders['tags'].transform([tag])[0]) | |
except ValueError: | |
encoded_tags_list.append( | |
label_encoders['tags'].transform(['unknown'])[0]) | |
encoded_tags = np.mean(encoded_tags_list).astype( | |
int) if encoded_tags_list else label_encoders['tags'].transform(['unknown'])[0] | |
try: | |
encoded_artist = label_encoders['artist_name'].transform([artist_name])[ | |
0] | |
except ValueError: | |
encoded_artist = label_encoders['artist_name'].transform(['unknown'])[ | |
0] | |
return [encoded_tags, encoded_artist] | |
def recommend_songs(tags, artist_name): | |
encoded_input = encode_input(tags, artist_name) | |
input_tensor = torch.tensor([encoded_input]).float() | |
with torch.no_grad(): | |
output = model(input_tensor) | |
recommendations_indices = torch.topk(output, 5).indices.squeeze().tolist() | |
recommendations = [label_encoders['title'].inverse_transform( | |
[idx])[0] for idx in recommendations_indices] | |
return recommendations | |
def record_feedback(tags, recommendations, feedbacks): | |
# Load existing feedback if it exists | |
try: | |
feedback_df = pd.read_csv("feedback_data/feedback_data.csv") | |
except FileNotFoundError: | |
feedback_df = pd.DataFrame( | |
columns=["Tags", "Recommendation", "Feedback"]) | |
# Create new feedback entries | |
new_feedbacks = pd.DataFrame({ | |
"Tags": [tags] * len(recommendations), | |
"Recommendation": recommendations, | |
"Feedback": feedbacks | |
}) | |
# Only keep rows where both a song recommendation and a rating are present | |
new_feedbacks = new_feedbacks[new_feedbacks["Recommendation"] | |
!= "No recommendations found"] | |
new_feedbacks = new_feedbacks[new_feedbacks["Feedback"].notna()] | |
# Append new feedback to the existing dataframe | |
feedback_df = pd.concat([feedback_df, new_feedbacks], ignore_index=True) | |
# Save the updated dataframe to CSV | |
feedback_df.to_csv("feedback_data/feedback_data.csv", index=False) | |
return "Feedback recorded!" | |
app = gr.Blocks() | |
with app: | |
gr.Markdown("## Music Recommendation System") | |
tags_input = gr.Textbox( | |
label="Enter Tags (e.g., rock, jazz, pop)", placeholder="rock, pop") | |
submit_button = gr.Button("Get Recommendations") | |
recommendation_outputs = [ | |
gr.HTML(label=f"Recommendation {i+1}") for i in range(5)] | |
feedback_inputs = [gr.Radio( | |
choices=["Thumbs Up", "Thumbs Down"], label=f"Feedback {i+1}") for i in range(5)] | |
feedback_submit_button = gr.Button("Submit Feedback") | |
song_recommendations = [] | |
def display_recommendations(tags): | |
global song_recommendations | |
song_recommendations = recommend_songs(tags, "") | |
updated_recommendations = [ | |
gr.update(value=song) for song in song_recommendations] | |
updated_feedbacks = [gr.update(label=song) | |
for song in song_recommendations] | |
return updated_recommendations + updated_feedbacks | |
submit_button.click( | |
fn=display_recommendations, | |
inputs=[tags_input], | |
outputs=recommendation_outputs + feedback_inputs | |
) | |
def collect_feedback(tags, *feedbacks): | |
global song_recommendations | |
feedbacks = list(feedbacks) | |
record_feedback(tags, song_recommendations, feedbacks) | |
return "Feedback submitted successfully!" | |
feedback_confirmation = gr.Markdown("") | |
feedback_submit_button.click( | |
fn=collect_feedback, | |
inputs=[tags_input] + feedback_inputs, | |
outputs=feedback_confirmation | |
) | |
for i in range(5): | |
with gr.Row(): | |
gr.Column([recommendation_outputs[i], feedback_inputs[i]]) | |
with gr.Row(): | |
gr.Column([feedback_submit_button, feedback_confirmation]) | |
app.launch() | |