File size: 1,710 Bytes
9d0d223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
from queue import Queue, Empty
import signal
import sys
import threading
import traceback

logger = logging.getLogger(__name__)


class DeadlockDetect:
    def __init__(self, use: bool = False, timeout: float = 120.):
        self.use = use
        self.timeout = timeout
        self._queue: Queue = Queue()

    def update(self, stage: str):
        if self.use:
            self._queue.put(stage)

    def __enter__(self):
        if self.use:
            self._thread = threading.Thread(target=self._detector_thread)
            self._thread.start()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.use:
            self._queue.put(None)
            self._thread.join()

    def _detector_thread(self):
        logger.debug("Deadlock detector started")
        last_stage = "init"
        while True:
            try:
                stage = self._queue.get(timeout=self.timeout)
            except Empty:
                break
            if stage is None:
                logger.debug("Exiting deadlock detector thread")
                return
            else:
                last_stage = stage
        logger.error("Deadlock detector timed out, last stage was %s", last_stage)
        for th in threading.enumerate():
            print(th, file=sys.stderr)
            traceback.print_stack(sys._current_frames()[th.ident])
            print(file=sys.stderr)
        sys.stdout.flush()
        sys.stderr.flush()
        os.kill(os.getpid(), signal.SIGKILL)