File size: 826 Bytes
07423df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from functools import partial
from typing import Any, List

import bitsandbytes as bnb
from torch import optim

__all__ = ["Optimizers"]


class Optimizers:
    """Optimizers factory."""

    _optimizers = {
        "Adam": optim.Adam,
        "AdamW": optim.AdamW,
        "SGD": partial(optim.SGD, momentum=0.9, nesterov=True),
        "RMSprop": partial(optim.RMSprop, momentum=0.9, alpha=0.9),
        "Adadelta": optim.Adadelta,
        "AdamW8bit": bnb.optim.Adam8bit,
    }

    @classmethod
    def names(cls) -> List[str]:
        return sorted(cls._optimizers.keys())

    @classmethod
    def get(cls, name: str) -> Any:
        """Access to Optimizers.

        Args:
            name: optimizer name
        Returns:
            A class to build the Optimizer
        """
        return cls._optimizers.get(name)