Spaces:
Runtime error
Runtime error
File size: 1,074 Bytes
614d543 dfbce2c 614d543 dfbce2c 614d543 dfbce2c 614d543 dfbce2c 614d543 dfbce2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from dataclasses import dataclass
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from src.constants import MAX_ATTEMPTS
@dataclass
class ProgramState:
current_word_index: int
player_guesses: list
player_points: int
lm_guesses: list
lm_points: int
def correct_guess(self):
# FIXME: not 1 for every point
self.player_points += 1
self.next_word()
def next_word(self):
self.current_word_index += 1
self.player_guesses = []
self.lm_guesses = []
@property
def player_guess_str(self):
return "\n".join(self.player_guesses)
def get_lm_guess_display(self, remaining_attempts: int) -> str:
return "\n".join(map(tokenizer.decode, self.lm_guesses[: MAX_ATTEMPTS - remaining_attempts]))
STATE = ProgramState(
current_word_index=20,
player_guesses=[],
lm_guesses=[],
player_points=0,
lm_points=0,
)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.eval()
|