Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
from transformers import AutoModelForCausalLM | |
from transformers import AutoTokenizer | |
from src.constants import MAX_ATTEMPTS | |
class ProgramState: | |
current_word_index: int | |
player_guesses: list | |
player_points: int | |
lm_guesses: list | |
lm_points: int | |
def correct_guess(self): | |
# FIXME: not 1 for every point | |
self.player_points += 1 | |
self.next_word() | |
def next_word(self): | |
self.current_word_index += 1 | |
self.player_guesses = [] | |
self.lm_guesses = [] | |
def player_guess_str(self): | |
return "\n".join(self.player_guesses) | |
def get_lm_guess_display(self, remaining_attempts: int) -> str: | |
return "\n".join(map(tokenizer.decode, self.lm_guesses[: MAX_ATTEMPTS - remaining_attempts])) | |
STATE = ProgramState( | |
current_word_index=20, | |
player_guesses=[], | |
lm_guesses=[], | |
player_points=0, | |
lm_points=0, | |
) | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
model = AutoModelForCausalLM.from_pretrained("gpt2") | |
model.eval() | |