Polo123's picture
Update logic2.py
fe189c2 verified
raw
history blame
3.54 kB
import pandas as pd
from tqdm import tqdm
import numpy as np
import itertools
import requests
import sys
import torch
import torch.nn.functional as F
from torch.nn import Linear
from arango import ArangoClient
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.transforms import RandomLinkSplit, ToUndirected
from sentence_transformers import SentenceTransformer
from torch_geometric.data import HeteroData
import yaml
import pickle
#----------------------------------------------
# SAGE model
class GNNEncoder(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
# these convolutions have been replicated to match the number of edge types
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
class EdgeDecoder(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
self.lin1 = Linear(2 * hidden_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, 1)
def forward(self, z_dict, edge_label_index):
row, col = edge_label_index
# concat user and movie embeddings
z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)
# concatenated embeddings passed to linear layer
z = self.lin1(z).relu()
z = self.lin2(z)
return z.view(-1)
class Model(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
self.encoder = GNNEncoder(hidden_channels, hidden_channels)
self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
self.decoder = EdgeDecoder(hidden_channels)
def forward(self, x_dict, edge_index_dict, edge_label_index):
# z_dict contains dictionary of movie and user embeddings returned from GraphSage
z_dict = self.encoder(x_dict, edge_index_dict)
return self.decoder(z_dict, edge_label_index)
#----------------------------------------------
def load_hetero_data():
with open('Hgraph.pkl', 'rb') as file:
global data
data = pickle.load(file)
return data
def load_model(data):
model = Model(hidden_channels=32)
with torch.no_grad():
model.encoder(data.x_dict, data.edge_index_dict)
model.load_state_dict(torch.load('model.pt',map_location=torch.device('cpu')))
model.eval()
return model
global id_map
with open('id_map.pkl', 'rb') as file:
id_map = pickle.load(file)
global m_id
with open('m_id.pkl', 'rb') as file:
m_id = pickle.load(file)
def get_movie(idx):
return id_map.loc[id_map['movieId'] == m_id[idx]].index
def get_recommendation(model,data,user_id):
total_movies = 9025
user_row = torch.tensor([user_id] * total_movies)
all_movie_ids = torch.arange(total_movies)
edge_label_index = torch.stack([user_row, all_movie_ids], dim=0)
pred = model(data.x_dict, data.edge_index_dict,edge_label_index)
pred = pred.clamp(min=0, max=5)
# we will only select movies for the user where the predicting rating is =5
rec_movie_ids = (pred == 5).nonzero(as_tuple=True)
top_ten_recs = [rec_movies for rec_movies in rec_movie_ids[0].tolist()[:10]]
top_ten_recs = [get_movie(movie_idx) for movie_idx in top_ten_recs]
return {'user': user_id, 'rec_movies': top_ten_recs}