marksverdhei commited on
Commit
dfbce2c
β€’
1 Parent(s): b1e0f19

Add attempt counts

Browse files
Files changed (5) hide show
  1. src/constants.py +1 -0
  2. src/handler.py +48 -34
  3. src/interface.py +20 -7
  4. src/state.py +21 -3
  5. src/text.py +1 -1
src/constants.py ADDED
@@ -0,0 +1 @@
 
 
1
+ MAX_ATTEMPTS = 3
src/handler.py CHANGED
@@ -1,9 +1,11 @@
1
- import torch
2
  import logging
3
 
 
 
 
4
  from src.state import STATE
5
- from src.state import tokenizer
6
  from src.state import model
 
7
  from src.text import get_text
8
 
9
  logger = logging.getLogger(__name__)
@@ -25,53 +27,65 @@ def get_model_predictions(input_text: str) -> torch.Tensor:
25
  return top_3
26
 
27
 
28
- def handle_guess(text: str) -> str:
29
  """
30
- *
31
- * Retreives model predictions and compares the top 3 predicted tokens
32
  """
33
- current_tokens = all_tokens[:STATE.current_word_index]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  current_text = tokenizer.decode(current_tokens)
35
  player_guesses = ""
36
  lm_guesses = ""
37
- remaining_attempts = 3
38
 
39
  if not text:
40
- return (
41
- current_text,
42
- player_guesses,
43
- lm_guesses,
44
- remaining_attempts
45
- )
46
 
47
  next_token = all_tokens[STATE.current_word_index]
48
- predicted_token_start = tokenizer.encode(text, add_special_tokens=False)[0]
49
- predicted_token_whitespace = tokenizer.encode(". " + text, add_special_tokens=False)[1]
50
- logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([next_token])))
51
- logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
52
 
53
- guess_is_correct = next_token in (predicted_token_start, predicted_token_whitespace)
 
54
 
55
- if guess_is_correct or remaining_attempts == 0:
56
- STATE.current_word_index += 1
57
- current_tokens = all_tokens[:STATE.current_word_index]
58
- remaining_attempts = 3
59
- STATE.player_guesses = []
60
- STATE.lm_guesses = []
61
  else:
62
- remaining_attempts -= 1
63
- STATE.player_guesses.append(tokenizer.decode([predicted_token_whitespace]))
64
 
65
  # FIXME: unoptimized, computing all three every time
66
- STATE.lm_guesses = get_model_predictions(tokenizer.decode(current_tokens))[:3-remaining_attempts]
67
- logger.debug(f"lm_guesses: {tokenizer.decode(lm_guesses)}")
68
-
69
- player_guesses = "\n".join(STATE.player_guesses)
70
  current_text = tokenizer.decode(current_tokens)
 
 
71
 
 
72
  return (
73
  current_text,
74
- player_guesses,
75
- lm_guesses,
76
- remaining_attempts
77
- )
 
 
1
  import logging
2
 
3
+ import torch
4
+
5
+ from src.constants import MAX_ATTEMPTS
6
  from src.state import STATE
 
7
  from src.state import model
8
+ from src.state import tokenizer
9
  from src.text import get_text
10
 
11
  logger = logging.getLogger(__name__)
 
27
  return top_3
28
 
29
 
30
+ def guess_is_correct(text: str, next_token: int) -> bool:
31
  """
32
+ We check if the predicted token or a corresponding one with a leading whitespace
33
+ matches that of the next token
34
  """
35
+ logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([next_token])))
36
+ predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(text)
37
+ logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
38
+ return next_token in (predicted_token_start, predicted_token_whitespace)
39
+
40
+
41
+ def get_start_and_whitespace_tokens(word: str) -> tuple[int, int]:
42
+ predicted_token_start = tokenizer.encode(word, add_special_tokens=False)[0]
43
+ predicted_token_whitespace = tokenizer.encode(". " + word, add_special_tokens=False)[1]
44
+ return predicted_token_start, predicted_token_whitespace
45
+
46
+
47
+ def handle_guess(
48
+ text: str,
49
+ remaining_attempts: int,
50
+ ) -> str:
51
+ """
52
+ *
53
+ * Retreives model predictions and compares the top 3 predicted tokens
54
+ """
55
+ logger.debug(f"Params:\ntext = {text}\nremaining_attempts = {remaining_attempts}\n")
56
+ logger.debug(f"Initial STATE:\n{STATE}")
57
+
58
+ current_tokens = all_tokens[: STATE.current_word_index]
59
  current_text = tokenizer.decode(current_tokens)
60
  player_guesses = ""
61
  lm_guesses = ""
62
+ remaining_attempts -= 1
63
 
64
  if not text:
65
+ logger.debug("Returning early")
66
+ return (current_text, player_guesses, lm_guesses, remaining_attempts)
 
 
 
 
67
 
68
  next_token = all_tokens[STATE.current_word_index]
 
 
 
 
69
 
70
+ if guess_is_correct(text, next_token):
71
+ STATE.correct_guess()
72
 
73
+ if remaining_attempts == 0:
74
+ STATE.next_word()
75
+ current_tokens = all_tokens[: STATE.current_word_index]
76
+ remaining_attempts = MAX_ATTEMPTS
 
 
77
  else:
78
+ STATE.player_guesses.append(text)
 
79
 
80
  # FIXME: unoptimized, computing all three every time
 
 
 
 
81
  current_text = tokenizer.decode(current_tokens)
82
+ STATE.lm_guesses = get_model_predictions(current_text)[: 3 - remaining_attempts]
83
+ logger.debug(f"lm_guesses: {tokenizer.decode(STATE.lm_guesses)}")
84
 
85
+ logger.debug(f"Pre-return STATE:\n{STATE}")
86
  return (
87
  current_text,
88
+ STATE.player_guess_str,
89
+ STATE.get_lm_guess_display(remaining_attempts),
90
+ remaining_attempts,
91
+ )
src/interface.py CHANGED
@@ -1,5 +1,10 @@
1
  import gradio as gr
 
 
 
2
  from src.handler import handle_guess
 
 
3
 
4
 
5
  def build_demo():
@@ -14,7 +19,11 @@ def build_demo():
14
  "The one with the fewest guesses for a given word gets a point."
15
  )
16
  with gr.Row():
17
- prompt_text = gr.Textbox(label="Context", interactive=False)
 
 
 
 
18
  with gr.Row():
19
  with gr.Column():
20
  player_points = gr.Number(label="your points", interactive=False)
@@ -22,9 +31,13 @@ def build_demo():
22
  lm_points = gr.Number(label="LM points", interactive=False)
23
  with gr.Row():
24
  with gr.Column():
25
- remaining_attempts = gr.Number(label="Remaining attempts")
 
 
 
 
26
  current_guesses = gr.Textbox(label="Your guesses")
27
- with gr.Column():
28
  lm_guesses = gr.Textbox(label="LM guesses")
29
 
30
  with gr.Row():
@@ -37,7 +50,10 @@ def build_demo():
37
 
38
  guess_button.click(
39
  handle_guess,
40
- inputs=guess,
 
 
 
41
  outputs=[
42
  prompt_text,
43
  current_guesses,
@@ -63,6 +79,3 @@ def get_demo(wip=False):
63
  return wip_sign()
64
  else:
65
  return build_demo()
66
-
67
-
68
-
 
1
  import gradio as gr
2
+
3
+ from src.constants import MAX_ATTEMPTS
4
+ from src.handler import all_tokens
5
  from src.handler import handle_guess
6
+ from src.state import STATE
7
+ from src.state import tokenizer
8
 
9
 
10
  def build_demo():
 
19
  "The one with the fewest guesses for a given word gets a point."
20
  )
21
  with gr.Row():
22
+ prompt_text = gr.Textbox(
23
+ value=tokenizer.decode(all_tokens[: STATE.current_word_index]),
24
+ label="Context",
25
+ interactive=False,
26
+ )
27
  with gr.Row():
28
  with gr.Column():
29
  player_points = gr.Number(label="your points", interactive=False)
 
31
  lm_points = gr.Number(label="LM points", interactive=False)
32
  with gr.Row():
33
  with gr.Column():
34
+ remaining_attempts = gr.Number(
35
+ value=MAX_ATTEMPTS,
36
+ label="Remaining attempts",
37
+ precision=0,
38
+ )
39
  current_guesses = gr.Textbox(label="Your guesses")
40
+ with gr.Column():
41
  lm_guesses = gr.Textbox(label="LM guesses")
42
 
43
  with gr.Row():
 
50
 
51
  guess_button.click(
52
  handle_guess,
53
+ inputs=[
54
+ guess,
55
+ remaining_attempts,
56
+ ],
57
  outputs=[
58
  prompt_text,
59
  current_guesses,
 
79
  return wip_sign()
80
  else:
81
  return build_demo()
 
 
 
src/state.py CHANGED
@@ -1,9 +1,10 @@
1
  from dataclasses import dataclass
2
- from transformers import AutoTokenizer
3
  from transformers import AutoModelForCausalLM
 
4
 
 
5
 
6
- from dataclasses import dataclass
7
 
8
  @dataclass
9
  class ProgramState:
@@ -13,6 +14,23 @@ class ProgramState:
13
  lm_guesses: list
14
  lm_points: int
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  STATE = ProgramState(
18
  current_word_index=20,
@@ -24,4 +42,4 @@ STATE = ProgramState(
24
 
25
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
26
  model = AutoModelForCausalLM.from_pretrained("gpt2")
27
- model.eval()
 
1
  from dataclasses import dataclass
2
+
3
  from transformers import AutoModelForCausalLM
4
+ from transformers import AutoTokenizer
5
 
6
+ from src.constants import MAX_ATTEMPTS
7
 
 
8
 
9
  @dataclass
10
  class ProgramState:
 
14
  lm_guesses: list
15
  lm_points: int
16
 
17
+ def correct_guess(self):
18
+ # FIXME: not 1 for every point
19
+ self.player_points += 1
20
+ self.next_word()
21
+
22
+ def next_word(self):
23
+ self.current_word_index += 1
24
+ self.player_guesses = []
25
+ self.lm_guesses = []
26
+
27
+ @property
28
+ def player_guess_str(self):
29
+ return "\n".join(self.player_guesses)
30
+
31
+ def get_lm_guess_display(self, remaining_attempts: int) -> str:
32
+ return "\n".join(map(tokenizer.decode, self.lm_guesses[: MAX_ATTEMPTS - remaining_attempts]))
33
+
34
 
35
  STATE = ProgramState(
36
  current_word_index=20,
 
42
 
43
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
44
  model = AutoModelForCausalLM.from_pretrained("gpt2")
45
+ model.eval()
src/text.py CHANGED
@@ -13,4 +13,4 @@ Nine months later, he was named the 2009 Nobel Peace Prize laureate, a decision
13
 
14
 
15
  def get_text():
16
- return target_text
 
13
 
14
 
15
  def get_text():
16
+ return target_text