File size: 5,269 Bytes
614d543
 
dfbce2c
 
 
614d543
 
dfbce2c
614d543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7eee83c
614d543
dfbce2c
 
614d543
7eee83c
 
dfbce2c
 
7eee83c
 
 
 
 
 
 
 
 
 
 
dfbce2c
 
 
 
 
 
 
 
7eee83c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfbce2c
 
0e7f280
 
dfbce2c
 
 
 
7eee83c
dfbce2c
 
7eee83c
 
614d543
 
7eee83c
0e7f280
7eee83c
 
 
 
 
 
 
 
 
 
 
 
 
614d543
7eee83c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import logging

import torch

from src.constants import MAX_ATTEMPTS
from src.state import STATE
from src.state import model
from src.state import tokenizer
from src.text import get_text

logger = logging.getLogger(__name__)

all_tokens = tokenizer.encode(get_text())


def get_model_predictions(input_text: str) -> torch.Tensor:
    """
    Returns the indices as a torch tensor of the top 3 predicted tokens.
    """
    inputs = tokenizer(input_text, return_tensors="pt")

    with torch.no_grad():
        logits = model(**inputs).logits

    last_token = logits[0, -1]
    top_3 = torch.topk(last_token, 3).indices.tolist()
    return top_3


def guess_is_correct(text: str) -> bool:
    """
    We check if the predicted token or a corresponding one with a leading whitespace
    matches that of the next token
    """
    current_target = all_tokens[STATE.current_word_index]
    logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target])))
    predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(text)
    logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
    return current_target in (predicted_token_start, predicted_token_whitespace)


def lm_is_correct() -> bool:
    # NOTE: out of range if remaining attempts is 0
    if STATE.remaining_attempts > 1:
        return False

    current_guess = STATE.lm_guesses[MAX_ATTEMPTS - STATE.remaining_attempts]
    current_target = all_tokens[STATE.current_word_index]
    return current_guess == current_target


def get_start_and_whitespace_tokens(word: str) -> tuple[int, int]:
    predicted_token_start = tokenizer.encode(word, add_special_tokens=False)[0]
    predicted_token_whitespace = tokenizer.encode(". " + word, add_special_tokens=False)[1]
    return predicted_token_start, predicted_token_whitespace


def get_current_text():
    return tokenizer.decode(all_tokens[: STATE.current_word_index])


def handle_player_win():
    # TODO: point system
    points = 1
    STATE.player_points += points
    STATE.button_label = "Next word"
    return STATE.get_tuple(
        get_current_text(),
        bottom_html=f"Player gets {points} point!",
    )


def handle_lm_win():
    points = 1
    STATE.lm_points += points
    STATE.button_label = "Next word"
    return STATE.get_tuple(
        get_current_text(),
        bottom_html=f"GPT2 gets {points} point!",
    )


def handle_out_of_attempts():
    STATE.button_label = "Next word"
    return STATE.get_tuple(
        get_current_text(),
        bottom_html="Out of attempts. No one gets points!",
    )


def handle_tie():
    STATE.button_label = "Next word"
    return STATE.get_tuple(
        get_current_text(),
        bottom_html="TIE! No one gets points!",
    )


def handle_next_attempt():
    STATE.remaining_attempts -= 1
    return STATE.get_tuple(
        get_current_text(), bottom_html=f"That was not it... {STATE.remaining_attempts} attempts left"
    )


def handle_no_input():
    return STATE.get_tuple(
        get_current_text(),
        bottom_html="Please write something",
    )


def handle_next_word():
    STATE.next_word()
    STATE.lm_guesses = get_model_predictions(get_current_text())
    return STATE.get_tuple()


def handle_guess(
    text: str,
    *args,
    **kwargs,
) -> str:
    """
    * Retreives model predictions and compares the top 3 predicted tokens
    """
    logger.debug("Params:\n" f"text = {text}\n" f"args = {args}\n" f"kwargs = {kwargs}\n")
    logger.debug(f"Initial STATE:\n{STATE}")

    if STATE.button_label == "Next word":
        return handle_next_word()

    if not text:
        return handle_no_input()

    STATE.player_guesses.append(text)

    player_correct = guess_is_correct(text)
    lm_correct = lm_is_correct()

    if player_correct and lm_correct:
        return handle_tie()
    elif player_correct and not lm_correct:
        return handle_player_win()
    elif lm_correct and not player_correct:
        return handle_lm_win()
    elif STATE.remaining_attempts == 0:
        return handle_out_of_attempts()
    else:
        return handle_next_attempt()


STATE.lm_guesses = get_model_predictions(get_current_text())


#     # STATE.correct_guess()
#     # remaining_attempts = 0
# # elif lm_guess_is_correct():
# #     pass
# else:
#     return handle_incorrect_guess()
# # elif remaining_attempts == 0:
# #     return handle_out_of_attempts()

#     remaining_attempts -= 1
#     STATE.player_guesses.append(text)

# if remaining_attempts == 0:
#     STATE.next_word()
#     current_tokens = all_tokens[: STATE.current_word_index]
#     remaining_attempts = MAX_ATTEMPTS

# # FIXME: unoptimized, computing all three every time
# current_text = tokenizer.decode(current_tokens)
# logger.debug(f"lm_guesses: {tokenizer.decode(STATE.lm_guesses)}")
# logger.debug(f"Pre-return STATE:\n{STATE}")

# # BUG: if you enter the word guess field when it says next
# #      word, it will guess it as the next
# return (
#     current_text,
#     STATE.player_points,
#     STATE.lm_points,
#     STATE.player_guess_str,
#     STATE.get_lm_guess_display(remaining_attempts),
#     remaining_attempts,
#     "",
#     "Guess!" if remaining_attempts else "Next word",
# )