Spaces:
Sleeping
Sleeping
"""Wraps `big_vision` PaliGemma model for easy use in demo.""" | |
from collections.abc import Callable | |
import dataclasses | |
from typing import Any | |
import jax | |
import jax.numpy as jnp | |
import ml_collections | |
import numpy as np | |
import PIL.Image | |
from big_vision import sharding | |
from big_vision import utils | |
from big_vision.models.proj.paligemma import paligemma | |
from big_vision.pp import builder as pp_builder | |
from big_vision.pp import ops_general # pylint: disable=unused-import | |
from big_vision.pp import ops_image # pylint: disable=unused-import | |
from big_vision.pp import ops_text # pylint: disable=unused-import | |
from big_vision.pp import tokenizer | |
from big_vision.pp.proj.paligemma import ops as ops_paligemma # pylint: disable=unused-import | |
from big_vision.trainers.proj.paligemma import predict_fns | |
mesh = jax.sharding.Mesh(jax.devices(), 'data') | |
def _recover_bf16(x): | |
if x.dtype == np.dtype('V2'): | |
x = x.view('bfloat16') | |
return x | |
def _load( | |
path, tokenizer_spec='gemma(tokensets=("loc", "seg"))', vocab_size=257_152 | |
): | |
"""Loads model, params, decode functions and tokenizer.""" | |
tok = tokenizer.get_tokenizer(tokenizer_spec) | |
config = ml_collections.FrozenConfigDict(dict( | |
llm_model='proj.paligemma.gemma_bv', | |
llm=dict(vocab_size=vocab_size, variant='gemma_2b'), | |
img=dict(variant='So400m/14', pool_type='none', scan=True), | |
)) | |
model = paligemma.Model(**config) | |
decode = predict_fns.get_all(model)['decode'] | |
beam_decode = predict_fns.get_all(model)['beam_decode'] | |
params_cpu = paligemma.load(None, path, config) | |
# Some numpy versions don't load bfloat16 correctly: | |
params_cpu = jax.tree.map(_recover_bf16, params_cpu) | |
return model, params_cpu, decode, beam_decode, tok | |
def _shard_params(params_cpu): | |
"""Shards `params_cpu` with fsdp strategy on all available devices.""" | |
params_sharding = sharding.infer_sharding( | |
params_cpu, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh | |
) | |
params = jax.tree.map(utils.reshard, params_cpu, params_sharding) | |
return params | |
def _pil2np(img): | |
"""Accepts `PIL.Image` or `np.ndarray` and returns `np.ndarray`.""" | |
if isinstance(img, PIL.Image.Image): | |
img = np.array(img) | |
img = img[..., :3] | |
if img.ndim == 2: | |
img = img[..., None] | |
if img.shape[-1] == 1: | |
img = np.repeat(img, 3, axis=-1) | |
return img | |
def _prepare_batch( | |
images, | |
prefixes, | |
*, | |
res=224, | |
tokenizer_spec='gemma(tokensets=("loc", "seg"))', | |
suffixes=None, | |
text_len=64, | |
): | |
"""Returns non-sharded batch.""" | |
pp_fn = pp_builder.get_preprocess_fn('|'.join([ | |
f'resize({res}, antialias=True)|value_range(-1, 1)', | |
f"tok(key='prefix', bos='yes', model='{tokenizer_spec}')", | |
f"tok(key='septok', text='\\n', model='{tokenizer_spec}')", | |
f"tok(key='suffix', model='{tokenizer_spec}')", | |
'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_input=[1, 1, 1])', # pylint: disable=line-too-long | |
f'tolen({text_len}, pad_value=0, key="text")', | |
f'tolen({text_len}, pad_value=1, key="mask_ar")', | |
f'tolen({text_len}, pad_value=0, key="mask_input")', | |
'keep("image", "text", "mask_ar", "mask_input")', | |
]), log_data=False) | |
assert not isinstance(prefixes, str), f'expected batch: {prefixes}' | |
assert ( | |
isinstance(images, (list, tuple)) or images.ndim == 4 | |
), f'expected batch: {images.shape}' | |
if suffixes is None: | |
suffixes = [''] * len(prefixes) | |
assert len(prefixes) == len(suffixes) == len(images) | |
examples = [{'_mask': True, **pp_fn({ | |
'image': np.asarray(_pil2np(image)), | |
'prefix': np.array(prefix), | |
'suffix': np.array(suffix), | |
})} for image, prefix, suffix in zip(images, prefixes, suffixes)] | |
batch = jax.tree_map(lambda *xs: np.stack(xs), *examples) | |
return batch | |
def _shard_batch(batch, n=None): | |
"""Shards `batch` with fsdp strategy on all available devices.""" | |
if n is None: | |
n = jax.local_device_count() | |
def pad(x): | |
return jnp.pad(x, [(0, -len(x) % n)] + [(0, 0)] * (x.ndim - 1)) | |
batch = {k: pad(v) for k, v in batch.items()} | |
data_sharding = jax.sharding.NamedSharding( | |
mesh, jax.sharding.PartitionSpec('data') | |
) | |
batch_on_device = utils.reshard(batch, data_sharding) | |
return batch_on_device | |
class PaligemmaConfig: | |
"""Desribes a `big_vision` PaliGemma model.""" | |
ckpt: str | |
res: int | |
text_len: int | |
tokenizer: str | |
vocab_size: int | |
class PaliGemmaModel: | |
"""Wraps a `big_vision` PaliGemma model.""" | |
config: PaligemmaConfig | |
tokenizer: tokenizer.Tokenizer | |
decode: Callable[..., Any] | |
beam_decode: Callable[..., Any] | |
def shard_batch(cls, batch): | |
return _shard_batch(batch) | |
def shard_params(cls, params_cpu): | |
return _shard_params(params_cpu) | |
def prepare_batch(self, images, texts, suffixes=None): | |
return _prepare_batch( | |
images=images, | |
prefixes=texts, | |
suffixes=suffixes, | |
res=self.config.res, | |
tokenizer_spec=self.config.tokenizer, | |
text_len=self.config.text_len, | |
) | |
def predict( | |
self, | |
params, | |
batch, | |
devices=None, | |
max_decode_len=128, | |
sampler='greedy', | |
**kw, | |
): | |
"""Returns tokens.""" | |
if devices is None: | |
devices = jax.devices() | |
if sampler == 'beam': | |
decode = self.beam_decode | |
else: | |
decode = self.decode | |
kw['sampler'] = sampler | |
return decode( | |
{'params': params}, | |
batch=batch, | |
devices=devices, | |
eos_token=self.tokenizer.eos_token, | |
max_decode_len=max_decode_len, | |
**kw, | |
) | |
ParamsCpu = Any | |
def load_model(config: PaligemmaConfig) -> tuple[PaliGemmaModel, ParamsCpu]: | |
"""Loads model from config.""" | |
model, params_cpu, decode, beam_decode, tok = _load( | |
path=config.ckpt, | |
tokenizer_spec=config.tokenizer, | |
vocab_size=config.vocab_size, | |
) | |
del model | |
return PaliGemmaModel( | |
config=config, tokenizer=tok, decode=decode, beam_decode=beam_decode, | |
), params_cpu | |