Spaces:
Runtime error
Runtime error
import logging | |
import torch | |
from src.constants import MAX_ATTEMPTS | |
from src.state import STATE | |
from src.state import model | |
from src.state import tokenizer | |
from src.text import get_text | |
logger = logging.getLogger(__name__) | |
all_tokens = tokenizer.encode(get_text()) | |
def get_model_predictions(input_text: str) -> torch.Tensor: | |
""" | |
Returns the indices as a torch tensor of the top 3 predicted tokens. | |
""" | |
inputs = tokenizer(input_text, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
last_token = logits[0, -1] | |
top_3 = torch.topk(last_token, 3).indices.tolist() | |
return top_3 | |
def guess_is_correct(text: str) -> bool: | |
""" | |
We check if the predicted token or a corresponding one with a leading whitespace | |
matches that of the next token | |
""" | |
current_target = all_tokens[STATE.current_word_index] | |
logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target]))) | |
predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(text) | |
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace])) | |
return current_target in (predicted_token_start, predicted_token_whitespace) | |
def lm_is_correct() -> bool: | |
# NOTE: out of range if remaining attempts is 0 | |
if STATE.remaining_attempts > 1: | |
return False | |
current_guess = STATE.lm_guesses[MAX_ATTEMPTS - STATE.remaining_attempts] | |
current_target = all_tokens[STATE.current_word_index] | |
return current_guess == current_target | |
def get_start_and_whitespace_tokens(word: str) -> tuple[int, int]: | |
predicted_token_start = tokenizer.encode(word, add_special_tokens=False)[0] | |
predicted_token_whitespace = tokenizer.encode(". " + word, add_special_tokens=False)[1] | |
return predicted_token_start, predicted_token_whitespace | |
def get_current_text(): | |
return tokenizer.decode(all_tokens[: STATE.current_word_index]) | |
def handle_player_win(): | |
# TODO: point system | |
points = 1 | |
STATE.player_points += points | |
STATE.button_label = "Next word" | |
return STATE.get_tuple( | |
get_current_text(), | |
bottom_html=f"Player gets {points} point!", | |
) | |
def handle_lm_win(): | |
points = 1 | |
STATE.lm_points += points | |
STATE.button_label = "Next word" | |
return STATE.get_tuple( | |
get_current_text(), | |
bottom_html=f"GPT2 gets {points} point!", | |
) | |
def handle_out_of_attempts(): | |
STATE.button_label = "Next word" | |
return STATE.get_tuple( | |
get_current_text(), | |
bottom_html="Out of attempts. No one gets points!", | |
) | |
def handle_tie(): | |
STATE.button_label = "Next word" | |
return STATE.get_tuple( | |
get_current_text(), | |
bottom_html="TIE! No one gets points!", | |
) | |
def handle_next_attempt(): | |
STATE.remaining_attempts -= 1 | |
return STATE.get_tuple( | |
get_current_text(), bottom_html=f"That was not it... {STATE.remaining_attempts} attempts left" | |
) | |
def handle_no_input(): | |
return STATE.get_tuple( | |
get_current_text(), | |
bottom_html="Please write something", | |
) | |
def handle_next_word(): | |
STATE.next_word() | |
STATE.lm_guesses = get_model_predictions(get_current_text()) | |
return STATE.get_tuple() | |
def handle_guess( | |
text: str, | |
*args, | |
**kwargs, | |
) -> str: | |
""" | |
* Retreives model predictions and compares the top 3 predicted tokens | |
""" | |
logger.debug("Params:\n" f"text = {text}\n" f"args = {args}\n" f"kwargs = {kwargs}\n") | |
logger.debug(f"Initial STATE:\n{STATE}") | |
if STATE.button_label == "Next word": | |
return handle_next_word() | |
if not text: | |
return handle_no_input() | |
STATE.player_guesses.append(text) | |
player_correct = guess_is_correct(text) | |
lm_correct = lm_is_correct() | |
if player_correct and lm_correct: | |
return handle_tie() | |
elif player_correct and not lm_correct: | |
return handle_player_win() | |
elif lm_correct and not player_correct: | |
return handle_lm_win() | |
elif STATE.remaining_attempts == 0: | |
return handle_out_of_attempts() | |
else: | |
return handle_next_attempt() | |
STATE.lm_guesses = get_model_predictions(get_current_text()) | |
# # STATE.correct_guess() | |
# # remaining_attempts = 0 | |
# # elif lm_guess_is_correct(): | |
# # pass | |
# else: | |
# return handle_incorrect_guess() | |
# # elif remaining_attempts == 0: | |
# # return handle_out_of_attempts() | |
# remaining_attempts -= 1 | |
# STATE.player_guesses.append(text) | |
# if remaining_attempts == 0: | |
# STATE.next_word() | |
# current_tokens = all_tokens[: STATE.current_word_index] | |
# remaining_attempts = MAX_ATTEMPTS | |
# # FIXME: unoptimized, computing all three every time | |
# current_text = tokenizer.decode(current_tokens) | |
# logger.debug(f"lm_guesses: {tokenizer.decode(STATE.lm_guesses)}") | |
# logger.debug(f"Pre-return STATE:\n{STATE}") | |
# # BUG: if you enter the word guess field when it says next | |
# # word, it will guess it as the next | |
# return ( | |
# current_text, | |
# STATE.player_points, | |
# STATE.lm_points, | |
# STATE.player_guess_str, | |
# STATE.get_lm_guess_display(remaining_attempts), | |
# remaining_attempts, | |
# "", | |
# "Guess!" if remaining_attempts else "Next word", | |
# ) | |