Update logic2.py
Browse files
logic2.py
CHANGED
@@ -74,4 +74,18 @@ def load_model(data):
|
|
74 |
model.load_state_dict(torch.load('model.pt',map_location=torch.device('cpu')))
|
75 |
model.eval()
|
76 |
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
|
|
74 |
model.load_state_dict(torch.load('model.pt',map_location=torch.device('cpu')))
|
75 |
model.eval()
|
76 |
return model
|
77 |
+
|
78 |
+
def get_recommendation(model,data,user_id):
|
79 |
+
|
80 |
+
total_movies = 9025
|
81 |
+
|
82 |
+
user_row = torch.tensor([user_id] * total_movies)
|
83 |
+
all_movie_ids = torch.arange(total_movies)
|
84 |
+
edge_label_index = torch.stack([user_row, all_movie_ids], dim=0)
|
85 |
+
pred = model(data.x_dict, data.edge_index_dict,edge_label_index)
|
86 |
+
pred = pred.clamp(min=0, max=5)
|
87 |
+
# we will only select movies for the user where the predicting rating is =5
|
88 |
+
rec_movie_ids = (pred == 5).nonzero(as_tuple=True)
|
89 |
+
top_ten_recs = [rec_movies for rec_movies in rec_movie_ids[0].tolist()[:10]]
|
90 |
+
return {'user': user_id, 'rec_movies': top_ten_recs}
|
91 |
|