PromptDA / promptda /utils /parallel_utils.py
haotongl
inital version
98844c3
from typing import Callable, List, Dict
from multiprocessing.pool import ThreadPool
from tqdm import tqdm
from threading import Thread
import asyncio
from functools import wraps
def async_call_func(func):
@wraps(func)
async def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
# Use run_in_executor to run the blocking function in a separate thread
return await loop.run_in_executor(None, func, *args, **kwargs)
return wrapper
def async_call(fn):
def wrapper(*args, **kwargs):
Thread(target=fn, args=args, kwargs=kwargs).start()
return wrapper
def parallel_execution(*args, action: Callable, num_processes=32, print_progress=False, sequential=False, async_return=False, desc=None, **kwargs):
# Copy from EasyVolCap
# Author: Zhen Xu https://github.com/dendenxu
# NOTE: we expect first arg / or kwargs to be distributed
# NOTE: print_progress arg is reserved
def get_length(args: List, kwargs: Dict):
for a in args:
if isinstance(a, list):
return len(a)
for v in kwargs.values():
if isinstance(v, list):
return len(v)
raise NotImplementedError
def get_action_args(length: int, args: List, kwargs: Dict, i: int):
action_args = [(arg[i] if isinstance(arg, list) and len(
arg) == length else arg) for arg in args]
# TODO: Support all types of iterable
action_kwargs = {key: (kwargs[key][i] if isinstance(kwargs[key], list) and len(
kwargs[key]) == length else kwargs[key]) for key in kwargs}
return action_args, action_kwargs
if not sequential:
# Create ThreadPool
pool = ThreadPool(processes=num_processes)
# Spawn threads
results = []
asyncs = []
length = get_length(args, kwargs)
for i in range(length):
action_args, action_kwargs = get_action_args(
length, args, kwargs, i)
async_result = pool.apply_async(action, action_args, action_kwargs)
asyncs.append(async_result)
# Join threads and get return values
if not async_return:
for async_result in tqdm(asyncs, desc=desc, disable=not print_progress):
# will sync the corresponding thread
results.append(async_result.get())
pool.close()
pool.join()
return results
else:
return pool
else:
results = []
length = get_length(args, kwargs)
for i in tqdm(range(length), desc=desc, disable=not print_progress):
action_args, action_kwargs = get_action_args(
length, args, kwargs, i)
async_result = action(*action_args, **action_kwargs)
results.append(async_result)
return results