Wizardry in how generate() uses key-value cache?
Hi, I'm trying to replicate the behavior of generate() method with a custom loop (for greedy search).
For your demo prompt:
prompt = [{'role': 'user', 'content': 'Which famous math number begins with 1.6 ...?'}]
inputs = tokenizer.apply_chat_template(
prompt,
add_generation_prompt=True,
return_tensors='pt',
max_length=input_token_size,
truncation=True
)
I then create left-side padding and an attention mask that goes along with the padding like so:
padding = torch.full((1, input_token_size - inputs.shape[1]), tokenizer.pad_token_id, dtype=torch.int64)
padded_input = torch.cat((padding, inputs), dim=1)
original_mask = torch.ones(inputs.shape, dtype=torch.int64)
padding_mask = torch.zeros(padding.shape, dtype=torch.int64)
attention_mask = torch.cat((padding_mask, original_mask), dim=1)
Using Google Colab high RAM CPU kernel, I first tried the vanilla HF generate() method like so:
tokens = model.generate(
inputs,
max_new_tokens=35,
attention_mask = attention_mask
)
print(tokenizer.decode(tokens[0], skip_special_tokens=False))
Which yielded:
TIME: about 9 seconds to generate 35 tokens
OUTPUT: "The number you are referring to is 1.6091. This is a well-known value in the field of mathematics, particularly in the area of approximation. It"
Then, I tried a custom loop without using the key-value cache like so:
generated_tokens = []
next_token_id = padded_input
with torch.no_grad():
for _ in range(35):
next_logits, _ = model(next_token_id, attention_mask=attention_mask)
next_logits = next_logits[:, -1:]
next_logit = torch.argmax(next_logits, dim=-1)
next_token_id = torch.cat([next_token_id, next_logit], dim=1)
attention_mask = torch.cat([attention_mask, torch.ones((1, 1), dtype=torch.int64)], dim=1)
generated_tokens.append(next_logit.item())
print(tokenizer.decode(generated_tokens, skip_special_tokens=False))
TIME: 51 seconds (!!!)
OUTPUT: "The number you are referring to is 1.6091. This is a well-known value in the field of mathematics, particularly in the area of approximation. It"
As expected, generated text is identical but without using the cache it's taking too long.
Now, here's my implementation to use the cache:
past_key_values = None # past_key_values is the key-value cache
generated_tokens = []
next_token_id = padded_input
with torch.no_grad():
for cycle in range(35):
next_logits, past_key_values = model(next_token_id, attention_mask=attention_mask, past_key_values=past_key_values)#, use_cache=True)
next_logits = next_logits[:, -1:]
next_token_id = torch.argmax(next_logits, dim=-1)
attention_mask = torch.cat([attention_mask, torch.ones((1, 1), dtype=torch.int64)], dim=1)
generated_tokens.append(next_token_id.item())
print(tokenizer.decode(generated_tokens, skip_special_tokens=False))
TIME: about 9 seconds
OUTPUT: "The number you are referring to is 1.6, and it is a common notation used in the field of mathematics to represent a decimal number. In mathematics. It's"
As you can see this implementation matches exactly the generation time of the vanilla HF generate() method, which is awesome. BUT, the generated text is both DIFFERENT from the vanilla generate method, and also is wrong (model doesn't realize we are talking about pi).
Clearly, the generate() method is implementing key-value, but in a different way.
Can you please clarify how does the generate method implement it?