Flash attention with 14k tokens gets obscure results

#62
by cclevenger - opened

Cuda version: cuda-12.4
Cuda drivers: G05

Model setup:
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
load_in_4bit=use_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=use_nested_quant,
)

model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map='auto',
torch_dtype=torch.float16,
quantization_config=bnb_config,
trust_remote_code=True,
attn_implementation='flash_attention_2',
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)

self.pipeline = pipeline(
'text-generation',
model=model,
tokenizer=self.tokenizer,
)

I am passing through a prompt that consumes ~14k tokens and I am seeing generated results that are completely off topic. When I remove attn_implementation='flash_attention_2' and lower the total tokens to ~8k, I get results closer to what I would expect. At least on topic.
8k total tokens works fine with attn_implementation='flash_attention_2'. At least, I get an on-topic response.

Could someone shed some light on this?

Sign up or log in to comment