Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from mario_gpt.dataset import MarioDataset | |
from mario_gpt.prompter import Prompter | |
from mario_gpt.lm import MarioLM | |
from mario_gpt.utils import view_level, convert_level_to_png | |
mario_lm = MarioLM() | |
device = torch.device('cuda') | |
mario_lm = mario_lm.to(device) | |
TILE_DIR = "data/tiles" | |
def update(pipes, enemies, blocks, elevation, temperature = 2.0, level_size = 1399, prompt = ""): | |
if prompt == "": | |
prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation" | |
print(f"Using prompt: {prompt}") | |
prompts = [prompt] | |
generated_level = mario_lm.sample( | |
prompts=prompts, | |
num_steps=level_size, | |
temperature=temperature, | |
use_tqdm=True | |
) | |
img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0] | |
return img | |
with gr.Blocks() as demo: | |
gr.Markdown("## Demo for ['MarioGPT: Open-Ended Text2Level Generation through Large Language Models'](https://github.com/shyamsn97/mario-gpt). Enter a text prompt or select parameters from below!") | |
text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'") | |
level_image = gr.Image() | |
btn = gr.Button("Generate level") | |
pipes = gr.Radio(["no", "little", "some", "many"], label="pipes") | |
enemies = gr.Radio(["no", "little", "some", "many"], label="enemies") | |
blocks = gr.Radio(["little", "some", "many"], label="blocks") | |
elevation = gr.Radio(["low", "high"], label="elevation") | |
temperature = gr.Number(value=2.0, label="temperature: Increase these for more stochastic, but lower quality, generations") | |
level_size = gr.Number(value=1399, precision=0, label="level_size") | |
btn.click(fn=update, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=level_image) | |
gr.Examples( | |
examples=[ | |
["many", "many", "some", "high"], | |
["no", "some", "many", "high", 2.0], | |
["many", "many", "little", "low", 2.0], | |
["no", "no", "many", "high", 2.4], | |
], | |
inputs=[pipes, enemies, blocks, elevation], | |
outputs=level_image, | |
fn=update, | |
cache_examples=True, | |
) | |
demo.launch() | |