from dataclasses import dataclass from transformers import AutoModelForCausalLM from transformers import AutoTokenizer from src.constants import MAX_ATTEMPTS @dataclass class ProgramState: current_word_index: int player_guesses: list player_points: int lm_guesses: list lm_points: int def correct_guess(self): # FIXME: not 1 for every point self.player_points += 1 self.next_word() def next_word(self): self.current_word_index += 1 self.player_guesses = [] self.lm_guesses = [] @property def player_guess_str(self): return "\n".join(self.player_guesses) def get_lm_guess_display(self, remaining_attempts: int) -> str: return "\n".join(map(tokenizer.decode, self.lm_guesses[: MAX_ATTEMPTS - remaining_attempts])) STATE = ProgramState( current_word_index=20, player_guesses=[], lm_guesses=[], player_points=0, lm_points=0, ) tokenizer = AutoTokenizer.from_pretrained("gpt2") model = AutoModelForCausalLM.from_pretrained("gpt2") model.eval()