Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Provides cluster and tools configuration across clusters (slurm, dora, utilities). | |
""" | |
import logging | |
import os | |
from pathlib import Path | |
import re | |
import typing as tp | |
import omegaconf | |
from .utils.cluster import _guess_cluster_type | |
logger = logging.getLogger(__name__) | |
class AudioCraftEnvironment: | |
"""Environment configuration for teams and clusters. | |
AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment | |
or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment | |
provides pointers to a reference folder resolved automatically across clusters that is shared across team members, | |
allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically | |
map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters. | |
The cluster type is identified automatically and base configuration file is read from config/teams.yaml. | |
Use the following environment variables to specify the cluster, team or configuration: | |
AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type | |
cannot be inferred automatically. | |
AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration. | |
If not set, configuration is read from config/teams.yaml. | |
AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team. | |
Cluster configuration are shared across teams to match compute allocation, | |
specify your cluster configuration in the configuration file under a key mapping | |
your team name. | |
""" | |
_instance = None | |
DEFAULT_TEAM = "default" | |
def __init__(self) -> None: | |
"""Loads configuration.""" | |
self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM) | |
cluster_type = _guess_cluster_type() | |
cluster = os.getenv( | |
"AUDIOCRAFT_CLUSTER", cluster_type.value | |
) | |
logger.info("Detecting cluster type %s", cluster_type) | |
self.cluster: str = cluster | |
config_path = os.getenv( | |
"AUDIOCRAFT_CONFIG", | |
Path(__file__) | |
.parent.parent.joinpath("config/teams", self.team) | |
.with_suffix(".yaml"), | |
) | |
self.config = omegaconf.OmegaConf.load(config_path) | |
self._dataset_mappers = [] | |
cluster_config = self._get_cluster_config() | |
if "dataset_mappers" in cluster_config: | |
for pattern, repl in cluster_config["dataset_mappers"].items(): | |
regex = re.compile(pattern) | |
self._dataset_mappers.append((regex, repl)) | |
def _get_cluster_config(self) -> omegaconf.DictConfig: | |
assert isinstance(self.config, omegaconf.DictConfig) | |
return self.config[self.cluster] | |
def instance(cls): | |
if cls._instance is None: | |
cls._instance = cls() | |
return cls._instance | |
def reset(cls): | |
"""Clears the environment and forces a reload on next invocation.""" | |
cls._instance = None | |
def get_team(cls) -> str: | |
"""Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var. | |
If not defined, defaults to "labs". | |
""" | |
return cls.instance().team | |
def get_cluster(cls) -> str: | |
"""Gets the detected cluster. | |
This value can be overridden by the AUDIOCRAFT_CLUSTER env var. | |
""" | |
return cls.instance().cluster | |
def get_dora_dir(cls) -> Path: | |
"""Gets the path to the dora directory for the current team and cluster. | |
Value is overridden by the AUDIOCRAFT_DORA_DIR env var. | |
""" | |
cluster_config = cls.instance()._get_cluster_config() | |
dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"]) | |
logger.warning(f"Dora directory: {dora_dir}") | |
return Path(dora_dir) | |
def get_reference_dir(cls) -> Path: | |
"""Gets the path to the reference directory for the current team and cluster. | |
Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var. | |
""" | |
cluster_config = cls.instance()._get_cluster_config() | |
return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"])) | |
def get_slurm_exclude(cls) -> tp.Optional[str]: | |
"""Get the list of nodes to exclude for that cluster.""" | |
cluster_config = cls.instance()._get_cluster_config() | |
return cluster_config.get("slurm_exclude") | |
def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str: | |
"""Gets the requested partitions for the current team and cluster as a comma-separated string. | |
Args: | |
partition_types (list[str], optional): partition types to retrieve. Values must be | |
from ['global', 'team']. If not provided, the global partition is returned. | |
""" | |
if not partition_types: | |
partition_types = ["global"] | |
cluster_config = cls.instance()._get_cluster_config() | |
partitions = [ | |
cluster_config["partitions"][partition_type] | |
for partition_type in partition_types | |
] | |
return ",".join(partitions) | |
def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path: | |
"""Converts reference placeholder in path with configured reference dir to resolve paths. | |
Args: | |
path (str or Path): Path to resolve. | |
Returns: | |
Path: Resolved path. | |
""" | |
path = str(path) | |
if path.startswith("//reference"): | |
reference_dir = cls.get_reference_dir() | |
logger.warn(f"Reference directory: {reference_dir}") | |
assert ( | |
reference_dir.exists() and reference_dir.is_dir() | |
), f"Reference directory does not exist: {reference_dir}." | |
path = re.sub("^//reference", str(reference_dir), path) | |
return Path(path) | |
def apply_dataset_mappers(cls, path: str) -> str: | |
"""Applies dataset mapping regex rules as defined in the configuration. | |
If no rules are defined, the path is returned as-is. | |
""" | |
instance = cls.instance() | |
for pattern, repl in instance._dataset_mappers: | |
path = pattern.sub(repl, path) | |
return path | |