marksverdhei commited on
Commit
0e7f280
Β·
1 Parent(s): dfbce2c

Update label to "next word"

Browse files
Files changed (2) hide show
  1. src/handler.py +37 -12
  2. src/interface.py +7 -6
src/handler.py CHANGED
@@ -47,45 +47,70 @@ def get_start_and_whitespace_tokens(word: str) -> tuple[int, int]:
47
  def handle_guess(
48
  text: str,
49
  remaining_attempts: int,
 
 
50
  ) -> str:
51
  """
52
- *
53
  * Retreives model predictions and compares the top 3 predicted tokens
54
  """
55
- logger.debug(f"Params:\ntext = {text}\nremaining_attempts = {remaining_attempts}\n")
 
 
 
 
 
56
  logger.debug(f"Initial STATE:\n{STATE}")
57
 
58
- current_tokens = all_tokens[: STATE.current_word_index]
59
  current_text = tokenizer.decode(current_tokens)
60
  player_guesses = ""
61
  lm_guesses = ""
62
- remaining_attempts -= 1
63
 
64
  if not text:
65
  logger.debug("Returning early")
66
- return (current_text, player_guesses, lm_guesses, remaining_attempts)
67
-
68
- next_token = all_tokens[STATE.current_word_index]
69
-
70
- if guess_is_correct(text, next_token):
71
- STATE.correct_guess()
 
 
 
 
72
 
73
  if remaining_attempts == 0:
74
  STATE.next_word()
75
  current_tokens = all_tokens[: STATE.current_word_index]
76
  remaining_attempts = MAX_ATTEMPTS
 
 
 
 
 
 
 
 
 
 
77
  else:
78
  STATE.player_guesses.append(text)
79
 
80
  # FIXME: unoptimized, computing all three every time
81
  current_text = tokenizer.decode(current_tokens)
82
- STATE.lm_guesses = get_model_predictions(current_text)[: 3 - remaining_attempts]
83
  logger.debug(f"lm_guesses: {tokenizer.decode(STATE.lm_guesses)}")
84
-
85
  logger.debug(f"Pre-return STATE:\n{STATE}")
 
 
 
86
  return (
87
  current_text,
 
 
88
  STATE.player_guess_str,
89
  STATE.get_lm_guess_display(remaining_attempts),
90
  remaining_attempts,
 
 
91
  )
 
47
  def handle_guess(
48
  text: str,
49
  remaining_attempts: int,
50
+ *args,
51
+ **kwargs,
52
  ) -> str:
53
  """
 
54
  * Retreives model predictions and compares the top 3 predicted tokens
55
  """
56
+ logger.debug(
57
+ f"Params:\ntext = {text}\n"
58
+ f"remaining_attempts = {remaining_attempts}\n"
59
+ f"args = {args}\n"
60
+ f"kwargs = {kwargs}\n"
61
+ )
62
  logger.debug(f"Initial STATE:\n{STATE}")
63
 
64
+ current_tokens = all_tokens[:STATE.current_word_index]
65
  current_text = tokenizer.decode(current_tokens)
66
  player_guesses = ""
67
  lm_guesses = ""
 
68
 
69
  if not text:
70
  logger.debug("Returning early")
71
+ return (
72
+ current_text,
73
+ STATE.player_points,
74
+ STATE.lm_points,
75
+ STATE.player_guess_str,
76
+ STATE.get_lm_guess_display(remaining_attempts),
77
+ remaining_attempts,
78
+ "",
79
+ "Guess!"
80
+ )
81
 
82
  if remaining_attempts == 0:
83
  STATE.next_word()
84
  current_tokens = all_tokens[: STATE.current_word_index]
85
  remaining_attempts = MAX_ATTEMPTS
86
+
87
+ remaining_attempts -= 1
88
+
89
+ next_token = all_tokens[STATE.current_word_index]
90
+
91
+ if guess_is_correct(text, next_token):
92
+ # STATE.correct_guess()
93
+ STATE.player_points += 1
94
+ remaining_attempts = 0
95
+
96
  else:
97
  STATE.player_guesses.append(text)
98
 
99
  # FIXME: unoptimized, computing all three every time
100
  current_text = tokenizer.decode(current_tokens)
101
+ STATE.lm_guesses = get_model_predictions(current_text)[: MAX_ATTEMPTS - remaining_attempts]
102
  logger.debug(f"lm_guesses: {tokenizer.decode(STATE.lm_guesses)}")
 
103
  logger.debug(f"Pre-return STATE:\n{STATE}")
104
+
105
+ # BUG: if you enter the word guess field when it says next
106
+ # word, it will guess it as the next
107
  return (
108
  current_text,
109
+ STATE.player_points,
110
+ STATE.lm_points,
111
  STATE.player_guess_str,
112
  STATE.get_lm_guess_display(remaining_attempts),
113
  remaining_attempts,
114
+ "",
115
+ "Guess!" if remaining_attempts else "Next word",
116
  )
src/interface.py CHANGED
@@ -7,7 +7,7 @@ from src.state import STATE
7
  from src.state import tokenizer
8
 
9
 
10
- def build_demo():
11
  with gr.Blocks() as demo:
12
  with gr.Row():
13
  gr.Markdown("<h1><center>Can you beat a language model?</center></h1>")
@@ -42,23 +42,24 @@ def build_demo():
42
 
43
  with gr.Row():
44
  with gr.Column():
45
- guess = gr.Textbox(label="")
46
  guess_button = gr.Button(value="Guess!")
47
 
48
- with gr.Row():
49
- next_word = gr.Button(value="Next word")
50
-
51
  guess_button.click(
52
  handle_guess,
53
  inputs=[
54
- guess,
55
  remaining_attempts,
56
  ],
57
  outputs=[
58
  prompt_text,
 
 
59
  current_guesses,
60
  lm_guesses,
61
  remaining_attempts,
 
 
62
  ],
63
  )
64
 
 
7
  from src.state import tokenizer
8
 
9
 
10
+ def build_demo():
11
  with gr.Blocks() as demo:
12
  with gr.Row():
13
  gr.Markdown("<h1><center>Can you beat a language model?</center></h1>")
 
42
 
43
  with gr.Row():
44
  with gr.Column():
45
+ guess_field = gr.Textbox(label="")
46
  guess_button = gr.Button(value="Guess!")
47
 
 
 
 
48
  guess_button.click(
49
  handle_guess,
50
  inputs=[
51
+ guess_field,
52
  remaining_attempts,
53
  ],
54
  outputs=[
55
  prompt_text,
56
+ player_points,
57
+ lm_points,
58
  current_guesses,
59
  lm_guesses,
60
  remaining_attempts,
61
+ guess_field,
62
+ guess_button,
63
  ],
64
  )
65