Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,868 Bytes
98844c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from typing import Callable, List, Dict
from multiprocessing.pool import ThreadPool
from tqdm import tqdm
from threading import Thread
import asyncio
from functools import wraps
def async_call_func(func):
@wraps(func)
async def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
# Use run_in_executor to run the blocking function in a separate thread
return await loop.run_in_executor(None, func, *args, **kwargs)
return wrapper
def async_call(fn):
def wrapper(*args, **kwargs):
Thread(target=fn, args=args, kwargs=kwargs).start()
return wrapper
def parallel_execution(*args, action: Callable, num_processes=32, print_progress=False, sequential=False, async_return=False, desc=None, **kwargs):
# Copy from EasyVolCap
# Author: Zhen Xu https://github.com/dendenxu
# NOTE: we expect first arg / or kwargs to be distributed
# NOTE: print_progress arg is reserved
def get_length(args: List, kwargs: Dict):
for a in args:
if isinstance(a, list):
return len(a)
for v in kwargs.values():
if isinstance(v, list):
return len(v)
raise NotImplementedError
def get_action_args(length: int, args: List, kwargs: Dict, i: int):
action_args = [(arg[i] if isinstance(arg, list) and len(
arg) == length else arg) for arg in args]
# TODO: Support all types of iterable
action_kwargs = {key: (kwargs[key][i] if isinstance(kwargs[key], list) and len(
kwargs[key]) == length else kwargs[key]) for key in kwargs}
return action_args, action_kwargs
if not sequential:
# Create ThreadPool
pool = ThreadPool(processes=num_processes)
# Spawn threads
results = []
asyncs = []
length = get_length(args, kwargs)
for i in range(length):
action_args, action_kwargs = get_action_args(
length, args, kwargs, i)
async_result = pool.apply_async(action, action_args, action_kwargs)
asyncs.append(async_result)
# Join threads and get return values
if not async_return:
for async_result in tqdm(asyncs, desc=desc, disable=not print_progress):
# will sync the corresponding thread
results.append(async_result.get())
pool.close()
pool.join()
return results
else:
return pool
else:
results = []
length = get_length(args, kwargs)
for i in tqdm(range(length), desc=desc, disable=not print_progress):
action_args, action_kwargs = get_action_args(
length, args, kwargs, i)
async_result = action(*action_args, **action_kwargs)
results.append(async_result)
return results
|