Polo123 commited on
Commit
2e4da47
1 Parent(s): fa01a43

Update logic2.py

Browse files
Files changed (1) hide show
  1. logic2.py +14 -0
logic2.py CHANGED
@@ -74,4 +74,18 @@ def load_model(data):
74
  model.load_state_dict(torch.load('model.pt',map_location=torch.device('cpu')))
75
  model.eval()
76
  return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
 
74
  model.load_state_dict(torch.load('model.pt',map_location=torch.device('cpu')))
75
  model.eval()
76
  return model
77
+
78
+ def get_recommendation(model,data,user_id):
79
+
80
+ total_movies = 9025
81
+
82
+ user_row = torch.tensor([user_id] * total_movies)
83
+ all_movie_ids = torch.arange(total_movies)
84
+ edge_label_index = torch.stack([user_row, all_movie_ids], dim=0)
85
+ pred = model(data.x_dict, data.edge_index_dict,edge_label_index)
86
+ pred = pred.clamp(min=0, max=5)
87
+ # we will only select movies for the user where the predicting rating is =5
88
+ rec_movie_ids = (pred == 5).nonzero(as_tuple=True)
89
+ top_ten_recs = [rec_movies for rec_movies in rec_movie_ids[0].tolist()[:10]]
90
+ return {'user': user_id, 'rec_movies': top_ten_recs}
91