HNSCC-MultiOmics-Risk-Feature-Extraction / EdgeWeightPredictorModel.py
VatsalPatel18's picture
Model files
c238491
from transformers import PreTrainedModel
from OmicsConfig import OmicsConfig
from transformers import PretrainedConfig, PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Batch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch_geometric.utils import negative_sampling
from torch.nn.functional import cosine_similarity
from torch.optim.lr_scheduler import StepLR
class EdgeWeightPredictorModel(PreTrainedModel):
config_class = OmicsConfig
base_model_prefix = "edge_weight_predictor"
def __init__(self, config):
super().__init__(config)
layers = []
input_size = 2 * config.out_channels
for hidden_size, activation in zip(config.edge_decoder_hidden_sizes, config.edge_decoder_activations):
layers.append(nn.Linear(input_size, hidden_size))
if activation == 'ReLU':
layers.append(nn.ReLU())
elif activation == 'Sigmoid':
layers.append(nn.Sigmoid())
elif activation == 'Tanh':
layers.append(nn.Tanh())
# Add more activations if needed
input_size = hidden_size
layers.append(nn.Linear(input_size, 1))
self.predictor = nn.Sequential(*layers)
def forward(self, z, edge_index):
edge_embeddings = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=-1)
return self.predictor(edge_embeddings)