Spaces:
Sleeping
Sleeping
File size: 2,412 Bytes
eaf2e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from multiprocessing import Pool
from src.utils.datastruct import ConditionalDataBuffer
class MyAsyncPool:
def __init__(self, n):
self.n = n
self.pool = Pool(n)
self.res_buffer = ConditionalDataBuffer()
def push(self, func, *args):
"""
Push task(s) in to the process pool
:val func: function to be executed
:val args: arguments of the function in the task, any number larger than 1 of tuples is valid.
:return: None
"""
results = [self.pool.apply_async(func, arg) for arg in args]
self.res_buffer.push(*results)
def collect(self):
"""
Collect resutls of all the finished tasks
:return: A list of resutls of all the finished tasks
"""
roll_outs = self.res_buffer.collect(lambda x:x.ready())
return [item.get() for item in roll_outs]
pass
def wait_and_get(self):
for res in self.res_buffer.main:
res.wait()
return self.collect()
pass
def run_all(self, func, *args):
pass
def close(self):
self.pool.close()
def get_num_waiting(self):
return len(self.res_buffer)
def terminate(self):
self.pool.terminate()
class TaskLoad:
def __init__(self, *args, **kwargs):
self.args = args
self.info = kwargs
pass
class AsyncPoolWithLoad(MyAsyncPool):
def push(self, func, *loads):
"""
Push task(s) with addtional dictionary information into the process pool
:val func: function to be executed
:val args: arguments of the function in the task, any number larger than 1 of tuples is valid.
:return: None
"""
results = [self.pool.apply_async(func, load.args) for load in loads]
tmp = ((res, load.info) for res, load in zip(results, loads))
self.res_buffer.push(*tmp)
def collect(self):
"""
Collect resutls of all the finished tasks
:return: A list of resutls of all the finished tasks
"""
roll_outs = self.res_buffer.collect(lambda x: x[0].ready())
# print(data)
# tmp =
return [(res.get(), info) for res, info in roll_outs]
pass
def wait_and_get(self):
for res, _ in self.res_buffer.main:
res.wait()
res = self.collect()
# print(data)
return res
|