File size: 1,247 Bytes
c238491
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from transformers import PreTrainedModel
from OmicsConfig import OmicsConfig
from transformers import PretrainedConfig, PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Batch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch_geometric.utils import negative_sampling
from torch.nn.functional import cosine_similarity
from torch.optim.lr_scheduler import StepLR


class GATv2EncoderModel(PreTrainedModel):
    config_class = OmicsConfig
    base_model_prefix = "gatv2_encoder"

    def __init__(self, config):
        super().__init__(config)
        self.layers = nn.ModuleList([
            GATv2Conv(config.in_channels if i == 0 else config.out_channels, config.out_channels, heads=1, concat=True, edge_dim=config.edge_attr_channels, add_self_loops=False)
            for i in range(config.num_layers)
        ])

    def forward(self, x, edge_index, edge_attr):
        attention_weights = []
        for layer in self.layers:
            x, attn_weights = layer(x, edge_index, edge_attr, return_attention_weights=True)
            attention_weights.append(attn_weights)
        return x, attention_weights