HNSCC-MultiOmics-Risk-Feature-Extraction / MultiOmicsGraphAttentionAutoencoderModel.py
VatsalPatel18's picture
Model files
c238491
raw
history blame
7.06 kB
from transformers import PreTrainedModel
from OmicsConfig import OmicsConfig
from transformers import PretrainedConfig, PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Batch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch_geometric.utils import negative_sampling
from torch.nn.functional import cosine_similarity
from torch.optim.lr_scheduler import StepLR
from GATv2EncoderModel import GATv2EncoderModel
from GATv2DecoderModel import GATv2DecoderModel
from EdgeWeightPredictorModel import EdgeWeightPredictorModel
class MultiOmicsGraphAttentionAutoencoderModel(PreTrainedModel):
config_class = OmicsConfig
base_model_prefix = "graph-attention-autoencoder"
def __init__(self, config):
super().__init__(config)
self.encoder = GATv2EncoderModel(config)
self.decoder = GATv2DecoderModel(config)
self.optimizer = AdamW(list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=config.learning_rate)
self.scheduler = StepLR(self.optimizer, step_size=30, gamma=0.7)
def forward(self, x, edge_index, edge_attr):
z, attention_weights = self.encoder(x, edge_index, edge_attr)
x_reconstructed = self.decoder(z)
return x_reconstructed, attention_weights
def predict_edge_weights(self, z, edge_index):
return self.decoder.predict_edge_weights(z, edge_index)
def train_model(self, data_loader, device):
self.encoder.to(device)
self.decoder.to(device)
self.encoder.train()
self.decoder.train()
total_loss = 0
total_cosine_similarity = 0
loss_weight_node = 1.0
loss_weight_edge = 1.0
loss_weight_edge_attr = 1.0
for data in data_loader:
data = data.to(device)
self.optimizer.zero_grad()
z, attention_weights = self.encoder(data.x, data.edge_index, data.edge_attr)
x_reconstructed = self.decoder(z)
node_loss = graph_reconstruction_loss(x_reconstructed, data.x)
edge_loss = edge_reconstruction_loss(z, data.edge_index)
cos_sim = cosine_similarity(x_reconstructed, data.x, dim=-1).mean()
total_cosine_similarity += cos_sim.item()
pred_edge_weights = self.decoder.predict_edge_weights(z, data.edge_index)
edge_weight_loss = edge_weight_reconstruction_loss(pred_edge_weights, data.edge_attr)
loss = (loss_weight_node * node_loss) + (loss_weight_edge * edge_loss) + (loss_weight_edge_attr * edge_weight_loss)
print(f"node_loss: {node_loss}, edge_loss: {edge_loss:.4f}, edge_weight_loss: {edge_weight_loss:.4f}, cosine_similarity: {cos_sim:.4f}")
loss.backward()
self.optimizer.step()
total_loss += loss.item()
avg_loss, avg_cosine_similarity = total_loss / len(data_loader), total_cosine_similarity / len(data_loader)
return avg_loss, avg_cosine_similarity
def fit(self, train_loader, validation_loader, epochs, device):
train_losses = []
val_losses = []
for epoch in range(1, epochs + 1):
train_loss, train_cosine_similarity = self.train_model(train_loader, device)
torch.cuda.empty_cache()
val_loss, val_cosine_similarity = self.validate(validation_loader, device)
print(f"Epoch: {epoch}, Train Loss: {train_loss:.4f}, Train Cosine Similarity: {train_cosine_similarity:.4f}, Validation Loss: {val_loss:.4f}, Validation Cosine Similarity: {val_cosine_similarity:.4f}")
self.scheduler.step()
return train_losses, val_losses
def validate(self, validation_loader, device):
self.encoder.to(device)
self.decoder.to(device)
self.encoder.eval()
self.decoder.eval()
total_loss = 0
total_cosine_similarity = 0
with torch.no_grad():
for data in validation_loader:
data = data.to(device)
z, attention_weights = self.encoder(data.x, data.edge_index, data.edge_attr)
x_reconstructed = self.decoder(z)
node_loss = graph_reconstruction_loss(x_reconstructed, data.x)
edge_loss = edge_reconstruction_loss(z, data.edge_index)
cos_sim = cosine_similarity(x_reconstructed, data.x, dim=-1).mean()
total_cosine_similarity += cos_sim.item()
loss = node_loss + edge_loss
total_loss += loss.item()
avg_loss = total_loss / len(validation_loader)
avg_cosine_similarity = total_cosine_similarity / len(validation_loader)
return avg_loss, avg_cosine_similarity
def evaluate(self, test_loader, device):
self.encoder.to(device)
self.decoder.to(device)
self.encoder.eval()
self.decoder.eval()
total_loss = 0
total_accuracy = 0
with torch.no_grad():
for data in test_loader:
data = data.to(device)
z, attention_weights = self.encoder(data.x, data.edge_index, data.edge_attr)
x_reconstructed = self.decoder(z)
node_loss = graph_reconstruction_loss(x_reconstructed, data.x)
edge_loss = edge_reconstruction_loss(z, data.edge_index)
cos_sim = cosine_similarity(x_reconstructed, data.x, dim=-1).mean()
total_cosine_similarity += cos_sim.item()
loss = node_loss + edge_loss
total_loss += loss.item()
avg_loss = total_loss / len(validation_loader)
avg_cosine_similarity = total_cosine_similarity / len(validation_loader)
return avg_loss, avg_cosine_similarity
# Define a collate function for the DataLoader
def collate_graph_data(batch):
return Batch.from_data_list(batch)
# Define a function to create a DataLoader
def create_data_loader(train_data, batch_size=1, shuffle=True):
graph_data = list(train_data.values())
return DataLoader(graph_data, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_graph_data)
# Define functions for the losses
def graph_reconstruction_loss(pred_features, true_features):
return F.mse_loss(pred_features, true_features)
def edge_reconstruction_loss(z, pos_edge_index, neg_edge_index=None):
pos_logits = (z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=-1)
pos_loss = F.binary_cross_entropy_with_logits(pos_logits, torch.ones_like(pos_logits))
if neg_edge_index is None:
neg_edge_index = negative_sampling(pos_edge_index, z.size(0))
neg_logits = (z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=-1)
neg_loss = F.binary_cross_entropy_with_logits(neg_logits, torch.zeros_like(neg_logits))
return pos_loss + neg_loss
def edge_weight_reconstruction_loss(pred_weights, true_weights):
pred_weights = pred_weights.squeeze(-1)
return F.mse_loss(pred_weights, true_weights)