Spaces:
Sleeping
Sleeping
# SOURCE: https://github.com/Sulam-Group/IBYDMT/blob/main/ibydmt/multimodal.py | |
from abc import abstractmethod | |
from typing import Mapping, Optional | |
import clip | |
import open_clip | |
from transformers import ( | |
AlignModel, | |
AlignProcessor, | |
BlipForImageTextRetrieval, | |
BlipProcessor, | |
FlavaModel, | |
FlavaProcessor, | |
) | |
from app_lib.config import Config | |
from app_lib.config import Constants as c | |
class VisionLanguageModel: | |
def __init__(self, backbone: Optional[str] = None, device=c.DEVICE): | |
pass | |
def encode_text(self, text): | |
pass | |
def encode_image(self, image): | |
pass | |
models: Mapping[str, VisionLanguageModel] = {} | |
def register_model(name): | |
def register(cls: VisionLanguageModel): | |
if name in models: | |
raise ValueError(f"Model {name} is already registered") | |
models[name] = cls | |
return register | |
def get_model_name_and_backbone(config: Config): | |
backbone = config.data.backbone.split(":") | |
if len(backbone) == 1: | |
backbone.append(None) | |
return backbone | |
def get_model(config: Config, device=c.DEVICE) -> VisionLanguageModel: | |
model_name, backbone = get_model_name_and_backbone(config) | |
return models[model_name](backbone, device=device) | |
def get_text_encoder(config: Config, device=c.DEVICE): | |
model = get_model(config, device=device) | |
return model.encode_text | |
def get_image_encoder(config: Config, device=c.DEVICE): | |
model = get_model(config, device=device) | |
return model.encode_image | |
class CLIPModel(VisionLanguageModel): | |
def __init__(self, backbone: str, device=c.DEVICE): | |
self.model, self.preprocess = clip.load(backbone, device=device) | |
self.tokenize = clip.tokenize | |
self.device = device | |
def encode_text(self, text): | |
text = self.tokenize(text).to(self.device) | |
return self.model.encode_text(text) | |
def encode_image(self, image): | |
image = self.preprocess(image).unsqueeze(0).to(self.device) | |
return self.model.encode_image(image) | |
class OpenClipModel(VisionLanguageModel): | |
OPENCLIP_WEIGHTS = { | |
"ViT-B-32": "laion2b_s34b_b79k", | |
"ViT-L-14": "laion2b_s32b_b82k", | |
} | |
def __init__(self, backbone: str, device=c.DEVICE): | |
self.model, _, self.preprocess = open_clip.create_model_and_transforms( | |
backbone, pretrained=self.OPENCLIP_WEIGHTS[backbone], device=device | |
) | |
self.tokenize = open_clip.get_tokenizer(backbone) | |
self.device = device | |
def encode_text(self, text): | |
text = self.tokenize(text).to(self.device) | |
return self.model.encode_text(text) | |
def encode_image(self, image): | |
image = self.preprocess(image).unsqueeze(0).to(self.device) | |
return self.model.encode_image(image) | |
class FLAVAModel(VisionLanguageModel): | |
HF_MODEL = "facebook/flava-full" | |
def __init__(self, backbone: Optional[str] = None, device=c.DEVICE): | |
if backbone is None: | |
backbone = self.HF_MODEL | |
self.model = FlavaModel.from_pretrained(backbone).to(device) | |
self.processor = FlavaProcessor.from_pretrained(backbone) | |
self.device = device | |
def encode_text(self, text): | |
text_inputs = self.processor( | |
text=text, return_tensors="pt", padding="max_length", max_length=77 | |
) | |
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} | |
return self.model.get_text_features(**text_inputs)[:, 0, :] | |
def encode_image(self, image): | |
image_inputs = self.processor(images=image, return_tensors="pt") | |
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()} | |
return self.model.get_image_features(**image_inputs)[:, 0, :] | |
class ALIGNModel(VisionLanguageModel): | |
HF_MODEL = "kakaobrain/align-base" | |
def __init__(self, backbone: Optional[str] = None, device=c.DEVICE): | |
if backbone is None: | |
backbone = self.HF_MODEL | |
self.model = AlignModel.from_pretrained(backbone).to(device) | |
self.processor = AlignProcessor.from_pretrained(backbone) | |
self.device = device | |
def encode_text(self, text): | |
text_inputs = self.processor( | |
text=text, return_tensors="pt", padding="max_length", max_length=77 | |
) | |
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} | |
return self.model.get_text_features(**text_inputs) | |
def encode_image(self, image): | |
image_inputs = self.processor(images=image, return_tensors="pt") | |
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()} | |
return self.model.get_image_features(**image_inputs) | |
class BLIPModel(VisionLanguageModel): | |
HF_MODEL = "Salesforce/blip-itm-base-coco" | |
def __init__(self, backbone: Optional[str] = None, device=c.DEVICE): | |
if backbone is None: | |
backbone = self.HF_MODEL | |
self.model = BlipForImageTextRetrieval.from_pretrained(backbone).to(device) | |
self.processor = BlipProcessor.from_pretrained(backbone) | |
self.device = device | |
def encode_text(self, text): | |
text_inputs = self.processor( | |
text=text, return_tensors="pt", padding="max_length", max_length=77 | |
) | |
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} | |
question_embeds = self.model.text_encoder(**text_inputs)[0] | |
return self.model.text_proj(question_embeds[:, 0, :]) | |
def encode_image(self, image): | |
image_inputs = self.processor(images=image, return_tensors="pt") | |
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()} | |
image_embeds = self.model.vision_model(**image_inputs)[0] | |
return self.model.vision_proj(image_embeds[:, 0, :]) | |