Spaces:
Runtime error
Runtime error
# from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM | |
from transformers import AutoTokenizer | |
target_text = """\ | |
Obama was born in Honolulu, Hawaii. | |
After graduating from Columbia University in 1983, he worked as a community organizer in Chicago. | |
In 1988, he enrolled in Harvard Law School, where he was the first black president of the Harvard Law Review. | |
After graduating, he became a civil rights attorney and an academic, teaching constitutional law at the University of Chicago Law School from 1992 to 2004. | |
Turning to elective politics, he represented the 13th district in the Illinois Senate from 1997 until 2004, when he ran for the U.S. | |
Senate. | |
Obama received national attention in 2004 with his March Senate primary win, his well-received July Democratic National Convention keynote address, and his landslide November election to the Senate. | |
In 2008, after a close primary campaign against Hillary Clinton, he was nominated by the Democratic Party for president and chose Joe Biden as his running mate. | |
Obama was elected over Republican nominee John McCain in the presidential election and was inaugurated on January 20, 2009. | |
Nine months later, he was named the 2009 Nobel Peace Prize laureate, a decision that drew a mixture of praise and criticism. | |
""" | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
model = AutoModelForCausalLM.from_pretrained("gpt2") | |
model.eval() | |
def get_next_word(text: str) -> str: | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
last_token = logits[0, -1] | |
top_3 = torch.topk(last_token, 3) | |
input_ids = list(inputs.input_ids.squeeze()) | |
argmax = top_3.indices[0] | |
input_ids.append(argmax) | |
return "!!!", tokenizer.decode(input_ids) | |
def build_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("<h1><center>Can you beat a language model?</center></h1>") | |
with gr.Row(): | |
prompt_text = gr.Markdown(target_text) | |
with gr.Row(): | |
with gr.Column(): | |
guess = gr.Textbox(label="Guess!") | |
guess_btn = gr.Button(value="Guess!") | |
with gr.Column(): | |
lm_guess = gr.Textbox(label="LM guess") | |
guess_btn.click(get_next_word, inputs=guess, outputs=[prompt_text, lm_guess], api_name="get_next_word") | |
# examples = gr.Examples( | |
# examples=["I went to the supermarket yesterday.", "Helen is a good swimmer."], inputs=[guess] | |
# ) | |
return demo | |
def wip_sign(): | |
with gr.Blocks() as demo: | |
gr.Markdown("<h1><center>Can you beat a language model?</center></h1>") | |
with gr.Row(): | |
gr.Markdown("<h1><center>βπ·ββοΈ Work in progress, come back later </center></h1>") | |
return demo | |
def main(): | |
# demo = build_demo() | |
demo = wip_sign() | |
demo.launch(debug=True) | |
if __name__ == "__main__": | |
main() | |