Spaces:
Sleeping
Sleeping
"""Model-related code and constants.""" | |
import spaces | |
import dataclasses | |
import os | |
import re | |
import PIL.Image | |
# pylint: disable=g-bad-import-order | |
import gradio_helpers | |
import llama_cpp | |
ORGANIZATION = 'abetlen' | |
BASE_MODELS = [ | |
('paligemma-3b-mix-224-gguf', 'paligemma-3b-mix-224'), | |
] | |
MODELS = { | |
**{ | |
model_name: ( | |
f'{ORGANIZATION}/{repo}', | |
(f'{model_name}-text-model-q4_k_m.gguf', f'{model_name}-mmproj-f16.gguf'), | |
) | |
for repo, model_name in BASE_MODELS | |
}, | |
} | |
MODELS_INFO = { | |
'paligemma-3b-mix-224': ( | |
'GGUF PaliGemma 3B weights quantized in Q4_K_M Format, finetuned with 224x224 input images and 256 token input/output ' | |
'text sequences on a mixture of downstream academic datasets. The models are available in float32, ' | |
'bfloat16 and float16 format for research purposes only.' | |
), | |
} | |
MODELS_RES_SEQ = { | |
'paligemma-3b-mix-224': (224, 256), | |
} | |
# "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM. | |
# Below value should be smaller than "available RAM - one model". | |
# A single bf16 is about 5860 MB. | |
MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9) | |
# config = paligemma_bv.PaligemmaConfig( | |
# ckpt='', # will be set below | |
# res=224, | |
# text_len=64, | |
# tokenizer='gemma(tokensets=("loc", "seg"))', | |
# vocab_size=256_000 + 1024 + 128, | |
# ) | |
def get_cached_model( | |
model_name: str, | |
):# -> tuple[paligemma_bv.PaliGemmaModel, paligemma_bv.ParamsCpu]: | |
"""Returns model and params, using RAM cache.""" | |
res, seq = MODELS_RES_SEQ[model_name] | |
model_path = gradio_helpers.get_paths()[model_name] | |
config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq) | |
model, params_cpu = gradio_helpers.get_memory_cache( | |
config_, | |
lambda: paligemma_bv.load_model(config_), | |
max_cache_size_bytes=MAX_RAM_CACHE, | |
) | |
return model, params_cpu | |
def pil_image_to_base64(image: PIL.Image.Image) -> str: | |
"""Converts PIL image to base64.""" | |
import io | |
import base64 | |
buffered = io.BytesIO() | |
image.save(buffered, format='JPEG') | |
return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
def generate( | |
model_name: str, sampler: str, image: PIL.Image.Image, prompt: str | |
) -> str: | |
"""Generates output with specified `model_name`, `sampler`.""" | |
# model, params_cpu = get_cached_model(model_name) | |
# batch = model.shard_batch(model.prepare_batch([image], [prompt])) | |
# with gradio_helpers.timed('sharding'): | |
# params = model.shard_params(params_cpu) | |
# with gradio_helpers.timed('computation', start_message=True): | |
# tokens = model.predict(params, batch, sampler=sampler) | |
model_path, clip_path = gradio_helpers.get_paths()[model_name] | |
print(model_path) | |
print(gradio_helpers.get_paths()) | |
model = llama_cpp.Llama( | |
model_path, | |
chat_handler=llama_cpp.llama_chat_format.PaliGemmaChatHandler( | |
clip_path | |
), | |
n_ctx=1024, | |
n_ubatch=512, | |
n_batch=512, | |
n_gpu_layers=-1, | |
) | |
return model.create_chat_completion(messages=[{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": prompt | |
}, | |
{ | |
"type": "image_url", | |
"image_url": "data:image/jpeg;base64," + pil_image_to_base64(image) | |
} | |
] | |
}])["choices"][0]["message"]["content"] | |