Spaces:
Runtime error
Runtime error
import re | |
import os | |
import signal | |
import logging | |
import sys | |
from time import sleep, time | |
from random import random, randint | |
from multiprocessing import JoinableQueue, Event, Process | |
from queue import Empty | |
from typing import Optional | |
logger = logging.getLogger(__name__) | |
def re_findall(pattern, string): | |
return [m.groupdict() for m in re.finditer(pattern, string)] | |
class Task: | |
def __init__(self, function, *args, **kwargs) -> None: | |
self.function = function | |
self.args = args | |
self.kwargs = kwargs | |
def run(self): | |
return self.function(*self.args, **self.kwargs) | |
class CallbackGenerator: | |
def __init__(self, generator, callback): | |
self.generator = generator | |
self.callback = callback | |
def __iter__(self): | |
if self.callback is not None and callable(self.callback): | |
for t in self.generator: | |
self.callback(t) | |
yield t | |
else: | |
yield from self.generator | |
def start_worker(q: JoinableQueue, stop_event: Event): # TODO make class? | |
logger.info('Starting worker...') | |
while True: | |
if stop_event.is_set(): | |
logger.info('Worker exiting because of stop_event') | |
break | |
# We set a timeout so we loop past 'stop_event' even if the queue is empty | |
try: | |
task = q.get(timeout=.01) | |
except Empty: | |
# Run next iteration of loop | |
continue | |
# Exit if end of queue | |
if task is None: | |
logger.info('Worker exiting because of None on queue') | |
q.task_done() | |
break | |
try: | |
task.run() # Do the task | |
except: # Will also catch KeyboardInterrupt | |
logger.exception(f'Failed to process task {task}', ) | |
# Can implement some kind of retry handling here | |
finally: | |
q.task_done() | |
class InterruptibleTaskPool: | |
# https://the-fonz.gitlab.io/posts/python-multiprocessing/ | |
def __init__(self, | |
tasks=None, | |
num_workers=None, | |
callback=None, # Fired on start | |
max_queue_size=1, | |
grace_period=2, | |
kill_period=30, | |
): | |
self.tasks = CallbackGenerator( | |
[] if tasks is None else tasks, callback) | |
self.num_workers = os.cpu_count() if num_workers is None else num_workers | |
self.max_queue_size = max_queue_size | |
self.grace_period = grace_period | |
self.kill_period = kill_period | |
# The JoinableQueue has an internal counter that increments when an item is put on the queue and | |
# decrements when q.task_done() is called. This allows us to wait until it's empty using .join() | |
self.queue = JoinableQueue(maxsize=self.max_queue_size) | |
# This is a process-safe version of the 'panic' variable shown above | |
self.stop_event = Event() | |
# n_workers: Start this many processes | |
# max_queue_size: If queue exceeds this size, block when putting items on the queue | |
# grace_period: Send SIGINT to processes if they don't exit within this time after SIGINT/SIGTERM | |
# kill_period: Send SIGKILL to processes if they don't exit after this many seconds | |
# self.on_task_complete = on_task_complete | |
# self.raise_after_interrupt = raise_after_interrupt | |
def __enter__(self): | |
self.start() | |
return self | |
def __exit__(self, exc_type, exc_value, exc_traceback): | |
pass | |
def start(self) -> None: | |
def handler(signalname): | |
""" | |
Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed. | |
Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose. | |
""" | |
def f(signal_received, frame): | |
raise KeyboardInterrupt(f'{signalname} received') | |
return f | |
# This will be inherited by the child process if it is forked (not spawned) | |
signal.signal(signal.SIGINT, handler('SIGINT')) | |
signal.signal(signal.SIGTERM, handler('SIGTERM')) | |
procs = [] | |
for i in range(self.num_workers): | |
# Make it a daemon process so it is definitely terminated when this process exits, | |
# might be overkill but is a nice feature. See | |
# https://docs.python.org/3.8/library/multiprocessing.html#multiprocessing.Process.daemon | |
p = Process(name=f'Worker-{i:02d}', daemon=True, | |
target=start_worker, args=(self.queue, self.stop_event)) | |
procs.append(p) | |
p.start() | |
try: | |
# Put tasks on queue | |
for task in self.tasks: | |
logger.info(f'Put task {task} on queue') | |
self.queue.put(task) | |
# Put exit tasks on queue | |
for i in range(self.num_workers): | |
self.queue.put(None) | |
# Wait until all tasks are processed | |
self.queue.join() | |
except KeyboardInterrupt: | |
logger.warning('Caught KeyboardInterrupt! Setting stop event...') | |
# raise # TODO add option | |
finally: | |
self.stop_event.set() | |
t = time() | |
# Send SIGINT if process doesn't exit quickly enough, and kill it as last resort | |
# .is_alive() also implicitly joins the process (good practice in linux) | |
while alive_procs := [p for p in procs if p.is_alive()]: | |
if time() > t + self.grace_period: | |
for p in alive_procs: | |
os.kill(p.pid, signal.SIGINT) | |
logger.warning(f'Sending SIGINT to {p}') | |
elif time() > t + self.kill_period: | |
for p in alive_procs: | |
logger.warning(f'Sending SIGKILL to {p}') | |
# Queues and other inter-process communication primitives can break when | |
# process is killed, but we don't care here | |
p.kill() | |
sleep(.01) | |
sleep(.1) | |
for p in procs: | |
logger.info(f'Process status: {p}') | |