|
|
|
|
|
import logging |
|
import time |
|
import torch |
|
import torch.nn.functional as F |
|
import transformers |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def isolate_responses_BL(output_BL, prompt_len, eos_token_id): |
|
responses_BL = [] |
|
resp_BL = output_BL[:, prompt_len:] |
|
for i in range(resp_BL.shape[0]): |
|
resp_L = resp_BL[i] |
|
resplen = resp_L.shape[0] |
|
for j in range(resplen): |
|
if resp_L[j] == eos_token_id: |
|
resplen = j+1 |
|
break |
|
response_L = resp_L[:resplen].cpu().detach().numpy() |
|
responses_BL.append(response_L) |
|
return responses_BL |
|
|
|
def tokenize_prompt(device, tokenizer, chat, quiet=False): |
|
|
|
prompt_BL = tokenizer.apply_chat_template( |
|
[chat], |
|
tokenizer=True, add_generation_prompt=True, return_tensors='pt' |
|
).to(device) |
|
if not quiet: |
|
print('PROMPT:') |
|
print(tokenizer.decode(prompt_BL[0])) |
|
return prompt_BL |
|
|
|
def generate(device, model, tokenizer, chat): |
|
"""Generate a response using huggingface's generation.""" |
|
|
|
prompt_BL = tokenize_prompt(device, tokenizer, chat) |
|
prompt_len = prompt_BL.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
torch.use_deterministic_algorithms(False) |
|
generation_output = model.generate( |
|
inputs=prompt_BL, max_new_tokens=512, do_sample=True, |
|
return_dict_in_generate=True, |
|
) |
|
torch.use_deterministic_algorithms(True) |
|
|
|
output_BL = generation_output.sequences |
|
|
|
for response_L in isolate_responses_BL(output_BL, prompt_len, tokenizer.eos_token_id): |
|
print(f'RESPONSE:') |
|
print(tokenizer.decode(response_L)) |
|
|
|
return tokenizer.decode(response_L) |
|
|
|
def generate_with_logits(device, model, tokenizer, chat, seed=None): |
|
|
|
|
|
if seed is not None: |
|
transformers.set_seed(seed) |
|
|
|
prompt_BL = tokenize_prompt(device, tokenizer, chat) |
|
prompt_len = prompt_BL.shape[1] |
|
|
|
|
|
|
|
torch.use_deterministic_algorithms(False) |
|
generation_output = model.generate( |
|
inputs=prompt_BL, max_new_tokens=16, do_sample=True, |
|
return_dict_in_generate=True, output_logits=True, |
|
) |
|
torch.use_deterministic_algorithms(True) |
|
|
|
output_BL = generation_output.sequences |
|
logits_BLV = torch.stack(generation_output.logits, axis=1) |
|
|
|
for response_L in isolate_responses_BL(output_BL, prompt_len, tokenizer.eos_token_id): |
|
print(f'RESPONSE:') |
|
print(tokenizer.decode(response_L)) |
|
|
|
return response_L, logits_BLV[0] |
|
|
|
def response_logits(device, model, tokenizer, chat, response_L): |
|
|
|
|
|
prompt_BL = tokenize_prompt(device, tokenizer, chat) |
|
response_pt_L = torch.from_numpy(response_L[:-1]) |
|
|
|
|
|
input_ids = torch.cat((prompt_BL, response_pt_L[None]), dim=1) |
|
|
|
outputs = model(input_ids) |
|
logits_BLV = outputs.logits |
|
|
|
return logits_BLV[0][-len(response_L):] |
|
|
|
def generate_custom(device, model, tokenizer, chat, max_tokens=512, seed=None, return_tokens=False, quiet=False, return_lnprobs=False): |
|
"""Generate a response using custom generation.""" |
|
|
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
|
|
prompt_BL = tokenize_prompt(device, tokenizer, chat, quiet=quiet) |
|
|
|
|
|
|
|
input_ids = prompt_BL |
|
past_key_values = None |
|
|
|
response_L = [] |
|
lnprobs_L = [] |
|
while 1: |
|
outputs = model(input_ids, past_key_values=past_key_values, use_cache=True) |
|
past_key_values = outputs.past_key_values |
|
|
|
logits = outputs.logits[0, -1, :] |
|
p_V = F.softmax(logits, dim=-1) |
|
token = torch.multinomial(p_V, num_samples=1).item() |
|
response_L.append(token) |
|
lnprobs_L.append(F.log_softmax(logits, dim=-1)[token].item()) |
|
|
|
p_token = p_V[token].item() |
|
if not quiet: |
|
print(f' Sampled token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)') |
|
|
|
if token == tokenizer.eos_token_id or len(response_L) >= max_tokens: |
|
break |
|
|
|
input_ids = torch.tensor([[token]], device=device) |
|
|
|
if not quiet: |
|
print(f'RESPONSE:') |
|
print(tokenizer.decode(response_L)) |
|
|
|
if return_tokens: |
|
if return_lnprobs: |
|
return response_L, lnprobs_L |
|
else: |
|
return response_L |
|
else: |
|
assert not return_lnprobs |
|
return tokenizer.decode(response_L) |
|
|
|
def format_token(tokenizer, token_id): |
|
return repr(tokenizer.decode(token_id)) |
|
|
|
def apoc(device, model_x, model_y, tokenizer, chat_x, chat_y, max_tokens=512, seed=None, return_tokens=False, quiet=False): |
|
"""Generate a response using APOC unconditional sampling.""" |
|
|
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
|
|
|
|
|
|
logit_dtype = torch.float64 |
|
|
|
prompt_x_BL = tokenize_prompt(device, tokenizer, chat_x, quiet=quiet) |
|
prompt_y_BL = tokenize_prompt(device, tokenizer, chat_y, quiet=quiet) |
|
|
|
|
|
|
|
input_ids_xa = prompt_x_BL |
|
input_ids_ya = prompt_y_BL |
|
input_ids_xb = prompt_x_BL |
|
input_ids_yb = prompt_y_BL |
|
|
|
past_key_values_xa = None |
|
past_key_values_ya = None |
|
past_key_values_xb = None |
|
past_key_values_yb = None |
|
|
|
equal = True |
|
a_eos = False |
|
b_eos = False |
|
def zero(): |
|
return torch.zeros(1, dtype=logit_dtype, device=device) |
|
ln_pya_m_ln_pxa = zero() |
|
ln_pxb_m_ln_pyb = zero() |
|
|
|
response_a_L = [] |
|
response_b_L = [] |
|
i = 0 |
|
while 1: |
|
if i >= max_tokens or (a_eos and b_eos): break |
|
if not quiet: |
|
print(f'Generating response token {i}') |
|
i += 1 |
|
|
|
forward_passes_start = time.perf_counter() |
|
|
|
if not a_eos: |
|
outputs = model_x(input_ids_xa, past_key_values=past_key_values_xa, use_cache=True) |
|
past_key_values_xa = outputs.past_key_values |
|
logits = outputs.logits[0, -1, :].to(logit_dtype) |
|
ln_pxa_V = F.log_softmax(logits, dim=-1) |
|
|
|
outputs = model_y(input_ids_ya, past_key_values=past_key_values_ya, use_cache=True) |
|
past_key_values_ya = outputs.past_key_values |
|
logits = outputs.logits[0, -1, :].to(logit_dtype) |
|
ln_pya_V = F.log_softmax(logits, dim=-1) |
|
|
|
if not b_eos: |
|
if equal: |
|
|
|
|
|
assert not a_eos |
|
past_key_values_xb = past_key_values_xa |
|
ln_pxb_V = ln_pxa_V |
|
past_key_values_yb = past_key_values_ya |
|
ln_pyb_V = ln_pya_V |
|
else: |
|
outputs = model_x(input_ids_xb, past_key_values=past_key_values_xb, use_cache=True) |
|
past_key_values_xb = outputs.past_key_values |
|
logits = outputs.logits[0, -1, :].to(logit_dtype) |
|
ln_pxb_V = F.log_softmax(logits, dim=-1) |
|
|
|
outputs = model_y(input_ids_yb, past_key_values=past_key_values_yb, use_cache=True) |
|
past_key_values_yb = outputs.past_key_values |
|
logits = outputs.logits[0, -1, :].to(logit_dtype) |
|
ln_pyb_V = F.log_softmax(logits, dim=-1) |
|
|
|
forward_passes_end = time.perf_counter() |
|
if not quiet: |
|
print(f' Forward passes took {(forward_passes_end - forward_passes_start)*1000:.0f} ms') |
|
|
|
if equal: |
|
ln_pmeet_V = torch.minimum( |
|
ln_pxa_V + torch.maximum(zero(), -ln_pya_m_ln_pxa), |
|
ln_pya_V + torch.maximum(zero(), ln_pya_m_ln_pxa), |
|
) |
|
pmeet_V = torch.exp(ln_pmeet_V) |
|
pmeet = torch.sum(pmeet_V) |
|
|
|
if not quiet: |
|
print(f' Equal mode (pmeet={pmeet*100.0:.3f}%)') |
|
|
|
if torch.rand_like(pmeet) < pmeet: |
|
token_a = token_b = torch.multinomial(pmeet_V, num_samples=1).item() |
|
|
|
p_token_a = (pmeet_V[token_a] / pmeet).item() |
|
if not quiet: |
|
print(f' Sampled {format_token(tokenizer, token_a)} ({p_token_a*100.0:.1f}%)') |
|
else: |
|
if not quiet: |
|
print(' Exited equal mode') |
|
equal = False |
|
|
|
if not equal: |
|
if not a_eos: |
|
wxt_V = torch.maximum(zero(), torch.exp(ln_pxa_V) - torch.exp(ln_pya_V + ln_pya_m_ln_pxa)) |
|
token_a = torch.multinomial(wxt_V, num_samples=1).item() |
|
|
|
p_token_a = (wxt_V[token_a] / torch.sum(wxt_V)).item() |
|
if not quiet: |
|
print(f' Sampled token_a {format_token(tokenizer, token_a)} ({p_token_a*100.0:.3f}%)') |
|
if not b_eos: |
|
wyt_V = torch.maximum(zero(), torch.exp(ln_pyb_V) - torch.exp(ln_pxb_V + ln_pxb_m_ln_pyb)) |
|
token_b = torch.multinomial(wyt_V, num_samples=1).item() |
|
|
|
p_token_b = (wyt_V[token_b] / torch.sum(wyt_V)).item() |
|
if not quiet: |
|
print(f' Sampled token_b {format_token(tokenizer, token_b)} ({p_token_b*100.0:.3f}%)') |
|
|
|
if not a_eos: |
|
response_a_L.append(token_a) |
|
input_ids_xa = input_ids_ya = torch.tensor([[token_a]], device=device) |
|
ln_pya_m_ln_pxa += ln_pya_V[token_a] - ln_pxa_V[token_a] |
|
if token_a == tokenizer.eos_token_id: |
|
a_eos = True |
|
if not b_eos: |
|
response_b_L.append(token_b) |
|
input_ids_xb = input_ids_yb = torch.tensor([[token_b]], device=device) |
|
ln_pxb_m_ln_pyb += ln_pxb_V[token_b] - ln_pyb_V[token_b] |
|
if token_b == tokenizer.eos_token_id: |
|
b_eos = True |
|
|
|
if not quiet: |
|
print(f'RESPONSE X:') |
|
print(tokenizer.decode(response_a_L)) |
|
print(f'RESPONSE Y:') |
|
print(tokenizer.decode(response_b_L)) |
|
|
|
if return_tokens: |
|
return response_a_L, response_b_L |
|
else: |
|
return tokenizer.decode(response_a_L), tokenizer.decode(response_b_L) |
|
|
|
|
|
@torch.no_grad() |
|
def apoc_alt(device, model_x, model_y, tokenizer, chat_x, chat_y, max_tokens=512, seed=None): |
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
|
|
prompt_x_BL = tokenize_prompt(device, tokenizer, chat_x, quiet=True) |
|
prompt_y_BL = tokenize_prompt(device, tokenizer, chat_y, quiet=True) |
|
model_pair = ModelPair(model_x, model_y, prompt_x_BL, prompt_y_BL) |
|
|
|
logger.debug('PROMPT X:') |
|
logger.debug(tokenizer.decode(prompt_x_BL[0])) |
|
logger.debug('PROMPT Y:') |
|
logger.debug(tokenizer.decode(prompt_y_BL[0])) |
|
|
|
return _apoc_impl(model_pair, tokenizer, max_tokens) |
|
|
|
LOGIT_DTYPE = torch.float64 |
|
|
|
class ModelPair: |
|
def __init__(self, model_x, model_y, prompt_x_BL, prompt_y_BL): |
|
self._model_x = model_x |
|
self._model_y = model_y |
|
self._prompt_x_BL = prompt_x_BL |
|
self._prompt_y_BL = prompt_y_BL |
|
self._is_swapped = False |
|
|
|
def start(self): |
|
|
|
|
|
outputs = self._model_x(self._prompt_x_BL, use_cache=True) |
|
self._past_key_values_x = outputs.past_key_values |
|
logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE) |
|
lnpx_V = F.log_softmax(logits, dim=-1) |
|
|
|
outputs = self._model_y(self._prompt_y_BL, use_cache=True) |
|
self._past_key_values_y = outputs.past_key_values |
|
logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE) |
|
lnpy_V = F.log_softmax(logits, dim=-1) |
|
|
|
return self._maybe_swap(lnpx_V, lnpy_V) |
|
|
|
def step(self, token): |
|
|
|
|
|
forward_passes_start = time.perf_counter() |
|
|
|
input_ids = torch.tensor([[token]], device=self._prompt_x_BL.device) |
|
|
|
outputs = self._model_x(input_ids, past_key_values=self._past_key_values_x, use_cache=True) |
|
self._past_key_values_x = outputs.past_key_values |
|
logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE) |
|
lnpx_V = F.log_softmax(logits, dim=-1) |
|
|
|
outputs = self._model_y(input_ids, past_key_values=self._past_key_values_y, use_cache=True) |
|
self._past_key_values_y = outputs.past_key_values |
|
logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE) |
|
lnpy_V = F.log_softmax(logits, dim=-1) |
|
|
|
forward_passes_end = time.perf_counter() |
|
logger.debug(f'Incremental forward passes took {(forward_passes_end - forward_passes_start)*1000:.0f} ms') |
|
|
|
return self._maybe_swap(lnpx_V, lnpy_V) |
|
|
|
def get_position(self): |
|
|
|
return self._past_key_values_x, self._past_key_values_y |
|
|
|
def rewind_to(self, position): |
|
|
|
self._past_key_values_x, self._past_key_values_y = position |
|
|
|
def swap_models(self): |
|
|
|
self._is_swapped = not self._is_swapped |
|
|
|
def _maybe_swap(self, a, b): |
|
if self._is_swapped: |
|
return b, a |
|
else: |
|
return a, b |
|
|
|
def _apoc_impl(model_pair, tokenizer, max_tokens): |
|
prefix = [] |
|
lnpx_V, lnpy_V = model_pair.start() |
|
lnpy_m_lnpx = torch.zeros(1, dtype=lnpx_V.dtype, device=lnpx_V.device) |
|
|
|
while 1: |
|
ln_peq_V = torch.minimum( |
|
lnpx_V + F.relu(-lnpy_m_lnpx), |
|
lnpy_V + F.relu(lnpy_m_lnpx), |
|
) |
|
peq_V = torch.exp(ln_peq_V) |
|
peq = torch.sum(peq_V) |
|
|
|
if torch.rand_like(peq) > peq: |
|
logger.debug(f'Completed common prefix ({(1-peq)*100.:.3f}%)') |
|
break |
|
logger.debug(f'Extending common prefix ({peq*100.:.3f}%)') |
|
|
|
token = torch.multinomial(peq_V, 1).item() |
|
prefix.append(token) |
|
lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token] |
|
|
|
p_token = (peq_V[token] / peq).item() |
|
logger.debug(f'Sampled prefix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)') |
|
|
|
if token == tokenizer.eos_token_id or len(prefix) >= max_tokens: |
|
return prefix, prefix |
|
|
|
lnpx_V, lnpy_V = model_pair.step(token) |
|
|
|
remaining_tokens = max_tokens - len(prefix) |
|
split_pos = model_pair.get_position() |
|
response_a = prefix + _apoc_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx) |
|
logger.debug('First suffix complete; rewinding') |
|
model_pair.rewind_to(split_pos) |
|
model_pair.swap_models() |
|
response_b = prefix + _apoc_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpy_V, lnpx_V, -lnpy_m_lnpx) |
|
|
|
return response_a, response_b |
|
|
|
def _apoc_gen_suffix(model_pair, tokenizer, max_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx): |
|
lnpy_m_lnpx = lnpy_m_lnpx.clone() |
|
suffix = [] |
|
while 1: |
|
wx_V = F.relu(torch.exp(lnpx_V) - torch.exp(lnpy_V + lnpy_m_lnpx)) |
|
token = torch.multinomial(wx_V, 1).item() |
|
suffix.append(token) |
|
lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token] |
|
|
|
p_token = (wx_V[token] / torch.sum(wx_V)).item() |
|
logger.debug(f'Sampled suffix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)') |
|
|
|
if token == tokenizer.eos_token_id or len(suffix) >= max_tokens: |
|
return suffix |
|
|
|
lnpx_V, lnpy_V = model_pair.step(token) |
|
|
|
def generate_streaming(device, model, tokenizer, chat, max_tokens=512, seed=None): |
|
"""Stream a response using custom generation.""" |
|
|
|
prompt_BL = tokenize_prompt(device, tokenizer, chat, quiet=True) |
|
logger.debug('PROMPT:') |
|
logger.debug(tokenizer.decode(prompt_BL[0])) |
|
|
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
|
|
return _generate_streaming_impl(device, model, tokenizer, prompt_BL, max_tokens) |
|
|
|
def _generate_streaming_impl(device, model, tokenizer, prompt_BL, max_tokens): |
|
input_ids = prompt_BL |
|
past_key_values = None |
|
|
|
n_tokens = 0 |
|
while 1: |
|
outputs = model(input_ids, past_key_values=past_key_values, use_cache=True) |
|
past_key_values = outputs.past_key_values |
|
|
|
logits = outputs.logits[0, -1, :] |
|
p_V = F.softmax(logits, dim=-1) |
|
token = torch.multinomial(p_V, num_samples=1).item() |
|
|
|
p_token = p_V[token].item() |
|
logger.debug(f' Sampled token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)') |
|
|
|
yield token |
|
n_tokens += 1 |
|
|
|
if token == tokenizer.eos_token_id or n_tokens >= max_tokens: |
|
break |
|
|
|
input_ids = torch.tensor([[token]], device=device) |
|
|
|
|
|
@torch.no_grad() |
|
def apoc_streaming(model_x, model_y, tokenizer, chat_x, chat_y, max_tokens=512, seed=None): |
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
|
|
prompt_x_BL = tokenize_prompt(model_x.device, tokenizer, chat_x, quiet=True) |
|
prompt_y_BL = tokenize_prompt(model_y.device, tokenizer, chat_y, quiet=True) |
|
model_pair = ModelPair(model_x, model_y, prompt_x_BL, prompt_y_BL) |
|
|
|
logger.debug('PROMPT X:') |
|
logger.debug(tokenizer.decode(prompt_x_BL[0])) |
|
logger.debug('PROMPT Y:') |
|
logger.debug(tokenizer.decode(prompt_y_BL[0])) |
|
|
|
return _apoc_streaming_impl(model_pair, tokenizer, max_tokens) |
|
|
|
def _apoc_streaming_impl(model_pair, tokenizer, max_tokens): |
|
remaining_tokens = max_tokens |
|
lnpx_V, lnpy_V = model_pair.start() |
|
lnpy_m_lnpx = torch.zeros(1, dtype=lnpx_V.dtype, device=lnpx_V.device) |
|
|
|
while 1: |
|
ln_peq_V = torch.minimum( |
|
lnpx_V + F.relu(-lnpy_m_lnpx), |
|
lnpy_V + F.relu(lnpy_m_lnpx), |
|
) |
|
peq_V = torch.exp(ln_peq_V) |
|
peq = torch.sum(peq_V) |
|
|
|
if torch.rand_like(peq) > peq: |
|
logger.debug(f'Completed common prefix ({(1-peq)*100.:.3f}%)') |
|
break |
|
logger.debug(f'Extending common prefix ({peq*100.:.3f}%)') |
|
|
|
token = torch.multinomial(peq_V, 1).item() |
|
remaining_tokens -= 1 |
|
yield token, token |
|
lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token] |
|
|
|
p_token = (peq_V[token] / peq).item() |
|
logger.debug(f'Sampled prefix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)') |
|
|
|
if token == tokenizer.eos_token_id or remaining_tokens == 0: |
|
return |
|
|
|
lnpx_V, lnpy_V = model_pair.step(token) |
|
|
|
split_pos = model_pair.get_position() |
|
for token_a in _apoc_streaming_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx): |
|
yield token_a, None |
|
logger.debug('Suffix a complete; rewinding') |
|
model_pair.rewind_to(split_pos) |
|
model_pair.swap_models() |
|
for token_b in _apoc_streaming_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpy_V, lnpx_V, -lnpy_m_lnpx): |
|
yield None, token_b |
|
logger.debug('Suffix b complete') |
|
|
|
def _apoc_streaming_gen_suffix(model_pair, tokenizer, max_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx): |
|
remaining_tokens = max_tokens |
|
lnpy_m_lnpx = lnpy_m_lnpx.clone() |
|
while 1: |
|
wx_V = F.relu(torch.exp(lnpx_V) - torch.exp(lnpy_V + lnpy_m_lnpx)) |
|
token = torch.multinomial(wx_V, 1).item() |
|
remaining_tokens -= 1 |
|
yield token |
|
lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token] |
|
|
|
p_token = (wx_V[token] / torch.sum(wx_V)).item() |
|
logger.debug(f'Sampled suffix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)') |
|
|
|
if token == tokenizer.eos_token_id or remaining_tokens == 0: |
|
return |
|
|
|
lnpx_V, lnpy_V = model_pair.step(token) |