Spaces:
Runtime error
Runtime error
File size: 2,278 Bytes
850b0e4 8d86ca5 850b0e4 8d86ca5 850b0e4 8d86ca5 850b0e4 8d86ca5 7a7833f 850b0e4 8d86ca5 d59d1e6 8d86ca5 d59d1e6 141b1fb d59d1e6 850b0e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import gradio as gr
import torch
from mario_gpt.dataset import MarioDataset
from mario_gpt.prompter import Prompter
from mario_gpt.lm import MarioLM
from mario_gpt.utils import view_level, convert_level_to_png
mario_lm = MarioLM()
device = torch.device('cuda')
mario_lm = mario_lm.to(device)
TILE_DIR = "data/tiles"
def update(pipes, enemies, blocks, elevation, temperature = 2.0, level_size = 1399, prompt = ""):
if prompt == "":
prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"
print(f"Using prompt: {prompt}")
prompts = [prompt]
generated_level = mario_lm.sample(
prompts=prompts,
num_steps=level_size,
temperature=temperature,
use_tqdm=True
)
img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
return img
with gr.Blocks() as demo:
gr.Markdown("## Demo for ['MarioGPT: Open-Ended Text2Level Generation through Large Language Models'](https://github.com/shyamsn97/mario-gpt). Enter a text prompt or select parameters from below!")
text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
level_image = gr.Image()
btn = gr.Button("Generate level")
pipes = gr.Radio(["no", "little", "some", "many"], label="pipes")
enemies = gr.Radio(["no", "little", "some", "many"], label="enemies")
blocks = gr.Radio(["little", "some", "many"], label="blocks")
elevation = gr.Radio(["low", "high"], label="elevation")
temperature = gr.Number(value=2.0, label="temperature: Increase these for more stochastic, but lower quality, generations")
level_size = gr.Number(value=1399, precision=0, label="level_size")
btn.click(fn=update, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=level_image)
gr.Examples(
examples=[
["many", "many", "some", "high"],
["no", "some", "many", "high", 2.0],
["many", "many", "little", "low", 2.0],
["no", "no", "many", "high", 2.4],
],
inputs=[pipes, enemies, blocks, elevation],
outputs=level_image,
fn=update,
cache_examples=True,
)
demo.launch()
|