File size: 6,371 Bytes
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3582c8a
 
 
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3582c8a
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import time
import importlib
import multiprocessing as mp
from src.smb.level import *
from src.smb.proxy import MarioProxy
from queue import Queue, Full as FullExpection


def _simlt_worker(remote, parent_remote, rfunc, resource):
    rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(rfunc)()
    # W = MarioLevel.seg_width
    simulator = MarioProxy()
    parent_remote.close()
    refs = [MarioLevel(lvl) for lvl in resource.get('refs', [])]
    simlt_k = 150 if resource.get('test', False) else 100
    while True:
        try:
            cmd, data = remote.recv()
            if cmd == 'evaluate':
                tid, strlvl = data
                lvl = MarioLevel(strlvl)
                segs = lvl.to_segs()
                simlt_res = MarioProxy.get_seg_infos(simulator.simulate_complete(lvl, segTimeK=simlt_k))
                rewards = rfunc.get_rewards(segs=segs, simlt_res=simlt_res)
                remote.send((tid, rewards))
                pass
            elif cmd == 'close':
                remote.close()
                break
            elif cmd == 'check_playable':
                strlvl, item = data
                lvl = MarioLevel(strlvl)
                standable = False
                for i in range(lvl.h):
                    if lvl[i,0] in MarioLevel.solidset:
                        standable = True
                        break
                if standable:
                    simlt_res = simulator.simulate_game(lvl)
                    playable = simlt_res['status'] == 'WIN'
                    remote.send((playable, item))
                else:
                    remote.send((False, item))
            elif cmd == 'mnd_item':
                # strlvl = data
                p = MarioLevel(data)
                min_hm, min_dtw = float('inf'), float('inf')
                for q in refs:
                    vhm = hamming_dis(p, q)
                    vdtw = lvl_dtw(p, q)
                    if vhm > 0:
                        min_hm = min(min_hm, vhm)
                    if vdtw > 0:
                        min_dtw = min(min_dtw, vdtw)
                remote.send((min_hm, min_dtw))
            elif cmd == 'mpd':
                hms, dtws = [], []
                for strlvl1, strlvl2 in data:
                    lvl1, lvl2 = MarioLevel(strlvl1), MarioLevel(strlvl2)
                    hms.append(hamming_dis(lvl1, lvl2))
                remote.send((hms, None))
            else:
                raise KeyError(f'Unknown command for simulation worker: {cmd}')
        except EOFError:
            break
    pass


class AsycSimltPool:
    """
    异步池, 用于多进程马里奥模拟任务
    """
    def __init__(self, poolsize, queuesize=None, rfunc_name='default', verbose=True, **rsrc):
        self.np, self.nq = poolsize, poolsize if queuesize is None else queuesize
        self.waiting_queue = Queue(self.nq)
        self.ready = [True] * poolsize
        resource = {'rfunc': 'default'}
        resource.update(rsrc)
        self.__init_remotes(rfunc_name, resource)
        self.res_buffer = []
        self.histlen = AsycSimltPool.get_histlen(rfunc_name)
        self.verbose = verbose

    @staticmethod
    def get_histlen(rfunc_name):
        rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(rfunc_name)()
        return rfunc.get_n()
        pass

    def put(self, cmd, args):
        """
            Put a new evaluation task into the pool. If the pool and waiting queue is full,
            wait until a process is free
        """
        putted = False
        for i, remote in enumerate(self.remotes):
            if self.ready[i]:
                remote.send((cmd, args))
                self.ready[i] = False
                putted = True
                break
        while not putted:
            try:
                self.waiting_queue.put((cmd, args), timeout=0.01)
                putted = True
            except FullExpection:
                self.refresh()

    def get(self, wait=False):
        if wait:
            self.__wait()
        self.refresh()
        occp, occq = self.get_occupied()
        if self.verbose:
            print(f'Workers: {occp}/{self.np}, Queue: {occq}/{self.nq}, Buffer: {len(self.res_buffer)}')
        res = self.res_buffer
        self.res_buffer = []
        return res

    def get_occupied(self):
        process_occupied = sum(0 if r else 1 for r in self.ready)
        return process_occupied, self.waiting_queue.qsize()

    def refresh(self):
        """ Recive ready results and cache them in buffer, then assign tasks in waiting queue to free workers """
        for i, remote in enumerate(self.remotes):
            if remote.poll():
                self.res_buffer.append(remote.recv())
                self.ready[i] = True
        for i, remote in enumerate(self.remotes):
            if self.waiting_queue.empty():
                break
            if self.ready[i]:
                cmd, args = self.waiting_queue.get()
                remote.send((cmd, args))
                self.ready[i] = False

    def blocking(self):
        self.refresh()
        return self.waiting_queue.full()

    def __init_remotes(self, rfunc, resource):
        forkserver_available = "forkserver" in mp.get_all_start_methods()
        start_method = "forkserver" if forkserver_available else "spawn"
        ctx = mp.get_context(start_method)

        self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(self.np)])
        self.processes = []
        for work_remote, remote in zip(self.work_remotes, self.remotes):
            args = (work_remote, remote, rfunc, resource)
            # daemon=True: if the main process crashes, we should not cause things to hang
            # 开启多进程来做异步计算
            process = ctx.Process(target=_simlt_worker, args=args, daemon=True)  # pytype:disable=attribute-error
            process.start()
            self.processes.append(process)
            work_remote.close()

    def __wait(self):
        finish = False
        while not finish:
            self.refresh()
            finish = all(r for r in self.ready)
            time.sleep(0.01)

    def close(self):
        res = self.get(True)
        for remote, p in zip(self.remotes, self.processes):
            remote.send(('close', None))
            p.join()
        return res