marksverdhei
Add more handlers
7eee83c
raw
history blame
2.08 kB
import logging
from dataclasses import dataclass
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from src.constants import MAX_ATTEMPTS
logger = logging.getLogger(__name__)
@dataclass
class ProgramState:
current_word_index: int
player_guesses: list[str]
player_points: int
lm_guesses: list[int]
lm_points: int
remaining_attempts: int
button_label: str
@property
def player_guess_str(self):
return "\n".join(self.player_guesses)
@property
def lm_guess_str(self):
strings = list(map(tokenizer.decode, self.lm_guesses))
logger.debug(strings)
n_censored = self.remaining_attempts
for i in range(1, n_censored + 1):
strings[-i] = "****"
logger.debug(strings)
return "\n".join(strings)
def next_word(self):
self.current_word_index += 1
self.player_guesses = []
self.lm_guesses = [] # TODO: make guesses?
self.button_label = "Guess!"
def get_tuple(
self,
prompt_text=None,
player_points=None,
lm_points=None,
player_guess_str=None,
lm_guess_str=None,
remaining_attempts=None,
text_field=None,
button_label=None,
bottom_html=None,
) -> tuple:
return (
prompt_text or "", # FIXME
player_points or self.player_points,
lm_points or self.lm_points,
player_guess_str or self.player_guess_str,
lm_guess_str or self.lm_guess_str,
remaining_attempts or self.remaining_attempts,
text_field or "", # FIXME
button_label or self.button_label,
bottom_html or "", # FIXME
)
STATE = ProgramState(
current_word_index=20,
player_guesses=[],
lm_guesses=[],
player_points=0,
lm_points=0,
remaining_attempts=MAX_ATTEMPTS,
button_label="Guess!",
)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.eval()