NCERL-Diverse-PCG / src /smb /asyncsimlt.py
baiyanlali-zhao's picture
添加注释
3582c8a
raw
history blame
6.37 kB
import time
import importlib
import multiprocessing as mp
from src.smb.level import *
from src.smb.proxy import MarioProxy
from queue import Queue, Full as FullExpection
def _simlt_worker(remote, parent_remote, rfunc, resource):
rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(rfunc)()
# W = MarioLevel.seg_width
simulator = MarioProxy()
parent_remote.close()
refs = [MarioLevel(lvl) for lvl in resource.get('refs', [])]
simlt_k = 150 if resource.get('test', False) else 100
while True:
try:
cmd, data = remote.recv()
if cmd == 'evaluate':
tid, strlvl = data
lvl = MarioLevel(strlvl)
segs = lvl.to_segs()
simlt_res = MarioProxy.get_seg_infos(simulator.simulate_complete(lvl, segTimeK=simlt_k))
rewards = rfunc.get_rewards(segs=segs, simlt_res=simlt_res)
remote.send((tid, rewards))
pass
elif cmd == 'close':
remote.close()
break
elif cmd == 'check_playable':
strlvl, item = data
lvl = MarioLevel(strlvl)
standable = False
for i in range(lvl.h):
if lvl[i,0] in MarioLevel.solidset:
standable = True
break
if standable:
simlt_res = simulator.simulate_game(lvl)
playable = simlt_res['status'] == 'WIN'
remote.send((playable, item))
else:
remote.send((False, item))
elif cmd == 'mnd_item':
# strlvl = data
p = MarioLevel(data)
min_hm, min_dtw = float('inf'), float('inf')
for q in refs:
vhm = hamming_dis(p, q)
vdtw = lvl_dtw(p, q)
if vhm > 0:
min_hm = min(min_hm, vhm)
if vdtw > 0:
min_dtw = min(min_dtw, vdtw)
remote.send((min_hm, min_dtw))
elif cmd == 'mpd':
hms, dtws = [], []
for strlvl1, strlvl2 in data:
lvl1, lvl2 = MarioLevel(strlvl1), MarioLevel(strlvl2)
hms.append(hamming_dis(lvl1, lvl2))
remote.send((hms, None))
else:
raise KeyError(f'Unknown command for simulation worker: {cmd}')
except EOFError:
break
pass
class AsycSimltPool:
"""
异步池, 用于多进程马里奥模拟任务
"""
def __init__(self, poolsize, queuesize=None, rfunc_name='default', verbose=True, **rsrc):
self.np, self.nq = poolsize, poolsize if queuesize is None else queuesize
self.waiting_queue = Queue(self.nq)
self.ready = [True] * poolsize
resource = {'rfunc': 'default'}
resource.update(rsrc)
self.__init_remotes(rfunc_name, resource)
self.res_buffer = []
self.histlen = AsycSimltPool.get_histlen(rfunc_name)
self.verbose = verbose
@staticmethod
def get_histlen(rfunc_name):
rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(rfunc_name)()
return rfunc.get_n()
pass
def put(self, cmd, args):
"""
Put a new evaluation task into the pool. If the pool and waiting queue is full,
wait until a process is free
"""
putted = False
for i, remote in enumerate(self.remotes):
if self.ready[i]:
remote.send((cmd, args))
self.ready[i] = False
putted = True
break
while not putted:
try:
self.waiting_queue.put((cmd, args), timeout=0.01)
putted = True
except FullExpection:
self.refresh()
def get(self, wait=False):
if wait:
self.__wait()
self.refresh()
occp, occq = self.get_occupied()
if self.verbose:
print(f'Workers: {occp}/{self.np}, Queue: {occq}/{self.nq}, Buffer: {len(self.res_buffer)}')
res = self.res_buffer
self.res_buffer = []
return res
def get_occupied(self):
process_occupied = sum(0 if r else 1 for r in self.ready)
return process_occupied, self.waiting_queue.qsize()
def refresh(self):
""" Recive ready results and cache them in buffer, then assign tasks in waiting queue to free workers """
for i, remote in enumerate(self.remotes):
if remote.poll():
self.res_buffer.append(remote.recv())
self.ready[i] = True
for i, remote in enumerate(self.remotes):
if self.waiting_queue.empty():
break
if self.ready[i]:
cmd, args = self.waiting_queue.get()
remote.send((cmd, args))
self.ready[i] = False
def blocking(self):
self.refresh()
return self.waiting_queue.full()
def __init_remotes(self, rfunc, resource):
forkserver_available = "forkserver" in mp.get_all_start_methods()
start_method = "forkserver" if forkserver_available else "spawn"
ctx = mp.get_context(start_method)
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(self.np)])
self.processes = []
for work_remote, remote in zip(self.work_remotes, self.remotes):
args = (work_remote, remote, rfunc, resource)
# daemon=True: if the main process crashes, we should not cause things to hang
# 开启多进程来做异步计算
process = ctx.Process(target=_simlt_worker, args=args, daemon=True) # pytype:disable=attribute-error
process.start()
self.processes.append(process)
work_remote.close()
def __wait(self):
finish = False
while not finish:
self.refresh()
finish = all(r for r in self.ready)
time.sleep(0.01)
def close(self):
res = self.get(True)
for remote, p in zip(self.remotes, self.processes):
remote.send(('close', None))
p.join()
return res