marksverdhei
Add more handlers
7eee83c
raw
history blame
5.27 kB
import logging
import torch
from src.constants import MAX_ATTEMPTS
from src.state import STATE
from src.state import model
from src.state import tokenizer
from src.text import get_text
logger = logging.getLogger(__name__)
all_tokens = tokenizer.encode(get_text())
def get_model_predictions(input_text: str) -> torch.Tensor:
"""
Returns the indices as a torch tensor of the top 3 predicted tokens.
"""
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
last_token = logits[0, -1]
top_3 = torch.topk(last_token, 3).indices.tolist()
return top_3
def guess_is_correct(text: str) -> bool:
"""
We check if the predicted token or a corresponding one with a leading whitespace
matches that of the next token
"""
current_target = all_tokens[STATE.current_word_index]
logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target])))
predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(text)
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
return current_target in (predicted_token_start, predicted_token_whitespace)
def lm_is_correct() -> bool:
# NOTE: out of range if remaining attempts is 0
if STATE.remaining_attempts > 1:
return False
current_guess = STATE.lm_guesses[MAX_ATTEMPTS - STATE.remaining_attempts]
current_target = all_tokens[STATE.current_word_index]
return current_guess == current_target
def get_start_and_whitespace_tokens(word: str) -> tuple[int, int]:
predicted_token_start = tokenizer.encode(word, add_special_tokens=False)[0]
predicted_token_whitespace = tokenizer.encode(". " + word, add_special_tokens=False)[1]
return predicted_token_start, predicted_token_whitespace
def get_current_text():
return tokenizer.decode(all_tokens[: STATE.current_word_index])
def handle_player_win():
# TODO: point system
points = 1
STATE.player_points += points
STATE.button_label = "Next word"
return STATE.get_tuple(
get_current_text(),
bottom_html=f"Player gets {points} point!",
)
def handle_lm_win():
points = 1
STATE.lm_points += points
STATE.button_label = "Next word"
return STATE.get_tuple(
get_current_text(),
bottom_html=f"GPT2 gets {points} point!",
)
def handle_out_of_attempts():
STATE.button_label = "Next word"
return STATE.get_tuple(
get_current_text(),
bottom_html="Out of attempts. No one gets points!",
)
def handle_tie():
STATE.button_label = "Next word"
return STATE.get_tuple(
get_current_text(),
bottom_html="TIE! No one gets points!",
)
def handle_next_attempt():
STATE.remaining_attempts -= 1
return STATE.get_tuple(
get_current_text(), bottom_html=f"That was not it... {STATE.remaining_attempts} attempts left"
)
def handle_no_input():
return STATE.get_tuple(
get_current_text(),
bottom_html="Please write something",
)
def handle_next_word():
STATE.next_word()
STATE.lm_guesses = get_model_predictions(get_current_text())
return STATE.get_tuple()
def handle_guess(
text: str,
*args,
**kwargs,
) -> str:
"""
* Retreives model predictions and compares the top 3 predicted tokens
"""
logger.debug("Params:\n" f"text = {text}\n" f"args = {args}\n" f"kwargs = {kwargs}\n")
logger.debug(f"Initial STATE:\n{STATE}")
if STATE.button_label == "Next word":
return handle_next_word()
if not text:
return handle_no_input()
STATE.player_guesses.append(text)
player_correct = guess_is_correct(text)
lm_correct = lm_is_correct()
if player_correct and lm_correct:
return handle_tie()
elif player_correct and not lm_correct:
return handle_player_win()
elif lm_correct and not player_correct:
return handle_lm_win()
elif STATE.remaining_attempts == 0:
return handle_out_of_attempts()
else:
return handle_next_attempt()
STATE.lm_guesses = get_model_predictions(get_current_text())
# # STATE.correct_guess()
# # remaining_attempts = 0
# # elif lm_guess_is_correct():
# # pass
# else:
# return handle_incorrect_guess()
# # elif remaining_attempts == 0:
# # return handle_out_of_attempts()
# remaining_attempts -= 1
# STATE.player_guesses.append(text)
# if remaining_attempts == 0:
# STATE.next_word()
# current_tokens = all_tokens[: STATE.current_word_index]
# remaining_attempts = MAX_ATTEMPTS
# # FIXME: unoptimized, computing all three every time
# current_text = tokenizer.decode(current_tokens)
# logger.debug(f"lm_guesses: {tokenizer.decode(STATE.lm_guesses)}")
# logger.debug(f"Pre-return STATE:\n{STATE}")
# # BUG: if you enter the word guess field when it says next
# # word, it will guess it as the next
# return (
# current_text,
# STATE.player_points,
# STATE.lm_points,
# STATE.player_guess_str,
# STATE.get_lm_guess_display(remaining_attempts),
# remaining_attempts,
# "",
# "Guess!" if remaining_attempts else "Next word",
# )