from typing import Callable, List, Dict from multiprocessing.pool import ThreadPool from tqdm import tqdm from threading import Thread import asyncio from functools import wraps def async_call_func(func): @wraps(func) async def wrapper(*args, **kwargs): loop = asyncio.get_event_loop() # Use run_in_executor to run the blocking function in a separate thread return await loop.run_in_executor(None, func, *args, **kwargs) return wrapper def async_call(fn): def wrapper(*args, **kwargs): Thread(target=fn, args=args, kwargs=kwargs).start() return wrapper def parallel_execution(*args, action: Callable, num_processes=32, print_progress=False, sequential=False, async_return=False, desc=None, **kwargs): # Copy from EasyVolCap # Author: Zhen Xu https://github.com/dendenxu # NOTE: we expect first arg / or kwargs to be distributed # NOTE: print_progress arg is reserved def get_length(args: List, kwargs: Dict): for a in args: if isinstance(a, list): return len(a) for v in kwargs.values(): if isinstance(v, list): return len(v) raise NotImplementedError def get_action_args(length: int, args: List, kwargs: Dict, i: int): action_args = [(arg[i] if isinstance(arg, list) and len( arg) == length else arg) for arg in args] # TODO: Support all types of iterable action_kwargs = {key: (kwargs[key][i] if isinstance(kwargs[key], list) and len( kwargs[key]) == length else kwargs[key]) for key in kwargs} return action_args, action_kwargs if not sequential: # Create ThreadPool pool = ThreadPool(processes=num_processes) # Spawn threads results = [] asyncs = [] length = get_length(args, kwargs) for i in range(length): action_args, action_kwargs = get_action_args( length, args, kwargs, i) async_result = pool.apply_async(action, action_args, action_kwargs) asyncs.append(async_result) # Join threads and get return values if not async_return: for async_result in tqdm(asyncs, desc=desc, disable=not print_progress): # will sync the corresponding thread results.append(async_result.get()) pool.close() pool.join() return results else: return pool else: results = [] length = get_length(args, kwargs) for i in tqdm(range(length), desc=desc, disable=not print_progress): action_args, action_kwargs = get_action_args( length, args, kwargs, i) async_result = action(*action_args, **action_kwargs) results.append(async_result) return results