""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import re from bubogpt.common.registry import registry from bubogpt.processors.base_processor import BaseProcessor from bubogpt.processors.vision_augment import RandomAugment from omegaconf import OmegaConf from torchvision import transforms from torchvision.transforms.functional import InterpolationMode class ImageBindVisionBaseProcessor(BaseProcessor): def __init__(self, mean=None, std=None): super().__init__() if mean is None: mean = (0.48145466, 0.4578275, 0.40821073) if std is None: std = (0.26862954, 0.26130258, 0.27577711) self.normalize = transforms.Normalize(mean, std) # Note: The config of caption processor is different from the ones in BLIP2 / MiniGPT4 @registry.register_processor("imagebind_caption") class ImageBindCaptionProcessor(BaseProcessor): def __init__(self, prompt="", max_words=50): # Note: Actually no prompts are used here. super().__init__() self.prompt = prompt self.max_words = max_words def __call__(self, caption): caption = self.prompt + self.pre_caption(caption) return caption @classmethod def from_config(cls, cfg=None): if cfg is None: cfg = OmegaConf.create() prompt = cfg.get("prompt", "") max_words = cfg.get("max_words", 150) return cls(prompt=prompt, max_words=max_words) def pre_caption(self, caption): caption = re.sub( r"([\n\"()*#~])", " ", caption, ) caption = re.sub( r"\s{2,}", " ", caption, ) caption = caption.rstrip("\n") caption = caption.strip(" ") # # truncate caption Note: Deprecated. # caption_words = caption.split(" ") # if len(caption_words) > self.max_words: # caption = " ".join(caption_words[: self.max_words]) return caption # Note: The training config of vision processor keeps the same as BLIP2 / MiniGPT4 @registry.register_processor("imagebind_vision_train") class ImageBindVisionTrainProcessor(ImageBindVisionBaseProcessor): def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): super().__init__(mean=mean, std=std) self.transform = transforms.Compose( [ transforms.RandomResizedCrop( image_size, scale=(min_scale, max_scale), interpolation=InterpolationMode.BICUBIC, ), transforms.ToTensor(), self.normalize, ] ) def __call__(self, item): return self.transform(item) @classmethod def from_config(cls, cfg=None): if cfg is None: cfg = OmegaConf.create() image_size = cfg.get("image_size", 224) mean = cfg.get("mean", None) std = cfg.get("std", None) min_scale = cfg.get("min_scale", 0.5) max_scale = cfg.get("max_scale", 1.0) return cls( image_size=image_size, mean=mean, std=std, min_scale=min_scale, max_scale=max_scale, ) # Changed. @registry.register_processor("imagebind_vision_eval") class ImageBindVisionEvalProcessor(ImageBindVisionBaseProcessor): def __init__(self, image_size=224, mean=None, std=None): super().__init__(mean=mean, std=std) self.transform = transforms.Compose( [ transforms.Resize( image_size, interpolation=InterpolationMode.BICUBIC ), transforms.CenterCrop(image_size), transforms.ToTensor(), self.normalize, ] ) def __call__(self, item): return self.transform(item) @classmethod def from_config(cls, cfg=None): if cfg is None: cfg = OmegaConf.create() image_size = cfg.get("image_size", 224) mean = cfg.get("mean", None) std = cfg.get("std", None) return cls(image_size=image_size, mean=mean, std=std)