Spaces:
Sleeping
Sleeping
import time | |
import importlib | |
import multiprocessing as mp | |
from src.smb.level import * | |
from src.smb.proxy import MarioProxy | |
from queue import Queue, Full as FullExpection | |
def _simlt_worker(remote, parent_remote, rfunc, resource): | |
rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(rfunc)() | |
# W = MarioLevel.seg_width | |
simulator = MarioProxy() | |
parent_remote.close() | |
refs = [MarioLevel(lvl) for lvl in resource.get('refs', [])] | |
simlt_k = 150 if resource.get('test', False) else 100 | |
while True: | |
try: | |
cmd, data = remote.recv() | |
if cmd == 'evaluate': | |
tid, strlvl = data | |
lvl = MarioLevel(strlvl) | |
segs = lvl.to_segs() | |
simlt_res = MarioProxy.get_seg_infos(simulator.simulate_complete(lvl, segTimeK=simlt_k)) | |
rewards = rfunc.get_rewards(segs=segs, simlt_res=simlt_res) | |
remote.send((tid, rewards)) | |
pass | |
elif cmd == 'close': | |
remote.close() | |
break | |
elif cmd == 'check_playable': | |
strlvl, item = data | |
lvl = MarioLevel(strlvl) | |
standable = False | |
for i in range(lvl.h): | |
if lvl[i,0] in MarioLevel.solidset: | |
standable = True | |
break | |
if standable: | |
simlt_res = simulator.simulate_game(lvl) | |
playable = simlt_res['status'] == 'WIN' | |
remote.send((playable, item)) | |
else: | |
remote.send((False, item)) | |
elif cmd == 'mnd_item': | |
# strlvl = data | |
p = MarioLevel(data) | |
min_hm, min_dtw = float('inf'), float('inf') | |
for q in refs: | |
vhm = hamming_dis(p, q) | |
vdtw = lvl_dtw(p, q) | |
if vhm > 0: | |
min_hm = min(min_hm, vhm) | |
if vdtw > 0: | |
min_dtw = min(min_dtw, vdtw) | |
remote.send((min_hm, min_dtw)) | |
elif cmd == 'mpd': | |
hms, dtws = [], [] | |
for strlvl1, strlvl2 in data: | |
lvl1, lvl2 = MarioLevel(strlvl1), MarioLevel(strlvl2) | |
hms.append(hamming_dis(lvl1, lvl2)) | |
remote.send((hms, None)) | |
else: | |
raise KeyError(f'Unknown command for simulation worker: {cmd}') | |
except EOFError: | |
break | |
pass | |
class AsycSimltPool: | |
""" | |
异步池, 用于多进程马里奥模拟任务 | |
""" | |
def __init__(self, poolsize, queuesize=None, rfunc_name='default', verbose=True, **rsrc): | |
self.np, self.nq = poolsize, poolsize if queuesize is None else queuesize | |
self.waiting_queue = Queue(self.nq) | |
self.ready = [True] * poolsize | |
resource = {'rfunc': 'default'} | |
resource.update(rsrc) | |
self.__init_remotes(rfunc_name, resource) | |
self.res_buffer = [] | |
self.histlen = AsycSimltPool.get_histlen(rfunc_name) | |
self.verbose = verbose | |
def get_histlen(rfunc_name): | |
rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(rfunc_name)() | |
return rfunc.get_n() | |
pass | |
def put(self, cmd, args): | |
""" | |
Put a new evaluation task into the pool. If the pool and waiting queue is full, | |
wait until a process is free | |
""" | |
putted = False | |
for i, remote in enumerate(self.remotes): | |
if self.ready[i]: | |
remote.send((cmd, args)) | |
self.ready[i] = False | |
putted = True | |
break | |
while not putted: | |
try: | |
self.waiting_queue.put((cmd, args), timeout=0.01) | |
putted = True | |
except FullExpection: | |
self.refresh() | |
def get(self, wait=False): | |
if wait: | |
self.__wait() | |
self.refresh() | |
occp, occq = self.get_occupied() | |
if self.verbose: | |
print(f'Workers: {occp}/{self.np}, Queue: {occq}/{self.nq}, Buffer: {len(self.res_buffer)}') | |
res = self.res_buffer | |
self.res_buffer = [] | |
return res | |
def get_occupied(self): | |
process_occupied = sum(0 if r else 1 for r in self.ready) | |
return process_occupied, self.waiting_queue.qsize() | |
def refresh(self): | |
""" Recive ready results and cache them in buffer, then assign tasks in waiting queue to free workers """ | |
for i, remote in enumerate(self.remotes): | |
if remote.poll(): | |
self.res_buffer.append(remote.recv()) | |
self.ready[i] = True | |
for i, remote in enumerate(self.remotes): | |
if self.waiting_queue.empty(): | |
break | |
if self.ready[i]: | |
cmd, args = self.waiting_queue.get() | |
remote.send((cmd, args)) | |
self.ready[i] = False | |
def blocking(self): | |
self.refresh() | |
return self.waiting_queue.full() | |
def __init_remotes(self, rfunc, resource): | |
forkserver_available = "forkserver" in mp.get_all_start_methods() | |
start_method = "forkserver" if forkserver_available else "spawn" | |
ctx = mp.get_context(start_method) | |
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(self.np)]) | |
self.processes = [] | |
for work_remote, remote in zip(self.work_remotes, self.remotes): | |
args = (work_remote, remote, rfunc, resource) | |
# daemon=True: if the main process crashes, we should not cause things to hang | |
# 开启多进程来做异步计算 | |
process = ctx.Process(target=_simlt_worker, args=args, daemon=True) # pytype:disable=attribute-error | |
process.start() | |
self.processes.append(process) | |
work_remote.close() | |
def __wait(self): | |
finish = False | |
while not finish: | |
self.refresh() | |
finish = all(r for r in self.ready) | |
time.sleep(0.01) | |
def close(self): | |
res = self.get(True) | |
for remote, p in zip(self.remotes, self.processes): | |
remote.send(('close', None)) | |
p.join() | |
return res | |