Spaces:
Sleeping
Sleeping
File size: 6,371 Bytes
eaf2e33 3582c8a eaf2e33 3582c8a eaf2e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import time
import importlib
import multiprocessing as mp
from src.smb.level import *
from src.smb.proxy import MarioProxy
from queue import Queue, Full as FullExpection
def _simlt_worker(remote, parent_remote, rfunc, resource):
rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(rfunc)()
# W = MarioLevel.seg_width
simulator = MarioProxy()
parent_remote.close()
refs = [MarioLevel(lvl) for lvl in resource.get('refs', [])]
simlt_k = 150 if resource.get('test', False) else 100
while True:
try:
cmd, data = remote.recv()
if cmd == 'evaluate':
tid, strlvl = data
lvl = MarioLevel(strlvl)
segs = lvl.to_segs()
simlt_res = MarioProxy.get_seg_infos(simulator.simulate_complete(lvl, segTimeK=simlt_k))
rewards = rfunc.get_rewards(segs=segs, simlt_res=simlt_res)
remote.send((tid, rewards))
pass
elif cmd == 'close':
remote.close()
break
elif cmd == 'check_playable':
strlvl, item = data
lvl = MarioLevel(strlvl)
standable = False
for i in range(lvl.h):
if lvl[i,0] in MarioLevel.solidset:
standable = True
break
if standable:
simlt_res = simulator.simulate_game(lvl)
playable = simlt_res['status'] == 'WIN'
remote.send((playable, item))
else:
remote.send((False, item))
elif cmd == 'mnd_item':
# strlvl = data
p = MarioLevel(data)
min_hm, min_dtw = float('inf'), float('inf')
for q in refs:
vhm = hamming_dis(p, q)
vdtw = lvl_dtw(p, q)
if vhm > 0:
min_hm = min(min_hm, vhm)
if vdtw > 0:
min_dtw = min(min_dtw, vdtw)
remote.send((min_hm, min_dtw))
elif cmd == 'mpd':
hms, dtws = [], []
for strlvl1, strlvl2 in data:
lvl1, lvl2 = MarioLevel(strlvl1), MarioLevel(strlvl2)
hms.append(hamming_dis(lvl1, lvl2))
remote.send((hms, None))
else:
raise KeyError(f'Unknown command for simulation worker: {cmd}')
except EOFError:
break
pass
class AsycSimltPool:
"""
异步池, 用于多进程马里奥模拟任务
"""
def __init__(self, poolsize, queuesize=None, rfunc_name='default', verbose=True, **rsrc):
self.np, self.nq = poolsize, poolsize if queuesize is None else queuesize
self.waiting_queue = Queue(self.nq)
self.ready = [True] * poolsize
resource = {'rfunc': 'default'}
resource.update(rsrc)
self.__init_remotes(rfunc_name, resource)
self.res_buffer = []
self.histlen = AsycSimltPool.get_histlen(rfunc_name)
self.verbose = verbose
@staticmethod
def get_histlen(rfunc_name):
rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(rfunc_name)()
return rfunc.get_n()
pass
def put(self, cmd, args):
"""
Put a new evaluation task into the pool. If the pool and waiting queue is full,
wait until a process is free
"""
putted = False
for i, remote in enumerate(self.remotes):
if self.ready[i]:
remote.send((cmd, args))
self.ready[i] = False
putted = True
break
while not putted:
try:
self.waiting_queue.put((cmd, args), timeout=0.01)
putted = True
except FullExpection:
self.refresh()
def get(self, wait=False):
if wait:
self.__wait()
self.refresh()
occp, occq = self.get_occupied()
if self.verbose:
print(f'Workers: {occp}/{self.np}, Queue: {occq}/{self.nq}, Buffer: {len(self.res_buffer)}')
res = self.res_buffer
self.res_buffer = []
return res
def get_occupied(self):
process_occupied = sum(0 if r else 1 for r in self.ready)
return process_occupied, self.waiting_queue.qsize()
def refresh(self):
""" Recive ready results and cache them in buffer, then assign tasks in waiting queue to free workers """
for i, remote in enumerate(self.remotes):
if remote.poll():
self.res_buffer.append(remote.recv())
self.ready[i] = True
for i, remote in enumerate(self.remotes):
if self.waiting_queue.empty():
break
if self.ready[i]:
cmd, args = self.waiting_queue.get()
remote.send((cmd, args))
self.ready[i] = False
def blocking(self):
self.refresh()
return self.waiting_queue.full()
def __init_remotes(self, rfunc, resource):
forkserver_available = "forkserver" in mp.get_all_start_methods()
start_method = "forkserver" if forkserver_available else "spawn"
ctx = mp.get_context(start_method)
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(self.np)])
self.processes = []
for work_remote, remote in zip(self.work_remotes, self.remotes):
args = (work_remote, remote, rfunc, resource)
# daemon=True: if the main process crashes, we should not cause things to hang
# 开启多进程来做异步计算
process = ctx.Process(target=_simlt_worker, args=args, daemon=True) # pytype:disable=attribute-error
process.start()
self.processes.append(process)
work_remote.close()
def __wait(self):
finish = False
while not finish:
self.refresh()
finish = all(r for r in self.ready)
time.sleep(0.01)
def close(self):
res = self.get(True)
for remote, p in zip(self.remotes, self.processes):
remote.send(('close', None))
p.join()
return res
|