Bug in 4bit quantization?
Hi,
I am following the model's tutorial from the model card, with minor modifications such as device_map='auto' and os.environ['CUDA_VISIBLE_DEVICES'] = '2,3,4,5,6,7'. I am running it inside jupyter notebook. It works well when I use 8 bit quantization, but the model answers me "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" when I use 4 bit quantization. The problem disappears when I substitute load_in_4bit to load_in_8bit (but then it is too heavy to run even in 7 Titan RTX 24GB GPUs, with somewhat large images). Here is the code and output:
import os
os.environ['TRANSFORMERS_CACHE'] = './HFCache'
os.environ['HF_HOME'] = './HFCache'
os.environ['CUDA_VISIBLE_DEVICES'] = '2,3,4,5,6,7'
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
model_id = "llava-hf/llava-onevision-qwen2-72b-ov-hf"
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
load_in_4bit=True,
device_map='auto',
)
processor = AutoProcessor.from_pretrained(model_id)
# Define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What are these?"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(torch.float16)#.to(0, torch.float16)
output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
print(processor.decode(output[0][2:], skip_special_tokens=True))
Output:
What are these?assistant
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Output when using load_in_8bit (correct):
What are these?assistant
These are two cats lying on a pink blanket.
Environment details
Name Version Build Channel
_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
accelerate 0.34.2 pypi_0 pypi
asttokens 2.0.5 pyhd3eb1b0_0
bitsandbytes 0.43.3 pypi_0 pypi
bzip2 1.0.8 h5eee18b_6
ca-certificates 2024.7.2 h06a4308_0
certifi 2024.8.30 pypi_0 pypi
charset-normalizer 3.3.2 pypi_0 pypi
comm 0.2.1 py311h06a4308_0
debugpy 1.6.7 py311h6a678d5_0
decorator 5.1.1 pyhd3eb1b0_0
executing 0.8.3 pyhd3eb1b0_0
filelock 3.13.1 pypi_0 pypi
fsspec 2024.2.0 pypi_0 pypi
huggingface-hub 0.24.6 pypi_0 pypi
idna 3.8 pypi_0 pypi
ipykernel 6.28.0 py311h06a4308_0
ipython 8.25.0 py311h06a4308_0
jedi 0.19.1 py311h06a4308_0
jinja2 3.1.3 pypi_0 pypi
jupyter_client 8.6.0 py311h06a4308_0
jupyter_core 5.7.2 py311h06a4308_0
ld_impl_linux-64 2.38 h1181459_1
libffi 3.4.4 h6a678d5_1
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libsodium 1.0.18 h7b6447c_0
libstdcxx-ng 11.2.0 h1234567_1
libuuid 1.41.5 h5eee18b_0
markupsafe 2.1.5 pypi_0 pypi
matplotlib-inline 0.1.6 py311h06a4308_0
mpmath 1.3.0 pypi_0 pypi
ncurses 6.4 h6a678d5_0
nest-asyncio 1.6.0 py311h06a4308_0
networkx 3.2.1 pypi_0 pypi
numpy 1.26.3 pypi_0 pypi
nvidia-cublas-cu11 11.11.3.6 pypi_0 pypi
nvidia-cuda-cupti-cu11 11.8.87 pypi_0 pypi
nvidia-cuda-nvrtc-cu11 11.8.89 pypi_0 pypi
nvidia-cuda-runtime-cu11 11.8.89 pypi_0 pypi
nvidia-cudnn-cu11 9.1.0.70 pypi_0 pypi
nvidia-cufft-cu11 10.9.0.58 pypi_0 pypi
nvidia-curand-cu11 10.3.0.86 pypi_0 pypi
nvidia-cusolver-cu11 11.4.1.48 pypi_0 pypi
nvidia-cusparse-cu11 11.7.5.86 pypi_0 pypi
nvidia-nccl-cu11 2.20.5 pypi_0 pypi
nvidia-nvtx-cu11 11.8.86 pypi_0 pypi
openssl 3.0.15 h5eee18b_0
packaging 24.1 py311h06a4308_0
parso 0.8.3 pyhd3eb1b0_0
pexpect 4.8.0 pyhd3eb1b0_3
pillow 10.2.0 pypi_0 pypi
pip 24.2 py311h06a4308_0
platformdirs 3.10.0 py311h06a4308_0
prompt-toolkit 3.0.43 py311h06a4308_0
prompt_toolkit 3.0.43 hd3eb1b0_0
psutil 5.9.0 py311h5eee18b_0
ptyprocess 0.7.0 pyhd3eb1b0_2
pure_eval 0.2.2 pyhd3eb1b0_0
pygments 2.15.1 py311h06a4308_1
python 3.11.9 h955ad1f_0
python-dateutil 2.9.0post0 py311h06a4308_2
pyyaml 6.0.2 pypi_0 pypi
pyzmq 25.1.2 py311h6a678d5_0
readline 8.2 h5eee18b_0
regex 2024.7.24 pypi_0 pypi
requests 2.32.3 pypi_0 pypi
safetensors 0.4.5 pypi_0 pypi
setuptools 72.1.0 py311h06a4308_0
six 1.16.0 pyhd3eb1b0_1
sqlite 3.45.3 h5eee18b_0
stack_data 0.2.0 pyhd3eb1b0_0
sympy 1.12 pypi_0 pypi
tk 8.6.14 h39e8969_0
tokenizers 0.19.1 pypi_0 pypi
torch 2.4.1+cu118 pypi_0 pypi
torchaudio 2.4.1+cu118 pypi_0 pypi
torchvision 0.19.1+cu118 pypi_0 pypi
tornado 6.4.1 py311h5eee18b_0
tqdm 4.66.5 pypi_0 pypi
traitlets 5.14.3 py311h06a4308_0
transformers 4.45.0.dev0 pypi_0 pypi
triton 3.0.0 pypi_0 pypi
typing_extensions 4.11.0 py311h06a4308_0
tzdata 2024a h04d1e81_0
urllib3 2.2.2 pypi_0 pypi
wcwidth 0.2.5 pyhd3eb1b0_0
wheel 0.43.0 py311h06a4308_0
xz 5.4.6 h5eee18b_1
zeromq 4.3.5 h6a678d5_0
zlib 1.2.13 h5eee18b_1
@prasb
hmm, I tried the same code with 4-bit and got These are two cats lying on a pink blanket.
as reply. This can also be hardware related probably, as I have the same versions as you have