reach-vb's picture
reach-vb HF staff
Stereo demo update (#60)
5325fcc
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Utility functions for SLURM configuration and cluster settings.
"""
from enum import Enum
import os
import socket
import typing as tp
import omegaconf
class ClusterType(Enum):
AWS = "aws"
FAIR = "fair"
RSC = "rsc"
LOCAL_DARWIN = "darwin"
DEFAULT = "default" # used for any other cluster.
def _guess_cluster_type() -> ClusterType:
uname = os.uname()
fqdn = socket.getfqdn()
if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn):
return ClusterType.AWS
if fqdn.endswith(".fair"):
return ClusterType.FAIR
if fqdn.endswith(".facebook.com"):
return ClusterType.RSC
if uname.sysname == "Darwin":
return ClusterType.LOCAL_DARWIN
return ClusterType.DEFAULT
def get_cluster_type(
cluster_type: tp.Optional[ClusterType] = None,
) -> tp.Optional[ClusterType]:
if cluster_type is None:
return _guess_cluster_type()
return cluster_type
def get_slurm_parameters(
cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None
) -> omegaconf.DictConfig:
"""Update SLURM parameters in configuration based on cluster type.
If the cluster type is not specify, it infers it automatically.
"""
from ..environment import AudioCraftEnvironment
cluster_type = get_cluster_type(cluster_type)
# apply cluster-specific adjustments
if cluster_type == ClusterType.AWS:
cfg["mem_per_gpu"] = None
cfg["constraint"] = None
cfg["setup"] = []
elif cluster_type == ClusterType.RSC:
cfg["mem_per_gpu"] = None
cfg["setup"] = []
cfg["constraint"] = None
cfg["partition"] = "learn"
slurm_exclude = AudioCraftEnvironment.get_slurm_exclude()
if slurm_exclude is not None:
cfg["exclude"] = slurm_exclude
return cfg