AJ-Gazin commited on
Commit
5cc7af1
·
1 Parent(s): f3cdf0c

first commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.csv filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ etc/
2
+ include/
3
+ lib/
4
+ library/
5
+ Scripts/
6
+ share/
7
+ __pycache__/
8
+ test.ipynb
README.md CHANGED
@@ -1,12 +0,0 @@
1
- ---
2
- title: GNN GradioDemo
3
- emoji: 🚀
4
- colorFrom: green
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 4.31.4
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import plotly.express as px
3
+ import pandas as pd
4
+ import torch
5
+ from sklearn.decomposition import PCA
6
+ from sklearn.manifold import TSNE
7
+ import umap
8
+ from recommender import get_recommendations
9
+
10
+ # Loading pre-trained product embeddings for visualization
11
+ product_embeddings_path = 'data/product_embeddings.pt'
12
+ product_emb = torch.load(product_embeddings_path, map_location=torch.device('cpu'))
13
+
14
+ # Loading pre-trained user embeddings for visualization
15
+ user_embeddings_path = 'data/user_embeddings.pt'
16
+ user_emb = torch.load(user_embeddings_path, map_location=torch.device('cpu'))
17
+
18
+ # Loading the reviews dataframe for visualization purposes
19
+ reviews_df = pd.read_csv('data/organized_reviews.csv')
20
+
21
+ # Loading the training and validation loss data
22
+ loss_data_path = 'data/loss_data.csv'
23
+ loss_df = pd.read_csv(loss_data_path)
24
+ loss_df.columns = ['Epoch', 'Training Loss', 'Validation Loss']
25
+
26
+ # Creating a user dataframe by extracting unique user IDs and usernames
27
+ user_df = reviews_df[['user_id', 'username']].drop_duplicates()
28
+
29
+
30
+ # Function to perform dimensionality reduction on embeddings
31
+ # This function reduces the high-dimensional embeddings to a lower-dimensional space for visualization
32
+ def reduce_dimensions(embeddings, method, n_components=3):
33
+ # Selecting the appropriate dimensionality reduction technique based on the specified method
34
+ if method == "PCA":
35
+ reducer = PCA(n_components=n_components)
36
+ else:
37
+ # Performing initial PCA to reduce dimensionality before applying t-SNE or UMAP
38
+ pca = PCA(n_components=50)
39
+ reduced_embeddings = pca.fit_transform(embeddings)
40
+ reducer = TSNE(n_components=n_components) if method == "TSNE" else umap.UMAP(n_components=n_components)
41
+ embeddings = reduced_embeddings
42
+
43
+ # Applying the selected dimensionality reduction technique to the embeddings
44
+ reduced_embeddings = reducer.fit_transform(embeddings)
45
+
46
+ # Assigning appropriate column names based on the dimensionality reduction method
47
+ columns = ['PC1', 'PC2', 'PC3'] if method == "PCA" else ['TSNE1', 'TSNE2', 'TSNE3'] if method == "TSNE" else ['UMAP1', 'UMAP2', 'UMAP3']
48
+
49
+ return reduced_embeddings, columns
50
+
51
+ # Function to visualize embeddings using interactive 3D scatter plots
52
+ # This function creates an interactive plot to explore the embeddings in a three-dimensional space
53
+ def visualize_embeddings(embeddings, df, method, is_product=True):
54
+ reduced_embeddings, columns = reduce_dimensions(embeddings, method)
55
+ df_reduced = pd.DataFrame(reduced_embeddings, columns=columns)
56
+
57
+ if is_product:
58
+ # Adding product-related information to the dataframe for hover interactions
59
+ df_reduced['product_id'] = df['product_id']
60
+ df_reduced['category'] = df['category']
61
+ fig = px.scatter_3d(df_reduced, x=columns[0], y=columns[1], z=columns[2], color='category', hover_data=['product_id'], opacity=0.9)
62
+ else:
63
+ # Adding user-related information to the dataframe for hover interactions
64
+ df_reduced['user_id'] = df['user_id']
65
+ df_reduced['username'] = df['username']
66
+ fig = px.scatter_3d(df_reduced, x=columns[0], y=columns[1], z=columns[2], hover_data=['user_id', 'username'], opacity=0.9)
67
+
68
+ return fig
69
+
70
+ # Function to visualize product embeddings
71
+ # This function specifically visualizes the product embeddings using the selected dimensionality reduction method
72
+ def visualize_product_embeddings(method):
73
+ return visualize_embeddings(product_emb.cpu().numpy(), reviews_df, method)
74
+
75
+ # Function to visualize user embeddings
76
+ # This function specifically visualizes the user embeddings using the selected dimensionality reduction method
77
+ def visualize_user_embeddings(method):
78
+ return visualize_embeddings(user_emb.cpu().numpy(), user_df, method, is_product=False)
79
+
80
+ # Function to visualize training and validation loss
81
+ # This function creates a line plot to visualize the model's training and validation loss over epochs
82
+ def visualize_loss():
83
+ fig = px.line(loss_df, x='Epoch', y=['Training Loss', 'Validation Loss'], labels={
84
+ 'Epoch': 'Epoch',
85
+ 'value': 'Loss',
86
+ 'variable': 'Loss Type'
87
+ })
88
+ fig.update_layout(title='Training and Validation Loss', legend_title='Loss Type')
89
+ return fig
90
+
91
+ # Function to generate product recommendations for a given username
92
+ # This function retrieves the user ID based on the provided username and generates personalized product recommendations
93
+ def recommend(username, method):
94
+ user_id = user_df[user_df['username'] == username]['user_id'].values[0]
95
+ recommendations_title, recommendations = get_recommendations(user_id)
96
+ recommendations_list = [[rec[0], rec[1], rec[2]] for rec in recommendations]
97
+ return recommendations_title, recommendations_list
98
+
99
+
100
+ # Sampling a subset of usernames for the dropdown menu
101
+ sample_usernames = user_df['username'].sample(5, random_state=42).tolist()
102
+
103
+ # Creating the Gradio interface for the recommendation system
104
+ with gr.Blocks() as demo:
105
+ gr.Markdown("# Amazon Product Recommendation System")
106
+
107
+ with gr.Column():
108
+ username_input = gr.Dropdown(label="Select Username", choices=sample_usernames, value=sample_usernames[0])
109
+ recommendations_output = gr.Textbox(label="Recommendations")
110
+ recommendations_list = gr.Dataframe(headers=["Product ID", "Category", "Subcategory"])
111
+ recommend_button = gr.Button("Get Recommendations")
112
+
113
+ with gr.Row():
114
+ with gr.Column():
115
+ gr.Markdown("### Product Embeddings Visualization")
116
+ method_input_product = gr.Dropdown(label="Visualization Method", choices=["PCA", "TSNE", "UMAP"], value="PCA")
117
+ embeddings_plot_product = gr.Plot(value=visualize_product_embeddings("PCA"))
118
+ with gr.Column():
119
+ gr.Markdown("### User Embeddings Visualization")
120
+ method_input_user = gr.Dropdown(label="Visualization Method", choices=["PCA", "TSNE", "UMAP"], value="PCA")
121
+ embeddings_plot_user = gr.Plot(value=visualize_user_embeddings("PCA"))
122
+
123
+ gr.Markdown("### Training and Validation Loss")
124
+ loss_plot = gr.Plot(value=visualize_loss())
125
+
126
+ # Event triggers and their corresponding actions
127
+ recommend_button.click(recommend, inputs=[username_input], outputs=[recommendations_output, recommendations_list])
128
+ method_input_product.change(visualize_product_embeddings, inputs=[method_input_product], outputs=[embeddings_plot_product])
129
+ method_input_user.change(visualize_user_embeddings, inputs=[method_input_user], outputs=[embeddings_plot_user])
130
+
131
+ # Running the Gradio interface
132
+ if __name__ == "__main__":
133
+ demo.launch()
data/amazon_reviews.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a88348c20ed3f85d647e8fbaac0a730ab2f09f95e5d1f4bcf1f9e3650ef624d7
3
+ size 300904694
data/loss_data.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd25a48b83864f4fc1248ff466dcbd5b38dddcd35cab44f883c0b8a8be8dff27
3
+ size 1179
data/organized_reviews.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9617a398f69a0fac22cd43a2588dda166bd089882d9c356e6e67883654a70066
3
+ size 9020826
data/product_data_PYG.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c3ea2997f012564fcb92760bb74842ae906cfcfe9f651834f2db27c445f7354
3
+ size 10706333
data/product_embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d69dd6db7b9190fd247673da1130ee73037ae66b789a5640ae5fa3c0569c5708
3
+ size 1484243
data/rev_user_mapping.json ADDED
The diff for this file is too large to render. See raw diff
 
data/reviews_sample.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21822c89e46dc619af4220e930f501a9220d8a2b29f26e19ebe21e158b6d0dfd
3
+ size 9135407
data/sample_metadata.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21822c89e46dc619af4220e930f501a9220d8a2b29f26e19ebe21e158b6d0dfd
3
+ size 9135407
data/user_embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e69bd6fa86cdf10af5ff335a6a42a7d948ad26dbe24a71bc646fc31ffbb2419
3
+ size 15396548
data/user_mapping.json ADDED
The diff for this file is too large to render. See raw diff
 
models/amazon_best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f97e25eb3065749e9f14c7ed161561bc64d86f26fd9f9e7078a499b274373114
3
+ size 231662
recommender.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ from torch_geometric.data import HeteroData
4
+ from torch_geometric.nn import SAGEConv, to_hetero
5
+ from torch.nn import Linear
6
+
7
+ # Load the trained model
8
+ class GNNEncoder(torch.nn.Module):
9
+ def __init__(self, hidden_channels, out_channels):
10
+ super().__init__()
11
+ self.conv1 = SAGEConv((-1, -1), hidden_channels)
12
+ self.conv2 = SAGEConv((-1, -1), out_channels)
13
+
14
+ def forward(self, x, edge_index):
15
+ x = self.conv1(x, edge_index).relu()
16
+ x = self.conv2(x, edge_index)
17
+ return x
18
+
19
+ class EdgeDecoder(torch.nn.Module):
20
+ def __init__(self, hidden_channels):
21
+ super().__init__()
22
+ self.lin1 = Linear(2 * hidden_channels, hidden_channels)
23
+ self.lin2 = Linear(hidden_channels, 1)
24
+
25
+ def forward(self, z_dict, edge_label_index):
26
+ row, col = edge_label_index
27
+ z = torch.cat([z_dict['user'][row], z_dict['products'][col]], dim=-1)
28
+ z = self.lin1(z).relu()
29
+ z = self.lin2(z)
30
+ return z.view(-1)
31
+
32
+ class Model(torch.nn.Module):
33
+ def __init__(self, hidden_channels):
34
+ super().__init__()
35
+ self.encoder = GNNEncoder(hidden_channels, hidden_channels)
36
+ self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
37
+ self.decoder = EdgeDecoder(hidden_channels)
38
+
39
+ def forward(self, x_dict, edge_index_dict, edge_label_index):
40
+ z_dict = self.encoder(x_dict, edge_index_dict)
41
+ return self.decoder(z_dict, edge_label_index)
42
+
43
+ # Load data and model
44
+ data_path = 'data/product_data_PYG.pt'
45
+ model_path = 'models/amazon_best_model.pt'
46
+ reviews_path = 'data/organized_reviews.csv'
47
+ user_mapping_path = 'data/user_mapping.json'
48
+ rev_user_mapping_path = 'data/rev_user_mapping.json'
49
+
50
+ print("Loading data...")
51
+ data = torch.load(data_path, map_location=torch.device('cpu'))
52
+ device = 'cpu'
53
+ data = data.to(device)
54
+
55
+ print("Loading model...")
56
+ model = Model(hidden_channels=32).to(device)
57
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
58
+
59
+ print("Loading reviews dataframe...")
60
+ reviews_df = pd.read_csv(reviews_path)
61
+
62
+ print("Loading user mappings...")
63
+ user_mapping = pd.read_json(user_mapping_path, typ='series').to_dict()
64
+ rev_user_mapping = pd.read_json(rev_user_mapping_path, typ='series').to_dict()
65
+
66
+ # Function to get the username from user_id
67
+ def get_username(user_id):
68
+ if user_id not in reviews_df['user_id'].values:
69
+ raise ValueError(f"User ID {user_id} not found in reviews_df")
70
+ return reviews_df[reviews_df['user_id'] == user_id]['username'].iloc[0]
71
+
72
+ # Function to get product recommendations
73
+ def get_product_recommendations(model, data, user_id, total_products):
74
+ user_idx = user_mapping[user_id] # Get the embedding index for the user_id
75
+ user_row = torch.tensor([user_idx] * total_products).to(device)
76
+ all_product_ids = torch.arange(total_products).to(device)
77
+ edge_label_index = torch.stack([user_row, all_product_ids], dim=0)
78
+ pred = model(data.x_dict, data.edge_index_dict, edge_label_index).cpu()
79
+ top_five_indices = pred.topk(5).indices.numpy() # Ensure indices are integers for indexing
80
+ recommendations = []
81
+ for idx in top_five_indices:
82
+ idx = int(idx) # Convert to integer for indexing
83
+ product_id = reviews_df.iloc[idx]['product_id']
84
+ category = reviews_df.iloc[idx]['category']
85
+ subcategory = reviews_df.iloc[idx]['subcategory']
86
+ recommendations.append((product_id, category, subcategory))
87
+ return recommendations
88
+
89
+ # Function to get and print recommendations for a given user
90
+ def get_recommendations(user_id):
91
+ try:
92
+ user_id = str(user_id)
93
+ username = get_username(user_id)
94
+ recommendations = get_product_recommendations(model, data, user_id, data['products'].x.shape[0])
95
+ return f"Recommendations for {username} (User ID: {user_id}):", recommendations
96
+ except Exception as e:
97
+ return f"Error: {str(e)}", []
98
+
99
+ if __name__ == "__main__":
100
+ # For testing the recommendation functionality
101
+ user_id = 'A314APAWYQFKBJ' # Example user ID
102
+ recommendations_title, recommendations = get_recommendations(user_id)
103
+ print(recommendations_title)
104
+ print(recommendations)
requirements.txt ADDED
Binary file (5.09 kB). View file