Spaces:
Runtime error
Runtime error
marksverdhei
commited on
Commit
β’
dfbce2c
1
Parent(s):
b1e0f19
Add attempt counts
Browse files- src/constants.py +1 -0
- src/handler.py +48 -34
- src/interface.py +20 -7
- src/state.py +21 -3
- src/text.py +1 -1
src/constants.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
MAX_ATTEMPTS = 3
|
src/handler.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
-
import torch
|
2 |
import logging
|
3 |
|
|
|
|
|
|
|
4 |
from src.state import STATE
|
5 |
-
from src.state import tokenizer
|
6 |
from src.state import model
|
|
|
7 |
from src.text import get_text
|
8 |
|
9 |
logger = logging.getLogger(__name__)
|
@@ -25,53 +27,65 @@ def get_model_predictions(input_text: str) -> torch.Tensor:
|
|
25 |
return top_3
|
26 |
|
27 |
|
28 |
-
def
|
29 |
"""
|
30 |
-
|
31 |
-
|
32 |
"""
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
current_text = tokenizer.decode(current_tokens)
|
35 |
player_guesses = ""
|
36 |
lm_guesses = ""
|
37 |
-
remaining_attempts
|
38 |
|
39 |
if not text:
|
40 |
-
|
41 |
-
|
42 |
-
player_guesses,
|
43 |
-
lm_guesses,
|
44 |
-
remaining_attempts
|
45 |
-
)
|
46 |
|
47 |
next_token = all_tokens[STATE.current_word_index]
|
48 |
-
predicted_token_start = tokenizer.encode(text, add_special_tokens=False)[0]
|
49 |
-
predicted_token_whitespace = tokenizer.encode(". " + text, add_special_tokens=False)[1]
|
50 |
-
logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([next_token])))
|
51 |
-
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
|
52 |
|
53 |
-
guess_is_correct
|
|
|
54 |
|
55 |
-
if
|
56 |
-
STATE.
|
57 |
-
current_tokens = all_tokens[:STATE.current_word_index]
|
58 |
-
remaining_attempts =
|
59 |
-
STATE.player_guesses = []
|
60 |
-
STATE.lm_guesses = []
|
61 |
else:
|
62 |
-
|
63 |
-
STATE.player_guesses.append(tokenizer.decode([predicted_token_whitespace]))
|
64 |
|
65 |
# FIXME: unoptimized, computing all three every time
|
66 |
-
STATE.lm_guesses = get_model_predictions(tokenizer.decode(current_tokens))[:3-remaining_attempts]
|
67 |
-
logger.debug(f"lm_guesses: {tokenizer.decode(lm_guesses)}")
|
68 |
-
|
69 |
-
player_guesses = "\n".join(STATE.player_guesses)
|
70 |
current_text = tokenizer.decode(current_tokens)
|
|
|
|
|
71 |
|
|
|
72 |
return (
|
73 |
current_text,
|
74 |
-
|
75 |
-
|
76 |
-
remaining_attempts
|
77 |
-
)
|
|
|
|
|
1 |
import logging
|
2 |
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from src.constants import MAX_ATTEMPTS
|
6 |
from src.state import STATE
|
|
|
7 |
from src.state import model
|
8 |
+
from src.state import tokenizer
|
9 |
from src.text import get_text
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
|
|
27 |
return top_3
|
28 |
|
29 |
|
30 |
+
def guess_is_correct(text: str, next_token: int) -> bool:
|
31 |
"""
|
32 |
+
We check if the predicted token or a corresponding one with a leading whitespace
|
33 |
+
matches that of the next token
|
34 |
"""
|
35 |
+
logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([next_token])))
|
36 |
+
predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(text)
|
37 |
+
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
|
38 |
+
return next_token in (predicted_token_start, predicted_token_whitespace)
|
39 |
+
|
40 |
+
|
41 |
+
def get_start_and_whitespace_tokens(word: str) -> tuple[int, int]:
|
42 |
+
predicted_token_start = tokenizer.encode(word, add_special_tokens=False)[0]
|
43 |
+
predicted_token_whitespace = tokenizer.encode(". " + word, add_special_tokens=False)[1]
|
44 |
+
return predicted_token_start, predicted_token_whitespace
|
45 |
+
|
46 |
+
|
47 |
+
def handle_guess(
|
48 |
+
text: str,
|
49 |
+
remaining_attempts: int,
|
50 |
+
) -> str:
|
51 |
+
"""
|
52 |
+
*
|
53 |
+
* Retreives model predictions and compares the top 3 predicted tokens
|
54 |
+
"""
|
55 |
+
logger.debug(f"Params:\ntext = {text}\nremaining_attempts = {remaining_attempts}\n")
|
56 |
+
logger.debug(f"Initial STATE:\n{STATE}")
|
57 |
+
|
58 |
+
current_tokens = all_tokens[: STATE.current_word_index]
|
59 |
current_text = tokenizer.decode(current_tokens)
|
60 |
player_guesses = ""
|
61 |
lm_guesses = ""
|
62 |
+
remaining_attempts -= 1
|
63 |
|
64 |
if not text:
|
65 |
+
logger.debug("Returning early")
|
66 |
+
return (current_text, player_guesses, lm_guesses, remaining_attempts)
|
|
|
|
|
|
|
|
|
67 |
|
68 |
next_token = all_tokens[STATE.current_word_index]
|
|
|
|
|
|
|
|
|
69 |
|
70 |
+
if guess_is_correct(text, next_token):
|
71 |
+
STATE.correct_guess()
|
72 |
|
73 |
+
if remaining_attempts == 0:
|
74 |
+
STATE.next_word()
|
75 |
+
current_tokens = all_tokens[: STATE.current_word_index]
|
76 |
+
remaining_attempts = MAX_ATTEMPTS
|
|
|
|
|
77 |
else:
|
78 |
+
STATE.player_guesses.append(text)
|
|
|
79 |
|
80 |
# FIXME: unoptimized, computing all three every time
|
|
|
|
|
|
|
|
|
81 |
current_text = tokenizer.decode(current_tokens)
|
82 |
+
STATE.lm_guesses = get_model_predictions(current_text)[: 3 - remaining_attempts]
|
83 |
+
logger.debug(f"lm_guesses: {tokenizer.decode(STATE.lm_guesses)}")
|
84 |
|
85 |
+
logger.debug(f"Pre-return STATE:\n{STATE}")
|
86 |
return (
|
87 |
current_text,
|
88 |
+
STATE.player_guess_str,
|
89 |
+
STATE.get_lm_guess_display(remaining_attempts),
|
90 |
+
remaining_attempts,
|
91 |
+
)
|
src/interface.py
CHANGED
@@ -1,5 +1,10 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
2 |
from src.handler import handle_guess
|
|
|
|
|
3 |
|
4 |
|
5 |
def build_demo():
|
@@ -14,7 +19,11 @@ def build_demo():
|
|
14 |
"The one with the fewest guesses for a given word gets a point."
|
15 |
)
|
16 |
with gr.Row():
|
17 |
-
prompt_text = gr.Textbox(
|
|
|
|
|
|
|
|
|
18 |
with gr.Row():
|
19 |
with gr.Column():
|
20 |
player_points = gr.Number(label="your points", interactive=False)
|
@@ -22,9 +31,13 @@ def build_demo():
|
|
22 |
lm_points = gr.Number(label="LM points", interactive=False)
|
23 |
with gr.Row():
|
24 |
with gr.Column():
|
25 |
-
remaining_attempts = gr.Number(
|
|
|
|
|
|
|
|
|
26 |
current_guesses = gr.Textbox(label="Your guesses")
|
27 |
-
with gr.Column():
|
28 |
lm_guesses = gr.Textbox(label="LM guesses")
|
29 |
|
30 |
with gr.Row():
|
@@ -37,7 +50,10 @@ def build_demo():
|
|
37 |
|
38 |
guess_button.click(
|
39 |
handle_guess,
|
40 |
-
inputs=
|
|
|
|
|
|
|
41 |
outputs=[
|
42 |
prompt_text,
|
43 |
current_guesses,
|
@@ -63,6 +79,3 @@ def get_demo(wip=False):
|
|
63 |
return wip_sign()
|
64 |
else:
|
65 |
return build_demo()
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
1 |
import gradio as gr
|
2 |
+
|
3 |
+
from src.constants import MAX_ATTEMPTS
|
4 |
+
from src.handler import all_tokens
|
5 |
from src.handler import handle_guess
|
6 |
+
from src.state import STATE
|
7 |
+
from src.state import tokenizer
|
8 |
|
9 |
|
10 |
def build_demo():
|
|
|
19 |
"The one with the fewest guesses for a given word gets a point."
|
20 |
)
|
21 |
with gr.Row():
|
22 |
+
prompt_text = gr.Textbox(
|
23 |
+
value=tokenizer.decode(all_tokens[: STATE.current_word_index]),
|
24 |
+
label="Context",
|
25 |
+
interactive=False,
|
26 |
+
)
|
27 |
with gr.Row():
|
28 |
with gr.Column():
|
29 |
player_points = gr.Number(label="your points", interactive=False)
|
|
|
31 |
lm_points = gr.Number(label="LM points", interactive=False)
|
32 |
with gr.Row():
|
33 |
with gr.Column():
|
34 |
+
remaining_attempts = gr.Number(
|
35 |
+
value=MAX_ATTEMPTS,
|
36 |
+
label="Remaining attempts",
|
37 |
+
precision=0,
|
38 |
+
)
|
39 |
current_guesses = gr.Textbox(label="Your guesses")
|
40 |
+
with gr.Column():
|
41 |
lm_guesses = gr.Textbox(label="LM guesses")
|
42 |
|
43 |
with gr.Row():
|
|
|
50 |
|
51 |
guess_button.click(
|
52 |
handle_guess,
|
53 |
+
inputs=[
|
54 |
+
guess,
|
55 |
+
remaining_attempts,
|
56 |
+
],
|
57 |
outputs=[
|
58 |
prompt_text,
|
59 |
current_guesses,
|
|
|
79 |
return wip_sign()
|
80 |
else:
|
81 |
return build_demo()
|
|
|
|
|
|
src/state.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
-
|
3 |
from transformers import AutoModelForCausalLM
|
|
|
4 |
|
|
|
5 |
|
6 |
-
from dataclasses import dataclass
|
7 |
|
8 |
@dataclass
|
9 |
class ProgramState:
|
@@ -13,6 +14,23 @@ class ProgramState:
|
|
13 |
lm_guesses: list
|
14 |
lm_points: int
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
STATE = ProgramState(
|
18 |
current_word_index=20,
|
@@ -24,4 +42,4 @@ STATE = ProgramState(
|
|
24 |
|
25 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
26 |
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
27 |
-
model.eval()
|
|
|
1 |
from dataclasses import dataclass
|
2 |
+
|
3 |
from transformers import AutoModelForCausalLM
|
4 |
+
from transformers import AutoTokenizer
|
5 |
|
6 |
+
from src.constants import MAX_ATTEMPTS
|
7 |
|
|
|
8 |
|
9 |
@dataclass
|
10 |
class ProgramState:
|
|
|
14 |
lm_guesses: list
|
15 |
lm_points: int
|
16 |
|
17 |
+
def correct_guess(self):
|
18 |
+
# FIXME: not 1 for every point
|
19 |
+
self.player_points += 1
|
20 |
+
self.next_word()
|
21 |
+
|
22 |
+
def next_word(self):
|
23 |
+
self.current_word_index += 1
|
24 |
+
self.player_guesses = []
|
25 |
+
self.lm_guesses = []
|
26 |
+
|
27 |
+
@property
|
28 |
+
def player_guess_str(self):
|
29 |
+
return "\n".join(self.player_guesses)
|
30 |
+
|
31 |
+
def get_lm_guess_display(self, remaining_attempts: int) -> str:
|
32 |
+
return "\n".join(map(tokenizer.decode, self.lm_guesses[: MAX_ATTEMPTS - remaining_attempts]))
|
33 |
+
|
34 |
|
35 |
STATE = ProgramState(
|
36 |
current_word_index=20,
|
|
|
42 |
|
43 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
44 |
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
45 |
+
model.eval()
|
src/text.py
CHANGED
@@ -13,4 +13,4 @@ Nine months later, he was named the 2009 Nobel Peace Prize laureate, a decision
|
|
13 |
|
14 |
|
15 |
def get_text():
|
16 |
-
return target_text
|
|
|
13 |
|
14 |
|
15 |
def get_text():
|
16 |
+
return target_text
|