|
import torch |
|
from transformers import CLIPVisionModel |
|
|
|
class GenomicPLIPModel(torch.nn.Module): |
|
def __init__(self, original_model): |
|
super(GenomicPLIPModel, self).__init__() |
|
self.vision_model = original_model.vision_model |
|
self.vision_projection = torch.nn.Linear(768, 512) |
|
self.fc_layer = torch.nn.Linear(4, 512) |
|
|
|
def forward(self, pixel_values, score_vector): |
|
vision_output = self.vision_model(pixel_values) |
|
pooled_output = vision_output.pooler_output |
|
vision_features = self.vision_projection(pooled_output) |
|
score_features = self.fc_layer(score_vector) |
|
|
|
return vision_features, score_features |
|
|