Spaces:
Runtime error
Runtime error
File size: 5,269 Bytes
614d543 dfbce2c 614d543 dfbce2c 614d543 7eee83c 614d543 dfbce2c 614d543 7eee83c dfbce2c 7eee83c dfbce2c 7eee83c dfbce2c 0e7f280 dfbce2c 7eee83c dfbce2c 7eee83c 614d543 7eee83c 0e7f280 7eee83c 614d543 7eee83c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import logging
import torch
from src.constants import MAX_ATTEMPTS
from src.state import STATE
from src.state import model
from src.state import tokenizer
from src.text import get_text
logger = logging.getLogger(__name__)
all_tokens = tokenizer.encode(get_text())
def get_model_predictions(input_text: str) -> torch.Tensor:
"""
Returns the indices as a torch tensor of the top 3 predicted tokens.
"""
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
last_token = logits[0, -1]
top_3 = torch.topk(last_token, 3).indices.tolist()
return top_3
def guess_is_correct(text: str) -> bool:
"""
We check if the predicted token or a corresponding one with a leading whitespace
matches that of the next token
"""
current_target = all_tokens[STATE.current_word_index]
logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target])))
predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(text)
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
return current_target in (predicted_token_start, predicted_token_whitespace)
def lm_is_correct() -> bool:
# NOTE: out of range if remaining attempts is 0
if STATE.remaining_attempts > 1:
return False
current_guess = STATE.lm_guesses[MAX_ATTEMPTS - STATE.remaining_attempts]
current_target = all_tokens[STATE.current_word_index]
return current_guess == current_target
def get_start_and_whitespace_tokens(word: str) -> tuple[int, int]:
predicted_token_start = tokenizer.encode(word, add_special_tokens=False)[0]
predicted_token_whitespace = tokenizer.encode(". " + word, add_special_tokens=False)[1]
return predicted_token_start, predicted_token_whitespace
def get_current_text():
return tokenizer.decode(all_tokens[: STATE.current_word_index])
def handle_player_win():
# TODO: point system
points = 1
STATE.player_points += points
STATE.button_label = "Next word"
return STATE.get_tuple(
get_current_text(),
bottom_html=f"Player gets {points} point!",
)
def handle_lm_win():
points = 1
STATE.lm_points += points
STATE.button_label = "Next word"
return STATE.get_tuple(
get_current_text(),
bottom_html=f"GPT2 gets {points} point!",
)
def handle_out_of_attempts():
STATE.button_label = "Next word"
return STATE.get_tuple(
get_current_text(),
bottom_html="Out of attempts. No one gets points!",
)
def handle_tie():
STATE.button_label = "Next word"
return STATE.get_tuple(
get_current_text(),
bottom_html="TIE! No one gets points!",
)
def handle_next_attempt():
STATE.remaining_attempts -= 1
return STATE.get_tuple(
get_current_text(), bottom_html=f"That was not it... {STATE.remaining_attempts} attempts left"
)
def handle_no_input():
return STATE.get_tuple(
get_current_text(),
bottom_html="Please write something",
)
def handle_next_word():
STATE.next_word()
STATE.lm_guesses = get_model_predictions(get_current_text())
return STATE.get_tuple()
def handle_guess(
text: str,
*args,
**kwargs,
) -> str:
"""
* Retreives model predictions and compares the top 3 predicted tokens
"""
logger.debug("Params:\n" f"text = {text}\n" f"args = {args}\n" f"kwargs = {kwargs}\n")
logger.debug(f"Initial STATE:\n{STATE}")
if STATE.button_label == "Next word":
return handle_next_word()
if not text:
return handle_no_input()
STATE.player_guesses.append(text)
player_correct = guess_is_correct(text)
lm_correct = lm_is_correct()
if player_correct and lm_correct:
return handle_tie()
elif player_correct and not lm_correct:
return handle_player_win()
elif lm_correct and not player_correct:
return handle_lm_win()
elif STATE.remaining_attempts == 0:
return handle_out_of_attempts()
else:
return handle_next_attempt()
STATE.lm_guesses = get_model_predictions(get_current_text())
# # STATE.correct_guess()
# # remaining_attempts = 0
# # elif lm_guess_is_correct():
# # pass
# else:
# return handle_incorrect_guess()
# # elif remaining_attempts == 0:
# # return handle_out_of_attempts()
# remaining_attempts -= 1
# STATE.player_guesses.append(text)
# if remaining_attempts == 0:
# STATE.next_word()
# current_tokens = all_tokens[: STATE.current_word_index]
# remaining_attempts = MAX_ATTEMPTS
# # FIXME: unoptimized, computing all three every time
# current_text = tokenizer.decode(current_tokens)
# logger.debug(f"lm_guesses: {tokenizer.decode(STATE.lm_guesses)}")
# logger.debug(f"Pre-return STATE:\n{STATE}")
# # BUG: if you enter the word guess field when it says next
# # word, it will guess it as the next
# return (
# current_text,
# STATE.player_points,
# STATE.lm_points,
# STATE.player_guess_str,
# STATE.get_lm_guess_display(remaining_attempts),
# remaining_attempts,
# "",
# "Guess!" if remaining_attempts else "Next word",
# )
|