marksverdhei
Add attempt counts
dfbce2c
raw
history blame
2.95 kB
import logging
import torch
from src.constants import MAX_ATTEMPTS
from src.state import STATE
from src.state import model
from src.state import tokenizer
from src.text import get_text
logger = logging.getLogger(__name__)
all_tokens = tokenizer.encode(get_text())
def get_model_predictions(input_text: str) -> torch.Tensor:
"""
Returns the indices as a torch tensor of the top 3 predicted tokens.
"""
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
last_token = logits[0, -1]
top_3 = torch.topk(last_token, 3).indices.tolist()
return top_3
def guess_is_correct(text: str, next_token: int) -> bool:
"""
We check if the predicted token or a corresponding one with a leading whitespace
matches that of the next token
"""
logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([next_token])))
predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(text)
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
return next_token in (predicted_token_start, predicted_token_whitespace)
def get_start_and_whitespace_tokens(word: str) -> tuple[int, int]:
predicted_token_start = tokenizer.encode(word, add_special_tokens=False)[0]
predicted_token_whitespace = tokenizer.encode(". " + word, add_special_tokens=False)[1]
return predicted_token_start, predicted_token_whitespace
def handle_guess(
text: str,
remaining_attempts: int,
) -> str:
"""
*
* Retreives model predictions and compares the top 3 predicted tokens
"""
logger.debug(f"Params:\ntext = {text}\nremaining_attempts = {remaining_attempts}\n")
logger.debug(f"Initial STATE:\n{STATE}")
current_tokens = all_tokens[: STATE.current_word_index]
current_text = tokenizer.decode(current_tokens)
player_guesses = ""
lm_guesses = ""
remaining_attempts -= 1
if not text:
logger.debug("Returning early")
return (current_text, player_guesses, lm_guesses, remaining_attempts)
next_token = all_tokens[STATE.current_word_index]
if guess_is_correct(text, next_token):
STATE.correct_guess()
if remaining_attempts == 0:
STATE.next_word()
current_tokens = all_tokens[: STATE.current_word_index]
remaining_attempts = MAX_ATTEMPTS
else:
STATE.player_guesses.append(text)
# FIXME: unoptimized, computing all three every time
current_text = tokenizer.decode(current_tokens)
STATE.lm_guesses = get_model_predictions(current_text)[: 3 - remaining_attempts]
logger.debug(f"lm_guesses: {tokenizer.decode(STATE.lm_guesses)}")
logger.debug(f"Pre-return STATE:\n{STATE}")
return (
current_text,
STATE.player_guess_str,
STATE.get_lm_guess_display(remaining_attempts),
remaining_attempts,
)