File size: 2,278 Bytes
850b0e4
 
 
 
 
 
 
 
 
 
 
 
 
8d86ca5
 
 
 
 
 
850b0e4
 
 
8d86ca5
 
850b0e4
 
 
8d86ca5
850b0e4
 
8d86ca5
 
 
 
7a7833f
850b0e4
8d86ca5
 
 
 
 
 
 
 
 
d59d1e6
8d86ca5
 
 
 
 
 
 
d59d1e6
 
141b1fb
d59d1e6
850b0e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
import torch
from mario_gpt.dataset import MarioDataset
from mario_gpt.prompter import Prompter
from mario_gpt.lm import MarioLM
from mario_gpt.utils import view_level, convert_level_to_png

mario_lm = MarioLM()

device = torch.device('cuda')
mario_lm = mario_lm.to(device)
TILE_DIR = "data/tiles"



def update(pipes, enemies, blocks, elevation, temperature = 2.0, level_size = 1399, prompt = ""):
    if prompt == "":
        prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"
    print(f"Using prompt: {prompt}")
    prompts = [prompt]
    generated_level = mario_lm.sample(
        prompts=prompts,
        num_steps=level_size,
        temperature=temperature,
        use_tqdm=True
    )
    img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
    return img

with gr.Blocks() as demo:
    gr.Markdown("## Demo for ['MarioGPT: Open-Ended Text2Level Generation through Large Language Models'](https://github.com/shyamsn97/mario-gpt). Enter a text prompt or select parameters from below!")

    text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")

    level_image = gr.Image()
    btn = gr.Button("Generate level")

    pipes = gr.Radio(["no", "little", "some", "many"], label="pipes")
    enemies = gr.Radio(["no", "little", "some", "many"], label="enemies")
    blocks = gr.Radio(["little", "some", "many"], label="blocks")
    elevation = gr.Radio(["low", "high"], label="elevation")
    temperature = gr.Number(value=2.0, label="temperature: Increase these for more stochastic, but lower quality, generations")
    level_size = gr.Number(value=1399, precision=0, label="level_size")

    btn.click(fn=update, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=level_image)
    gr.Examples(
        examples=[
            ["many", "many", "some", "high"],
            ["no", "some", "many", "high", 2.0],
            ["many", "many", "little", "low", 2.0],
            ["no", "no", "many", "high", 2.4],
        ],
        inputs=[pipes, enemies, blocks, elevation],
        outputs=level_image,
        fn=update,
        cache_examples=True,
    )
demo.launch()