marksverdhei
Add attempt counts
dfbce2c
raw
history blame
1.07 kB
from dataclasses import dataclass
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from src.constants import MAX_ATTEMPTS
@dataclass
class ProgramState:
current_word_index: int
player_guesses: list
player_points: int
lm_guesses: list
lm_points: int
def correct_guess(self):
# FIXME: not 1 for every point
self.player_points += 1
self.next_word()
def next_word(self):
self.current_word_index += 1
self.player_guesses = []
self.lm_guesses = []
@property
def player_guess_str(self):
return "\n".join(self.player_guesses)
def get_lm_guess_display(self, remaining_attempts: int) -> str:
return "\n".join(map(tokenizer.decode, self.lm_guesses[: MAX_ATTEMPTS - remaining_attempts]))
STATE = ProgramState(
current_word_index=20,
player_guesses=[],
lm_guesses=[],
player_points=0,
lm_points=0,
)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.eval()