File size: 826 Bytes
07423df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from functools import partial
from typing import Any, List
import bitsandbytes as bnb
from torch import optim
__all__ = ["Optimizers"]
class Optimizers:
"""Optimizers factory."""
_optimizers = {
"Adam": optim.Adam,
"AdamW": optim.AdamW,
"SGD": partial(optim.SGD, momentum=0.9, nesterov=True),
"RMSprop": partial(optim.RMSprop, momentum=0.9, alpha=0.9),
"Adadelta": optim.Adadelta,
"AdamW8bit": bnb.optim.Adam8bit,
}
@classmethod
def names(cls) -> List[str]:
return sorted(cls._optimizers.keys())
@classmethod
def get(cls, name: str) -> Any:
"""Access to Optimizers.
Args:
name: optimizer name
Returns:
A class to build the Optimizer
"""
return cls._optimizers.get(name)
|