masharpe's picture
Use ZeroGPU
44b21e0
# Device-independent algorithms for LLM.
import logging
import time
import torch
import torch.nn.functional as F
import transformers
logger = logging.getLogger(__name__)
# Helper to pull out the response tokens.
def isolate_responses_BL(output_BL, prompt_len, eos_token_id):
responses_BL = []
resp_BL = output_BL[:, prompt_len:]
for i in range(resp_BL.shape[0]):
resp_L = resp_BL[i]
resplen = resp_L.shape[0]
for j in range(resplen):
if resp_L[j] == eos_token_id:
resplen = j+1
break
response_L = resp_L[:resplen].cpu().detach().numpy()
responses_BL.append(response_L)
return responses_BL
def tokenize_prompt(device, tokenizer, chat, quiet=False):
# Tokenize the prompt.
prompt_BL = tokenizer.apply_chat_template(
[chat],
tokenizer=True, add_generation_prompt=True, return_tensors='pt'
).to(device)
if not quiet:
print('PROMPT:')
print(tokenizer.decode(prompt_BL[0]))
return prompt_BL
def generate(device, model, tokenizer, chat):
"""Generate a response using huggingface's generation."""
prompt_BL = tokenize_prompt(device, tokenizer, chat)
prompt_len = prompt_BL.shape[1]
# Generate response.
# Unfortunately, huggingface's generation code uses 'cumsum',
# which doesn't have a deterministic implementation.
torch.use_deterministic_algorithms(False)
generation_output = model.generate(
inputs=prompt_BL, max_new_tokens=512, do_sample=True,
return_dict_in_generate=True,
)
torch.use_deterministic_algorithms(True)
output_BL = generation_output.sequences
for response_L in isolate_responses_BL(output_BL, prompt_len, tokenizer.eos_token_id):
print(f'RESPONSE:')
print(tokenizer.decode(response_L))
return tokenizer.decode(response_L)
def generate_with_logits(device, model, tokenizer, chat, seed=None):
# Huggingface generation that returns logits too.
if seed is not None:
transformers.set_seed(seed)
prompt_BL = tokenize_prompt(device, tokenizer, chat)
prompt_len = prompt_BL.shape[1]
# Unfortunately, huggingface's generation code uses 'cumsum',
# which doesn't have a deterministic implementation.
torch.use_deterministic_algorithms(False)
generation_output = model.generate(
inputs=prompt_BL, max_new_tokens=16, do_sample=True,
return_dict_in_generate=True, output_logits=True,
)
torch.use_deterministic_algorithms(True)
output_BL = generation_output.sequences
logits_BLV = torch.stack(generation_output.logits, axis=1)
for response_L in isolate_responses_BL(output_BL, prompt_len, tokenizer.eos_token_id):
print(f'RESPONSE:')
print(tokenizer.decode(response_L))
return response_L, logits_BLV[0]
def response_logits(device, model, tokenizer, chat, response_L):
# Calculate logits using a single pass.
prompt_BL = tokenize_prompt(device, tokenizer, chat)
response_pt_L = torch.from_numpy(response_L[:-1])
# Concatenate along axis 1.
input_ids = torch.cat((prompt_BL, response_pt_L[None]), dim=1)
outputs = model(input_ids)
logits_BLV = outputs.logits
return logits_BLV[0][-len(response_L):]
def generate_custom(device, model, tokenizer, chat, max_tokens=512, seed=None, return_tokens=False, quiet=False, return_lnprobs=False):
"""Generate a response using custom generation."""
if seed is not None:
torch.manual_seed(seed)
prompt_BL = tokenize_prompt(device, tokenizer, chat, quiet=quiet)
# Generate response.
input_ids = prompt_BL
past_key_values = None
response_L = []
lnprobs_L = []
while 1:
outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
past_key_values = outputs.past_key_values
logits = outputs.logits[0, -1, :]
p_V = F.softmax(logits, dim=-1)
token = torch.multinomial(p_V, num_samples=1).item()
response_L.append(token)
lnprobs_L.append(F.log_softmax(logits, dim=-1)[token].item())
p_token = p_V[token].item()
if not quiet:
print(f' Sampled token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
if token == tokenizer.eos_token_id or len(response_L) >= max_tokens:
break
input_ids = torch.tensor([[token]], device=device)
if not quiet:
print(f'RESPONSE:')
print(tokenizer.decode(response_L))
if return_tokens:
if return_lnprobs:
return response_L, lnprobs_L
else:
return response_L
else:
assert not return_lnprobs
return tokenizer.decode(response_L)
def format_token(tokenizer, token_id):
return repr(tokenizer.decode(token_id))
def apoc(device, model_x, model_y, tokenizer, chat_x, chat_y, max_tokens=512, seed=None, return_tokens=False, quiet=False):
"""Generate a response using APOC unconditional sampling."""
if seed is not None:
torch.manual_seed(seed)
# This early implementation of the algorithm is numerically non-robust,
# so reduce problems by using high-precision floating-point.
logit_dtype = torch.float64
prompt_x_BL = tokenize_prompt(device, tokenizer, chat_x, quiet=quiet)
prompt_y_BL = tokenize_prompt(device, tokenizer, chat_y, quiet=quiet)
# Four variables are needed, since in the first iteration it depends on prompt (X vs Y),
# whereas in later iterations it depends on response (a vs b).
input_ids_xa = prompt_x_BL
input_ids_ya = prompt_y_BL
input_ids_xb = prompt_x_BL
input_ids_yb = prompt_y_BL
past_key_values_xa = None
past_key_values_ya = None
past_key_values_xb = None
past_key_values_yb = None
equal = True
a_eos = False
b_eos = False
def zero():
return torch.zeros(1, dtype=logit_dtype, device=device)
ln_pya_m_ln_pxa = zero()
ln_pxb_m_ln_pyb = zero()
response_a_L = []
response_b_L = []
i = 0
while 1:
if i >= max_tokens or (a_eos and b_eos): break
if not quiet:
print(f'Generating response token {i}')
i += 1
forward_passes_start = time.perf_counter()
if not a_eos:
outputs = model_x(input_ids_xa, past_key_values=past_key_values_xa, use_cache=True)
past_key_values_xa = outputs.past_key_values
logits = outputs.logits[0, -1, :].to(logit_dtype)
ln_pxa_V = F.log_softmax(logits, dim=-1)
outputs = model_y(input_ids_ya, past_key_values=past_key_values_ya, use_cache=True)
past_key_values_ya = outputs.past_key_values
logits = outputs.logits[0, -1, :].to(logit_dtype)
ln_pya_V = F.log_softmax(logits, dim=-1)
if not b_eos:
if equal:
# In equal mode, neither input_ids nor past_key_values depends on a vs b,
# so we can reuse the forward pass results for a 50% time savings.
assert not a_eos
past_key_values_xb = past_key_values_xa
ln_pxb_V = ln_pxa_V
past_key_values_yb = past_key_values_ya
ln_pyb_V = ln_pya_V
else:
outputs = model_x(input_ids_xb, past_key_values=past_key_values_xb, use_cache=True)
past_key_values_xb = outputs.past_key_values
logits = outputs.logits[0, -1, :].to(logit_dtype)
ln_pxb_V = F.log_softmax(logits, dim=-1)
outputs = model_y(input_ids_yb, past_key_values=past_key_values_yb, use_cache=True)
past_key_values_yb = outputs.past_key_values
logits = outputs.logits[0, -1, :].to(logit_dtype)
ln_pyb_V = F.log_softmax(logits, dim=-1)
forward_passes_end = time.perf_counter()
if not quiet:
print(f' Forward passes took {(forward_passes_end - forward_passes_start)*1000:.0f} ms')
if equal:
ln_pmeet_V = torch.minimum(
ln_pxa_V + torch.maximum(zero(), -ln_pya_m_ln_pxa),
ln_pya_V + torch.maximum(zero(), ln_pya_m_ln_pxa),
)
pmeet_V = torch.exp(ln_pmeet_V)
pmeet = torch.sum(pmeet_V)
if not quiet:
print(f' Equal mode (pmeet={pmeet*100.0:.3f}%)')
if torch.rand_like(pmeet) < pmeet:
token_a = token_b = torch.multinomial(pmeet_V, num_samples=1).item()
p_token_a = (pmeet_V[token_a] / pmeet).item()
if not quiet:
print(f' Sampled {format_token(tokenizer, token_a)} ({p_token_a*100.0:.1f}%)')
else:
if not quiet:
print(' Exited equal mode')
equal = False
if not equal:
if not a_eos:
wxt_V = torch.maximum(zero(), torch.exp(ln_pxa_V) - torch.exp(ln_pya_V + ln_pya_m_ln_pxa))
token_a = torch.multinomial(wxt_V, num_samples=1).item()
p_token_a = (wxt_V[token_a] / torch.sum(wxt_V)).item()
if not quiet:
print(f' Sampled token_a {format_token(tokenizer, token_a)} ({p_token_a*100.0:.3f}%)')
if not b_eos:
wyt_V = torch.maximum(zero(), torch.exp(ln_pyb_V) - torch.exp(ln_pxb_V + ln_pxb_m_ln_pyb))
token_b = torch.multinomial(wyt_V, num_samples=1).item()
p_token_b = (wyt_V[token_b] / torch.sum(wyt_V)).item()
if not quiet:
print(f' Sampled token_b {format_token(tokenizer, token_b)} ({p_token_b*100.0:.3f}%)')
if not a_eos:
response_a_L.append(token_a)
input_ids_xa = input_ids_ya = torch.tensor([[token_a]], device=device)
ln_pya_m_ln_pxa += ln_pya_V[token_a] - ln_pxa_V[token_a]
if token_a == tokenizer.eos_token_id:
a_eos = True
if not b_eos:
response_b_L.append(token_b)
input_ids_xb = input_ids_yb = torch.tensor([[token_b]], device=device)
ln_pxb_m_ln_pyb += ln_pxb_V[token_b] - ln_pyb_V[token_b]
if token_b == tokenizer.eos_token_id:
b_eos = True
if not quiet:
print(f'RESPONSE X:')
print(tokenizer.decode(response_a_L))
print(f'RESPONSE Y:')
print(tokenizer.decode(response_b_L))
if return_tokens:
return response_a_L, response_b_L
else:
return tokenizer.decode(response_a_L), tokenizer.decode(response_b_L)
# Alternative implementation.
@torch.no_grad()
def apoc_alt(device, model_x, model_y, tokenizer, chat_x, chat_y, max_tokens=512, seed=None):
if seed is not None:
torch.manual_seed(seed)
prompt_x_BL = tokenize_prompt(device, tokenizer, chat_x, quiet=True)
prompt_y_BL = tokenize_prompt(device, tokenizer, chat_y, quiet=True)
model_pair = ModelPair(model_x, model_y, prompt_x_BL, prompt_y_BL)
logger.debug('PROMPT X:')
logger.debug(tokenizer.decode(prompt_x_BL[0]))
logger.debug('PROMPT Y:')
logger.debug(tokenizer.decode(prompt_y_BL[0]))
return _apoc_impl(model_pair, tokenizer, max_tokens)
LOGIT_DTYPE = torch.float64
class ModelPair:
def __init__(self, model_x, model_y, prompt_x_BL, prompt_y_BL):
self._model_x = model_x
self._model_y = model_y
self._prompt_x_BL = prompt_x_BL
self._prompt_y_BL = prompt_y_BL
self._is_swapped = False
def start(self):
# Return logprobs for the initial token.
outputs = self._model_x(self._prompt_x_BL, use_cache=True)
self._past_key_values_x = outputs.past_key_values
logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE)
lnpx_V = F.log_softmax(logits, dim=-1)
outputs = self._model_y(self._prompt_y_BL, use_cache=True)
self._past_key_values_y = outputs.past_key_values
logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE)
lnpy_V = F.log_softmax(logits, dim=-1)
return self._maybe_swap(lnpx_V, lnpy_V)
def step(self, token):
# Append the given token, then return logprobs for the next token.
forward_passes_start = time.perf_counter()
input_ids = torch.tensor([[token]], device=self._prompt_x_BL.device)
outputs = self._model_x(input_ids, past_key_values=self._past_key_values_x, use_cache=True)
self._past_key_values_x = outputs.past_key_values
logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE)
lnpx_V = F.log_softmax(logits, dim=-1)
outputs = self._model_y(input_ids, past_key_values=self._past_key_values_y, use_cache=True)
self._past_key_values_y = outputs.past_key_values
logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE)
lnpy_V = F.log_softmax(logits, dim=-1)
forward_passes_end = time.perf_counter()
logger.debug(f'Incremental forward passes took {(forward_passes_end - forward_passes_start)*1000:.0f} ms')
return self._maybe_swap(lnpx_V, lnpy_V)
def get_position(self):
# Return a position that can be rewound to.
return self._past_key_values_x, self._past_key_values_y
def rewind_to(self, position):
# Rewind the KV cache.
self._past_key_values_x, self._past_key_values_y = position
def swap_models(self):
# Exchange the order of the models.
self._is_swapped = not self._is_swapped
def _maybe_swap(self, a, b):
if self._is_swapped:
return b, a
else:
return a, b
def _apoc_impl(model_pair, tokenizer, max_tokens):
prefix = []
lnpx_V, lnpy_V = model_pair.start()
lnpy_m_lnpx = torch.zeros(1, dtype=lnpx_V.dtype, device=lnpx_V.device)
while 1:
ln_peq_V = torch.minimum(
lnpx_V + F.relu(-lnpy_m_lnpx),
lnpy_V + F.relu(lnpy_m_lnpx),
)
peq_V = torch.exp(ln_peq_V)
peq = torch.sum(peq_V)
if torch.rand_like(peq) > peq:
logger.debug(f'Completed common prefix ({(1-peq)*100.:.3f}%)')
break
logger.debug(f'Extending common prefix ({peq*100.:.3f}%)')
token = torch.multinomial(peq_V, 1).item()
prefix.append(token)
lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token]
p_token = (peq_V[token] / peq).item()
logger.debug(f'Sampled prefix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
if token == tokenizer.eos_token_id or len(prefix) >= max_tokens:
return prefix, prefix
lnpx_V, lnpy_V = model_pair.step(token)
remaining_tokens = max_tokens - len(prefix)
split_pos = model_pair.get_position()
response_a = prefix + _apoc_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx)
logger.debug('First suffix complete; rewinding')
model_pair.rewind_to(split_pos)
model_pair.swap_models()
response_b = prefix + _apoc_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpy_V, lnpx_V, -lnpy_m_lnpx)
return response_a, response_b
def _apoc_gen_suffix(model_pair, tokenizer, max_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx):
lnpy_m_lnpx = lnpy_m_lnpx.clone()
suffix = []
while 1:
wx_V = F.relu(torch.exp(lnpx_V) - torch.exp(lnpy_V + lnpy_m_lnpx))
token = torch.multinomial(wx_V, 1).item()
suffix.append(token)
lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token]
p_token = (wx_V[token] / torch.sum(wx_V)).item()
logger.debug(f'Sampled suffix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
if token == tokenizer.eos_token_id or len(suffix) >= max_tokens:
return suffix
lnpx_V, lnpy_V = model_pair.step(token)
def generate_streaming(device, model, tokenizer, chat, max_tokens=512, seed=None):
"""Stream a response using custom generation."""
prompt_BL = tokenize_prompt(device, tokenizer, chat, quiet=True)
logger.debug('PROMPT:')
logger.debug(tokenizer.decode(prompt_BL[0]))
if seed is not None:
torch.manual_seed(seed)
return _generate_streaming_impl(device, model, tokenizer, prompt_BL, max_tokens)
def _generate_streaming_impl(device, model, tokenizer, prompt_BL, max_tokens):
input_ids = prompt_BL
past_key_values = None
n_tokens = 0
while 1:
outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
past_key_values = outputs.past_key_values
logits = outputs.logits[0, -1, :]
p_V = F.softmax(logits, dim=-1)
token = torch.multinomial(p_V, num_samples=1).item()
p_token = p_V[token].item()
logger.debug(f' Sampled token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
yield token
n_tokens += 1
if token == tokenizer.eos_token_id or n_tokens >= max_tokens:
break
input_ids = torch.tensor([[token]], device=device)
# APOC unconditional streaming
@torch.no_grad()
def apoc_streaming(model_x, model_y, tokenizer, chat_x, chat_y, max_tokens=512, seed=None):
if seed is not None:
torch.manual_seed(seed)
prompt_x_BL = tokenize_prompt(model_x.device, tokenizer, chat_x, quiet=True)
prompt_y_BL = tokenize_prompt(model_y.device, tokenizer, chat_y, quiet=True)
model_pair = ModelPair(model_x, model_y, prompt_x_BL, prompt_y_BL)
logger.debug('PROMPT X:')
logger.debug(tokenizer.decode(prompt_x_BL[0]))
logger.debug('PROMPT Y:')
logger.debug(tokenizer.decode(prompt_y_BL[0]))
return _apoc_streaming_impl(model_pair, tokenizer, max_tokens)
def _apoc_streaming_impl(model_pair, tokenizer, max_tokens):
remaining_tokens = max_tokens
lnpx_V, lnpy_V = model_pair.start()
lnpy_m_lnpx = torch.zeros(1, dtype=lnpx_V.dtype, device=lnpx_V.device)
while 1:
ln_peq_V = torch.minimum(
lnpx_V + F.relu(-lnpy_m_lnpx),
lnpy_V + F.relu(lnpy_m_lnpx),
)
peq_V = torch.exp(ln_peq_V)
peq = torch.sum(peq_V)
if torch.rand_like(peq) > peq:
logger.debug(f'Completed common prefix ({(1-peq)*100.:.3f}%)')
break
logger.debug(f'Extending common prefix ({peq*100.:.3f}%)')
token = torch.multinomial(peq_V, 1).item()
remaining_tokens -= 1
yield token, token
lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token]
p_token = (peq_V[token] / peq).item()
logger.debug(f'Sampled prefix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
if token == tokenizer.eos_token_id or remaining_tokens == 0:
return
lnpx_V, lnpy_V = model_pair.step(token)
split_pos = model_pair.get_position()
for token_a in _apoc_streaming_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx):
yield token_a, None
logger.debug('Suffix a complete; rewinding')
model_pair.rewind_to(split_pos)
model_pair.swap_models()
for token_b in _apoc_streaming_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpy_V, lnpx_V, -lnpy_m_lnpx):
yield None, token_b
logger.debug('Suffix b complete')
def _apoc_streaming_gen_suffix(model_pair, tokenizer, max_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx):
remaining_tokens = max_tokens
lnpy_m_lnpx = lnpy_m_lnpx.clone()
while 1:
wx_V = F.relu(torch.exp(lnpx_V) - torch.exp(lnpy_V + lnpy_m_lnpx))
token = torch.multinomial(wx_V, 1).item()
remaining_tokens -= 1
yield token
lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token]
p_token = (wx_V[token] / torch.sum(wx_V)).item()
logger.debug(f'Sampled suffix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
if token == tokenizer.eos_token_id or remaining_tokens == 0:
return
lnpx_V, lnpy_V = model_pair.step(token)