import logging import torch from src.constants import MAX_ATTEMPTS from src.state import STATE from src.state import model from src.state import tokenizer from src.text import get_text logger = logging.getLogger(__name__) all_tokens = tokenizer.encode(get_text()) def get_model_predictions(input_text: str) -> torch.Tensor: """ Returns the indices as a torch tensor of the top 3 predicted tokens. """ inputs = tokenizer(input_text, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits last_token = logits[0, -1] top_3 = torch.topk(last_token, 3).indices.tolist() return top_3 def guess_is_correct(text: str) -> bool: """ We check if the predicted token or a corresponding one with a leading whitespace matches that of the next token """ current_target = all_tokens[STATE.current_word_index] logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target]))) predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(text) logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace])) return current_target in (predicted_token_start, predicted_token_whitespace) def lm_is_correct() -> bool: # NOTE: out of range if remaining attempts is 0 if STATE.remaining_attempts > 1: return False current_guess = STATE.lm_guesses[MAX_ATTEMPTS - STATE.remaining_attempts] current_target = all_tokens[STATE.current_word_index] return current_guess == current_target def get_start_and_whitespace_tokens(word: str) -> tuple[int, int]: predicted_token_start = tokenizer.encode(word, add_special_tokens=False)[0] predicted_token_whitespace = tokenizer.encode(". " + word, add_special_tokens=False)[1] return predicted_token_start, predicted_token_whitespace def get_current_text(): return tokenizer.decode(all_tokens[: STATE.current_word_index]) def handle_player_win(): # TODO: point system points = 1 STATE.player_points += points STATE.button_label = "Next word" return STATE.get_tuple( get_current_text(), bottom_html=f"Player gets {points} point!", ) def handle_lm_win(): points = 1 STATE.lm_points += points STATE.button_label = "Next word" return STATE.get_tuple( get_current_text(), bottom_html=f"GPT2 gets {points} point!", ) def handle_out_of_attempts(): STATE.button_label = "Next word" return STATE.get_tuple( get_current_text(), bottom_html="Out of attempts. No one gets points!", ) def handle_tie(): STATE.button_label = "Next word" return STATE.get_tuple( get_current_text(), bottom_html="TIE! No one gets points!", ) def handle_next_attempt(): STATE.remaining_attempts -= 1 return STATE.get_tuple( get_current_text(), bottom_html=f"That was not it... {STATE.remaining_attempts} attempts left" ) def handle_no_input(): return STATE.get_tuple( get_current_text(), bottom_html="Please write something", ) def handle_next_word(): STATE.next_word() STATE.lm_guesses = get_model_predictions(get_current_text()) return STATE.get_tuple() def handle_guess( text: str, *args, **kwargs, ) -> str: """ * Retreives model predictions and compares the top 3 predicted tokens """ logger.debug("Params:\n" f"text = {text}\n" f"args = {args}\n" f"kwargs = {kwargs}\n") logger.debug(f"Initial STATE:\n{STATE}") if STATE.button_label == "Next word": return handle_next_word() if not text: return handle_no_input() STATE.player_guesses.append(text) player_correct = guess_is_correct(text) lm_correct = lm_is_correct() if player_correct and lm_correct: return handle_tie() elif player_correct and not lm_correct: return handle_player_win() elif lm_correct and not player_correct: return handle_lm_win() elif STATE.remaining_attempts == 0: return handle_out_of_attempts() else: return handle_next_attempt() STATE.lm_guesses = get_model_predictions(get_current_text()) # # STATE.correct_guess() # # remaining_attempts = 0 # # elif lm_guess_is_correct(): # # pass # else: # return handle_incorrect_guess() # # elif remaining_attempts == 0: # # return handle_out_of_attempts() # remaining_attempts -= 1 # STATE.player_guesses.append(text) # if remaining_attempts == 0: # STATE.next_word() # current_tokens = all_tokens[: STATE.current_word_index] # remaining_attempts = MAX_ATTEMPTS # # FIXME: unoptimized, computing all three every time # current_text = tokenizer.decode(current_tokens) # logger.debug(f"lm_guesses: {tokenizer.decode(STATE.lm_guesses)}") # logger.debug(f"Pre-return STATE:\n{STATE}") # # BUG: if you enter the word guess field when it says next # # word, it will guess it as the next # return ( # current_text, # STATE.player_points, # STATE.lm_points, # STATE.player_guess_str, # STATE.get_lm_guess_display(remaining_attempts), # remaining_attempts, # "", # "Guess!" if remaining_attempts else "Next word", # )