omics-plip-1 / scripts /genomic_plip_model.py
VatsalPatel18's picture
Upload 8 files
8381e8e verified
raw
history blame
724 Bytes
import torch
from transformers import CLIPVisionModel
class GenomicPLIPModel(torch.nn.Module):
def __init__(self, original_model):
super(GenomicPLIPModel, self).__init__()
self.vision_model = original_model.vision_model
self.vision_projection = torch.nn.Linear(768, 512)
self.fc_layer = torch.nn.Linear(4, 512) # Fully connected layer for the 4D vector
def forward(self, pixel_values, score_vector):
vision_output = self.vision_model(pixel_values)
pooled_output = vision_output.pooler_output
vision_features = self.vision_projection(pooled_output)
score_features = self.fc_layer(score_vector)
return vision_features, score_features