Getting Error in Forward Pass - model.generate

#14
by neelabhsinha - opened

Hello,

I am following the guide as provided with the model card. However, I am getting the following issue -

Traceback (most recent call last):
  File "/path/to/project/main.py", line 84, in <module>
    execution_flow()
  File "/path/to/project/main.py", line 77, in execution_flow
    execute_vlm(model_name, args.batch_size, args.do_sample, args.top_k, args.top_p)
  File "/path/to/project/src/utils/execute.py", line 37, in execute_vlm
    results_df = evaluation_loop(dataloader, model, model_name)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/project/src/utils/execute.py", line 49, in evaluation_loop
    response = model(questions, images)
               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/project/src/model/vlm.py", line 232, in __call__
    outputs = self.model.generate(**inputs, **gen_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/conda_env/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/conda_env/lib/python3.12/site-packages/transformers/generation/utils.py", line 1894, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/path/to/conda_env/lib/python3.12/site-packages/transformers/generation/utils.py", line 2631, in _sample
    outputs = self(
              ^^^^^
  File "/path/to/conda_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/conda_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/conda_env/lib/python3.12/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/cache/huggingface/modules/transformers_modules/THUDM/cogvlm2-llama3-chat-19B/2bf7de6892877eb50142395af14847519ba95998/modeling_cogvlm.py", line 649, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/path/to/conda_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/conda_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/conda_env/lib/python3.12/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/cache/huggingface/modules/transformers_modules/THUDM/cogvlm2-llama3-chat-19B/2bf7de6892877eb50142395af14847519ba95998/modeling_cogvlm.py", line 403, in forward
    return self.llm_forward(
           ^^^^^^^^^^^^^^^^^
  File "/path/to/cache/huggingface/modules/transformers_modules/THUDM/cogvlm2-llama3-chat-19B/2bf7de6892877eb50142395af14847519ba95998/modeling_cogvlm.py", line 452, in llm_forward
    past_key_values_length = past_key_values[0][0].shape[2]
                             ^^^^^^^^^^^^^^^^^^^^^^^^^

AttributeError: 'str' object has no attribute 'shape'

Code -

class CogVLM2:
    def __init__(self, model_name, do_sample, top_k, top_p, checkpoint):
        self.model_name = model_name
        self.model_name = checkpoint if checkpoint is not None else f'THUDM/{model_name}'
        self.image_size = 800
        self.nf4_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.float16
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            cache_dir=cache_dir,
            torch_dtype=TORCH_DTYPE,
            trust_remote_code=True,
            quantization_config=BitsAndBytesConfig(load_in_4bit=True),
            low_cpu_mem_usage=True
        ).eval()
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, cache_dir=cache_dir, trust_remote_code=True)
        self.device = next(self.model.parameters()).device
        self.prompt_prefix = 'Only answer below the question. Do not provide any additional information.\n'
        self.gen_kwargs = {
            "max_new_tokens": 2048,
            "pad_token_id": 128002,
        }
        print_model_info(self.model, self.model_name)
        
    def __call__(self, questions, images):
        query = questions[0]
        image = images[0]
        history = []
        input_by_model = self.model.build_conversation_input_ids(
            self.tokenizer,
            query=query,
            history=history,
            images=[image],
            template_version='chat'
        )
        inputs = {
            'input_ids': input_by_model['input_ids'].unsqueeze(0).to(self.device),
            'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(self.device),
            'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(self.device),
            'images': [[input_by_model['images'][0].to(self.device).to(TORCH_DTYPE)]] if image is not None else None,
        }
        gen_kwargs = {
            "max_new_tokens": 2048,
            "pad_token_id": 128002,
            "top_k": 1,
        }
        with torch.no_grad():
            outputs = self.model.generate(**inputs, **gen_kwargs)
            outputs = outputs[:, inputs['input_ids'].shape[1]:]
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            print("\nCogVLM2:", response)

I am not sure what is wrong. I know tokenizer decode needs additional steps, but code is failing at model.generate itself. Why?

Knowledge Engineering Group (KEG) & Data Mining at Tsinghua University org

using with transformers == 4.40

Sign up or log in to comment