IBYDMT / app_lib /multimodal.py
jacopoteneggi's picture
Update
dffe47c verified
# SOURCE: https://github.com/Sulam-Group/IBYDMT/blob/main/ibydmt/multimodal.py
from abc import abstractmethod
from typing import Mapping, Optional
import clip
import open_clip
from transformers import (
AlignModel,
AlignProcessor,
BlipForImageTextRetrieval,
BlipProcessor,
FlavaModel,
FlavaProcessor,
)
from app_lib.config import Config
from app_lib.config import Constants as c
class VisionLanguageModel:
def __init__(self, backbone: Optional[str] = None, device=c.DEVICE):
pass
@abstractmethod
def encode_text(self, text):
pass
@abstractmethod
def encode_image(self, image):
pass
models: Mapping[str, VisionLanguageModel] = {}
def register_model(name):
def register(cls: VisionLanguageModel):
if name in models:
raise ValueError(f"Model {name} is already registered")
models[name] = cls
return register
def get_model_name_and_backbone(config: Config):
backbone = config.data.backbone.split(":")
if len(backbone) == 1:
backbone.append(None)
return backbone
def get_model(config: Config, device=c.DEVICE) -> VisionLanguageModel:
model_name, backbone = get_model_name_and_backbone(config)
return models[model_name](backbone, device=device)
def get_text_encoder(config: Config, device=c.DEVICE):
model = get_model(config, device=device)
return model.encode_text
def get_image_encoder(config: Config, device=c.DEVICE):
model = get_model(config, device=device)
return model.encode_image
@register_model(name="clip")
class CLIPModel(VisionLanguageModel):
def __init__(self, backbone: str, device=c.DEVICE):
self.model, self.preprocess = clip.load(backbone, device=device)
self.tokenize = clip.tokenize
self.device = device
def encode_text(self, text):
text = self.tokenize(text).to(self.device)
return self.model.encode_text(text)
def encode_image(self, image):
image = self.preprocess(image).unsqueeze(0).to(self.device)
return self.model.encode_image(image)
@register_model(name="open_clip")
class OpenClipModel(VisionLanguageModel):
OPENCLIP_WEIGHTS = {
"ViT-B-32": "laion2b_s34b_b79k",
"ViT-L-14": "laion2b_s32b_b82k",
}
def __init__(self, backbone: str, device=c.DEVICE):
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
backbone, pretrained=self.OPENCLIP_WEIGHTS[backbone], device=device
)
self.tokenize = open_clip.get_tokenizer(backbone)
self.device = device
def encode_text(self, text):
text = self.tokenize(text).to(self.device)
return self.model.encode_text(text)
def encode_image(self, image):
image = self.preprocess(image).unsqueeze(0).to(self.device)
return self.model.encode_image(image)
@register_model(name="flava")
class FLAVAModel(VisionLanguageModel):
HF_MODEL = "facebook/flava-full"
def __init__(self, backbone: Optional[str] = None, device=c.DEVICE):
if backbone is None:
backbone = self.HF_MODEL
self.model = FlavaModel.from_pretrained(backbone).to(device)
self.processor = FlavaProcessor.from_pretrained(backbone)
self.device = device
def encode_text(self, text):
text_inputs = self.processor(
text=text, return_tensors="pt", padding="max_length", max_length=77
)
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
return self.model.get_text_features(**text_inputs)[:, 0, :]
def encode_image(self, image):
image_inputs = self.processor(images=image, return_tensors="pt")
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
return self.model.get_image_features(**image_inputs)[:, 0, :]
@register_model(name="align")
class ALIGNModel(VisionLanguageModel):
HF_MODEL = "kakaobrain/align-base"
def __init__(self, backbone: Optional[str] = None, device=c.DEVICE):
if backbone is None:
backbone = self.HF_MODEL
self.model = AlignModel.from_pretrained(backbone).to(device)
self.processor = AlignProcessor.from_pretrained(backbone)
self.device = device
def encode_text(self, text):
text_inputs = self.processor(
text=text, return_tensors="pt", padding="max_length", max_length=77
)
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
return self.model.get_text_features(**text_inputs)
def encode_image(self, image):
image_inputs = self.processor(images=image, return_tensors="pt")
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
return self.model.get_image_features(**image_inputs)
@register_model(name="blip")
class BLIPModel(VisionLanguageModel):
HF_MODEL = "Salesforce/blip-itm-base-coco"
def __init__(self, backbone: Optional[str] = None, device=c.DEVICE):
if backbone is None:
backbone = self.HF_MODEL
self.model = BlipForImageTextRetrieval.from_pretrained(backbone).to(device)
self.processor = BlipProcessor.from_pretrained(backbone)
self.device = device
def encode_text(self, text):
text_inputs = self.processor(
text=text, return_tensors="pt", padding="max_length", max_length=77
)
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
question_embeds = self.model.text_encoder(**text_inputs)[0]
return self.model.text_proj(question_embeds[:, 0, :])
def encode_image(self, image):
image_inputs = self.processor(images=image, return_tensors="pt")
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
image_embeds = self.model.vision_model(**image_inputs)[0]
return self.model.vision_proj(image_embeds[:, 0, :])