Spaces:
Runtime error
Runtime error
import logging | |
from dataclasses import dataclass | |
from transformers import AutoModelForCausalLM | |
from transformers import AutoTokenizer | |
from src.constants import MAX_ATTEMPTS | |
logger = logging.getLogger(__name__) | |
class ProgramState: | |
current_word_index: int | |
player_guesses: list[str] | |
player_points: int | |
lm_guesses: list[int] | |
lm_points: int | |
remaining_attempts: int | |
button_label: str | |
def player_guess_str(self): | |
return "\n".join(self.player_guesses) | |
def lm_guess_str(self): | |
strings = list(map(tokenizer.decode, self.lm_guesses)) | |
logger.debug(strings) | |
n_censored = self.remaining_attempts | |
for i in range(1, n_censored + 1): | |
strings[-i] = "****" | |
logger.debug(strings) | |
return "\n".join(strings) | |
def next_word(self): | |
self.current_word_index += 1 | |
self.player_guesses = [] | |
self.lm_guesses = [] # TODO: make guesses? | |
self.button_label = "Guess!" | |
def get_tuple( | |
self, | |
prompt_text=None, | |
player_points=None, | |
lm_points=None, | |
player_guess_str=None, | |
lm_guess_str=None, | |
remaining_attempts=None, | |
text_field=None, | |
button_label=None, | |
bottom_html=None, | |
) -> tuple: | |
return ( | |
prompt_text or "", # FIXME | |
player_points or self.player_points, | |
lm_points or self.lm_points, | |
player_guess_str or self.player_guess_str, | |
lm_guess_str or self.lm_guess_str, | |
remaining_attempts or self.remaining_attempts, | |
text_field or "", # FIXME | |
button_label or self.button_label, | |
bottom_html or "", # FIXME | |
) | |
STATE = ProgramState( | |
current_word_index=20, | |
player_guesses=[], | |
lm_guesses=[], | |
player_points=0, | |
lm_points=0, | |
remaining_attempts=MAX_ATTEMPTS, | |
button_label="Guess!", | |
) | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
model = AutoModelForCausalLM.from_pretrained("gpt2") | |
model.eval() | |