File size: 724 Bytes
8381e8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
from transformers import CLIPVisionModel

class GenomicPLIPModel(torch.nn.Module):
    def __init__(self, original_model):
        super(GenomicPLIPModel, self).__init__()
        self.vision_model = original_model.vision_model
        self.vision_projection = torch.nn.Linear(768, 512)
        self.fc_layer = torch.nn.Linear(4, 512)  # Fully connected layer for the 4D vector

    def forward(self, pixel_values, score_vector):
        vision_output = self.vision_model(pixel_values)
        pooled_output = vision_output.pooler_output
        vision_features = self.vision_projection(pooled_output)
        score_features = self.fc_layer(score_vector)
        
        return vision_features, score_features