Joash2024 commited on
Commit
b7e4e4f
1 Parent(s): 6d2fac8

Organize files into folders

Browse files
data/new_label_encoders.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3592327f59f9de84ff0d96dd4a48c1785380fe523cfe11c980336780b54eb5da
3
+ size 5370329
data/new_scaler.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6dd8e3793f8411eeef102ec005d2f5ad5c2ef127fd3e890c176817170b9b25d
3
+ size 1063
models/improved_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1cc30b55b50184f85132d396837b1fbef7ccbdb2f3f967f44c58a3a02270f84
3
+ size 2785870
src/app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from joblib import load
5
+
6
+ # Define the same neural network model
7
+ class ImprovedSongRecommender(nn.Module):
8
+ def __init__(self, input_size, num_titles):
9
+ super(ImprovedSongRecommender, self).__init__()
10
+ self.fc1 = nn.Linear(input_size, 128)
11
+ self.bn1 = nn.BatchNorm1d(128)
12
+ self.fc2 = nn.Linear(128, 256)
13
+ self.bn2 = nn.BatchNorm1d(256)
14
+ self.fc3 = nn.Linear(256, 128)
15
+ self.bn3 = nn.BatchNorm1d(128)
16
+ self.output = nn.Linear(128, num_titles)
17
+ self.dropout = nn.Dropout(0.5)
18
+
19
+ def forward(self, x):
20
+ x = torch.relu(self.bn1(self.fc1(x)))
21
+ x = self.dropout(x)
22
+ x = torch.relu(self.bn2(self.fc2(x)))
23
+ x = self.dropout(x)
24
+ x = torch.relu(self.bn3(self.fc3(x)))
25
+ x = self.dropout(x)
26
+ x = self.output(x)
27
+ return x
28
+
29
+ # Load the trained model
30
+ model_path = "C:/Users/joash/Desktop/Neurobytes_final_project/Neurobytes_Music_Recommender/models/improved_model.pth"
31
+ num_unique_titles = 4855
32
+
33
+ model = ImprovedSongRecommender(input_size=2, num_titles=num_unique_titles)
34
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
35
+ model.eval()
36
+
37
+ # Load the label encoders and scaler
38
+ label_encoders_path = "C:/Users/joash/Desktop/Neurobytes_final_project/Neurobytes_Music_Recommender/data/new_label_encoders.joblib"
39
+ scaler_path = "C:/Users/joash/Desktop/Neurobytes_final_project/Neurobytes_Music_Recommender/data/new_scaler.joblib"
40
+
41
+ label_encoders = load(label_encoders_path)
42
+ scaler = load(scaler_path)
43
+
44
+ # Create a mapping from encoded indices to actual song titles
45
+ index_to_song_title = {index: title for index, title in enumerate(label_encoders['title'].classes_)}
46
+
47
+ def encode_input(tags, artist_name):
48
+ tags = tags.strip().replace('\n', '')
49
+ artist_name = artist_name.strip().replace('\n', '')
50
+
51
+ try:
52
+ encoded_tags = label_encoders['tags'].transform([tags])[0]
53
+ except ValueError:
54
+ encoded_tags = label_encoders['tags'].transform(['unknown'])[0]
55
+
56
+ if artist_name:
57
+ try:
58
+ encoded_artist = label_encoders['artist_name'].transform([artist_name])[0]
59
+ except ValueError:
60
+ encoded_artist = label_encoders['artist_name'].transform(['unknown'])[0]
61
+ else:
62
+ encoded_artist = label_encoders['artist_name'].transform(['unknown'])[0]
63
+
64
+ return [encoded_tags, encoded_artist]
65
+
66
+ def recommend_songs(tags, artist_name):
67
+ encoded_input = encode_input(tags, artist_name)
68
+ input_tensor = torch.tensor([encoded_input]).float()
69
+
70
+ with torch.no_grad():
71
+ output = model(input_tensor)
72
+
73
+ recommendations_indices = torch.topk(output, 5).indices.squeeze().tolist()
74
+ recommendations = [index_to_song_title.get(idx, "Unknown song") for idx in recommendations_indices]
75
+
76
+ formatted_output = [f"Recommendation {i+1}: {rec}" for i, rec in enumerate(recommendations)]
77
+ return formatted_output
78
+
79
+ # Set up the Gradio interface
80
+ interface = gr.Interface(
81
+ fn=recommend_songs,
82
+ inputs=[gr.Textbox(lines=1, placeholder="Enter Tags (e.g., rock)"), gr.Textbox(lines=1, placeholder="Enter Artist Name (optional)")],
83
+ outputs=gr.Textbox(label="Recommendations"),
84
+ title="Music Recommendation System",
85
+ description="Enter tags and (optionally) artist name to get music recommendations."
86
+ )
87
+
88
+ interface.launch()