Spaces:
Runtime error
Runtime error
import logging | |
import torch | |
from src.constants import MAX_ATTEMPTS | |
from src.state import STATE | |
from src.state import model | |
from src.state import tokenizer | |
from src.text import get_text | |
logger = logging.getLogger(__name__) | |
all_tokens = tokenizer.encode(get_text()) | |
def get_model_predictions(input_text: str) -> torch.Tensor: | |
""" | |
Returns the indices as a torch tensor of the top 3 predicted tokens. | |
""" | |
inputs = tokenizer(input_text, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
last_token = logits[0, -1] | |
top_3 = torch.topk(last_token, 3).indices.tolist() | |
return top_3 | |
def guess_is_correct(text: str, next_token: int) -> bool: | |
""" | |
We check if the predicted token or a corresponding one with a leading whitespace | |
matches that of the next token | |
""" | |
logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([next_token]))) | |
predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(text) | |
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace])) | |
return next_token in (predicted_token_start, predicted_token_whitespace) | |
def get_start_and_whitespace_tokens(word: str) -> tuple[int, int]: | |
predicted_token_start = tokenizer.encode(word, add_special_tokens=False)[0] | |
predicted_token_whitespace = tokenizer.encode(". " + word, add_special_tokens=False)[1] | |
return predicted_token_start, predicted_token_whitespace | |
def handle_guess( | |
text: str, | |
remaining_attempts: int, | |
) -> str: | |
""" | |
* | |
* Retreives model predictions and compares the top 3 predicted tokens | |
""" | |
logger.debug(f"Params:\ntext = {text}\nremaining_attempts = {remaining_attempts}\n") | |
logger.debug(f"Initial STATE:\n{STATE}") | |
current_tokens = all_tokens[: STATE.current_word_index] | |
current_text = tokenizer.decode(current_tokens) | |
player_guesses = "" | |
lm_guesses = "" | |
remaining_attempts -= 1 | |
if not text: | |
logger.debug("Returning early") | |
return (current_text, player_guesses, lm_guesses, remaining_attempts) | |
next_token = all_tokens[STATE.current_word_index] | |
if guess_is_correct(text, next_token): | |
STATE.correct_guess() | |
if remaining_attempts == 0: | |
STATE.next_word() | |
current_tokens = all_tokens[: STATE.current_word_index] | |
remaining_attempts = MAX_ATTEMPTS | |
else: | |
STATE.player_guesses.append(text) | |
# FIXME: unoptimized, computing all three every time | |
current_text = tokenizer.decode(current_tokens) | |
STATE.lm_guesses = get_model_predictions(current_text)[: 3 - remaining_attempts] | |
logger.debug(f"lm_guesses: {tokenizer.decode(STATE.lm_guesses)}") | |
logger.debug(f"Pre-return STATE:\n{STATE}") | |
return ( | |
current_text, | |
STATE.player_guess_str, | |
STATE.get_lm_guess_display(remaining_attempts), | |
remaining_attempts, | |
) | |