Spaces:
Sleeping
Sleeping
import random | |
import gradio as gr | |
import os | |
from os import path | |
import sys | |
sys.path.append(path.dirname(path.abspath(__file__))) | |
from src.olgen.ol_generator import VecOnlineGenerator | |
from src.olgen.olg_policy import RLGenPolicy | |
from src.smb.level import save_batch | |
from src.utils.filesys import getpath | |
from src.utils.img import make_img_sheet | |
import torch | |
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
def generate_and_play(): | |
path = 'models/example_policy' | |
# 使用example policy做生成 | |
N, L = 8, 10 | |
plc = RLGenPolicy.from_path(path, device) | |
generator = VecOnlineGenerator(plc, g_device=device) | |
fd, _ = os.path.split(getpath(path)) | |
os.makedirs(fd, exist_ok=True) | |
lvls = generator.generate(N, L) | |
imgs = [lvl.to_img() for lvl in lvls] | |
return imgs | |
with gr.Blocks(title="NCERL Demo") as demo: | |
gallery = gr.Gallery( | |
label="Generated images", show_label=False, elem_id="gallery" | |
, columns=[3], rows=[1], object_fit="contain", height="auto") | |
btn = gr.Button("Generate levels", scale=0) | |
btn.click(generate_and_play, None, gallery) | |
if __name__ == "__main__": | |
demo.launch() |