File size: 2,425 Bytes
99aee7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import torch
import torch.nn as nn
from transformers import CLIPImageProcessor
try:
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
from imagebind.data import load_and_transform_audio_data
except ImportError:
pass
class ImageBindWrapper(nn.Module):
def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = select_layer
self.select_feature = select_feature
if not delay_load:
self.load_model()
def load_model(self):
self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.vision_tower = imagebind_model.imagebind_huge(pretrained=True)
for p in self.vision_tower.parameters():
p.requires_grad = False
self.vision_tower.eval()
self.is_loaded = True
def train(self, mode=True):
self.training = mode
if self.is_loaded:
self.vision_tower.eval()
@torch.no_grad()
def forward(self, x):
if type(x) == dict:
if x["audios"] is not None:
inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()}
embeddings = self.vision_tower(inputs)
audio_embedding = embeddings[ModalityType.AUDIO]
return audio_embedding.unsqueeze(1)
else:
inputs = {ModalityType.VISION: x.to(dtype=self.dtype)}
embeddings = self.vision_tower(inputs)
vision_embedding = embeddings[ModalityType.VISION]
if vision_embedding.ndim == 2:
return vision_embedding.unsqueeze(1)
if vision_embedding.shape[1] == 257:
return vision_embedding[:, 1:]
raise ValueError(f"Unexpected shape: {vision_embedding.shape}")
@property
def dummy_feature(self):
return torch.zeros(1, 1024, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.modality_preprocessors.vision.cls_token.dtype
@property
def device(self):
return self.vision_tower.modality_preprocessors.vision.cls_token.device
@property
def hidden_size(self):
return 1024
|