Update logic.py
Browse files
logic.py
CHANGED
@@ -424,9 +424,33 @@ def make_pyg_graph(movie_rec_db):
|
|
424 |
rev_edge_types=[('movie', 'rev_rates', 'user')],
|
425 |
)(data)
|
426 |
|
427 |
-
return train_data, val_data, test_data
|
428 |
|
429 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
|
431 |
def train(train_data, val_data, test_data):
|
432 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
424 |
rev_edge_types=[('movie', 'rev_rates', 'user')],
|
425 |
)(data)
|
426 |
|
427 |
+
return data,train_data, val_data, test_data
|
428 |
|
429 |
|
430 |
+
def load_model(train_data, val_data, test_data):
|
431 |
+
model = Model(hidden_channels=32)
|
432 |
+
with torch.no_grad():
|
433 |
+
model.encoder(train_data.x_dict, train_data.edge_index_dict)
|
434 |
+
model.load_state_dict(torch.load('model.pt'))
|
435 |
+
model.eval()
|
436 |
+
return model
|
437 |
+
|
438 |
+
def get_recommendation(model,data,user_id):
|
439 |
+
|
440 |
+
movies = movie_rec_db.collection('Movie')
|
441 |
+
total_movies = len(movies)
|
442 |
+
|
443 |
+
user_row = torch.tensor([user_id] * total_movies)
|
444 |
+
all_movie_ids = torch.arange(total_movies)
|
445 |
+
edge_label_index = torch.stack([user_row, all_movie_ids], dim=0)
|
446 |
+
pred = model(data.x_dict, data.edge_index_dict,edge_label_index)
|
447 |
+
pred = pred.clamp(min=0, max=5)
|
448 |
+
# we will only select movies for the user where the predicting rating is =5
|
449 |
+
rec_movie_ids = (pred == 5).nonzero(as_tuple=True)
|
450 |
+
top_ten_recs = [rec_movies for rec_movies in rec_movie_ids[0].tolist()[:10]]
|
451 |
+
return {'user': user_id, 'rec_movies': top_ten_recs}
|
452 |
+
|
453 |
+
|
454 |
|
455 |
def train(train_data, val_data, test_data):
|
456 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|