|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
Module for computational graph execution.
|
|
|
|
Classes:
|
|
Task: Abstract base class representing a computational task.
|
|
Executor: Class for scheduling and executing directed acyclic task graphs.
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
|
|
|
|
import networkx
|
|
import torch
|
|
import tqdm
|
|
from pydantic import BaseModel
|
|
from typing_extensions import Generic, TypeVar
|
|
|
|
ValueT = TypeVar("ValueT")
|
|
|
|
|
|
class Task(ABC, BaseModel, Generic[ValueT], frozen=True):
|
|
"""
|
|
Abstract base class representing a task in a computational graph.
|
|
|
|
This class should be extended to define specific tasks. Each task can have arguments (dependencies) and a defined execution strategy.
|
|
|
|
Attributes:
|
|
Generic[ValueT] (TypeVar): The type of the value that the task returns upon execution.
|
|
|
|
Methods:
|
|
arguments: Abstract method to define task arguments (dependencies).
|
|
execute: Abstract method to execute the task.
|
|
priority: Returns the priority of the task for scheduling purposes.
|
|
group_label: Returns an optional label for task grouping.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def arguments(self) -> Dict[str, "Task"]:
|
|
"""
|
|
Returns a dictionary of arguments required for this task. The keys of the dictionary
|
|
are argument names, and the values are Task instances. These keys correspond to the
|
|
keyword argument names expected by the execute method.
|
|
|
|
For example, if this method returns {'input1': taskA, 'input2': taskB}, the execute
|
|
method should expect to be called as execute(input1=valueA, input2=valueB), where
|
|
valueA and valueB are the outputs of taskA and taskB respectively.
|
|
|
|
Returns:
|
|
Dict[str, "Task"]: A dictionary mapping argument names to Task instances.
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
def execute(self, **kwargs) -> ValueT:
|
|
"""
|
|
Executes the task using the results of its dependencies.
|
|
|
|
The keyword arguments (**kwargs) for this method are dynamically determined based on
|
|
the dictionary returned by the 'arguments' method. Each key in the 'arguments' method's
|
|
return dictionary becomes a keyword argument in this method, with its value being
|
|
the result of the corresponding task's execution.
|
|
|
|
Returns:
|
|
ValueT: The result of the task execution.
|
|
"""
|
|
...
|
|
|
|
def priority(self) -> int:
|
|
"""
|
|
Returns the priority of the task for scheduling.
|
|
|
|
Higher numbers indicate higher priority. Default is 0.
|
|
|
|
Returns:
|
|
int: The priority of the task.
|
|
"""
|
|
return 0
|
|
|
|
def group_label(self) -> Optional[str]:
|
|
"""
|
|
Returns an optional label used for grouping tasks together.
|
|
|
|
Returns:
|
|
Optional[str]: The group label of the task, if any.
|
|
"""
|
|
return None
|
|
|
|
def uses_accelerator(self) -> bool:
|
|
"""
|
|
Returns True if the task can take advantage of matrix operation
|
|
acceleration (such as on a GPU).
|
|
"""
|
|
return False
|
|
|
|
|
|
class Executor:
|
|
"""
|
|
Schedules and executes a set of tasks and their dependencies.
|
|
|
|
Handles scheduling, execution, the movement of data between devices, and the lifecycle of intermediate results.
|
|
|
|
Attributes:
|
|
math_device (torch.device): Device used for tensor computations.
|
|
storage_device (torch.device): Device used for storing intermediate results.
|
|
targets (List[Task]): List of target tasks to be executed.
|
|
schedule (List[Task]): Calculated execution schedule of tasks.
|
|
dependencies (Dict[Task, Set[Task]]): Dependencies of each task.
|
|
"""
|
|
|
|
math_device: torch.device
|
|
storage_device: torch.device
|
|
targets: List[Task]
|
|
schedule: List[Task]
|
|
dependencies: Dict[Task, Set[Task]]
|
|
|
|
def __init__(
|
|
self,
|
|
tasks: List[Task],
|
|
math_device: torch.device = torch.device("cpu"),
|
|
storage_device: torch.device = torch.device("cpu"),
|
|
):
|
|
"""
|
|
Initializes the Executor with a list of tasks and device configurations.
|
|
|
|
Args:
|
|
tasks (List[Task]): The list of tasks to be executed.
|
|
math_device (torch.device, optional): The device for tensor computations. Defaults to CPU.
|
|
storage_device (torch.device, optional): The device for storing results. Defaults to CPU.
|
|
"""
|
|
self.math_device = math_device
|
|
self.storage_device = storage_device
|
|
self.schedule = self._make_schedule(tasks)
|
|
self.targets = tasks
|
|
|
|
def run(self, quiet: bool = False) -> Iterator[Tuple[Task, Any]]:
|
|
"""
|
|
Execute the computed schedule and yield the target values.
|
|
|
|
Yields:
|
|
Iterator[Tuple[Task, Any]]: An iterator of task-result pairs.
|
|
"""
|
|
|
|
last_use_index = {}
|
|
for idx, task in reversed(list(enumerate(self.schedule))):
|
|
for t in self.dependencies[task]:
|
|
if t not in last_use_index:
|
|
last_use_index[t] = idx
|
|
if task not in last_use_index:
|
|
last_use_index[task] = idx
|
|
|
|
values: Dict[Task, Any] = {}
|
|
for idx, task in (
|
|
pbar := tqdm.tqdm(
|
|
list(enumerate(self.schedule)),
|
|
disable=quiet,
|
|
desc="Executing graph",
|
|
)
|
|
):
|
|
use_math_device = task.uses_accelerator()
|
|
|
|
arguments = {}
|
|
for name, dep in task.arguments().items():
|
|
value = values[dep]
|
|
|
|
|
|
if use_math_device:
|
|
if (
|
|
isinstance(value, torch.Tensor)
|
|
and value.device != self.math_device
|
|
):
|
|
value = value.to(self.math_device)
|
|
elif isinstance(value, dict):
|
|
for key in value:
|
|
if (
|
|
isinstance(value[key], torch.Tensor)
|
|
and value[key].device != self.math_device
|
|
):
|
|
value[key] = value[key].to(self.math_device)
|
|
|
|
arguments[name] = value
|
|
del value
|
|
|
|
res = task.execute(**arguments)
|
|
del arguments
|
|
|
|
if isinstance(res, torch.Tensor) and res.device != self.storage_device:
|
|
res = res.to(self.storage_device)
|
|
|
|
values[task] = res
|
|
del res
|
|
|
|
if task in self.targets:
|
|
yield (task, values[task])
|
|
|
|
|
|
expired = []
|
|
for key in values:
|
|
if idx >= last_use_index[key]:
|
|
expired.append(key)
|
|
|
|
for key in expired:
|
|
del values[key]
|
|
|
|
del values
|
|
del pbar
|
|
|
|
def execute(self) -> None:
|
|
"""
|
|
Execute all tasks and discard results.
|
|
"""
|
|
for task, value in self.run():
|
|
pass
|
|
|
|
DUMMY_TASK_VALUE = "!!DUMMY!!"
|
|
|
|
def _make_schedule(self, targets: List[Task]) -> List[Task]:
|
|
self.schedule = []
|
|
self.dependencies = self._build_dependencies(targets)
|
|
|
|
edge_tups = []
|
|
for node in self.dependencies:
|
|
for dependency in self.dependencies[node]:
|
|
edge_tups.append((dependency, node))
|
|
|
|
for task in targets:
|
|
|
|
|
|
edge_tups.append((Executor.DUMMY_TASK_VALUE, task))
|
|
|
|
def _compare_key(task: Union[Task, str]):
|
|
if task == Executor.DUMMY_TASK_VALUE:
|
|
return ("", 0)
|
|
return (
|
|
task.group_label() or "",
|
|
-task.priority(),
|
|
)
|
|
|
|
graph = networkx.DiGraph(edge_tups)
|
|
res = [
|
|
t
|
|
for t in networkx.lexicographical_topological_sort(graph, key=_compare_key)
|
|
if t != Executor.DUMMY_TASK_VALUE
|
|
]
|
|
return res
|
|
|
|
def _build_dependencies(self, targets: List[Task]) -> Dict[Task, Set[Task]]:
|
|
task_dependencies: Dict[Task, Set[Task]] = {}
|
|
to_process = list(targets)
|
|
while to_process:
|
|
child = to_process.pop()
|
|
if child in task_dependencies:
|
|
continue
|
|
|
|
task_dependencies[child] = set()
|
|
for _, dep in child.arguments().items():
|
|
task_dependencies[child].add(dep)
|
|
to_process.append(dep)
|
|
return task_dependencies
|
|
|