Polo123 commited on
Commit
4ee9a82
·
verified ·
1 Parent(s): 6e56fb5

Update logic.py

Browse files
Files changed (1) hide show
  1. logic.py +25 -1
logic.py CHANGED
@@ -424,9 +424,33 @@ def make_pyg_graph(movie_rec_db):
424
  rev_edge_types=[('movie', 'rev_rates', 'user')],
425
  )(data)
426
 
427
- return train_data, val_data, test_data
428
 
429
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
  def train(train_data, val_data, test_data):
432
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
424
  rev_edge_types=[('movie', 'rev_rates', 'user')],
425
  )(data)
426
 
427
+ return data,train_data, val_data, test_data
428
 
429
 
430
+ def load_model(train_data, val_data, test_data):
431
+ model = Model(hidden_channels=32)
432
+ with torch.no_grad():
433
+ model.encoder(train_data.x_dict, train_data.edge_index_dict)
434
+ model.load_state_dict(torch.load('model.pt'))
435
+ model.eval()
436
+ return model
437
+
438
+ def get_recommendation(model,data,user_id):
439
+
440
+ movies = movie_rec_db.collection('Movie')
441
+ total_movies = len(movies)
442
+
443
+ user_row = torch.tensor([user_id] * total_movies)
444
+ all_movie_ids = torch.arange(total_movies)
445
+ edge_label_index = torch.stack([user_row, all_movie_ids], dim=0)
446
+ pred = model(data.x_dict, data.edge_index_dict,edge_label_index)
447
+ pred = pred.clamp(min=0, max=5)
448
+ # we will only select movies for the user where the predicting rating is =5
449
+ rec_movie_ids = (pred == 5).nonzero(as_tuple=True)
450
+ top_ten_recs = [rec_movies for rec_movies in rec_movie_ids[0].tolist()[:10]]
451
+ return {'user': user_id, 'rec_movies': top_ten_recs}
452
+
453
+
454
 
455
  def train(train_data, val_data, test_data):
456
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')