File size: 527 Bytes
75fa479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import dataclasses
import torch
from transformers.models.opt.configuration_opt import OPTConfig

@dataclasses.dataclass(frozen=True)
class TricksyConfig:
    opt_config: OPTConfig

    # Percentage of weights to keep on each device
    # e.g. 30% of each MLP layer on GPU
    min_mlp_sparsity_gpu: float = .3
    # e.g. 100% of each MLP layer on CPU
    min_mlp_sparsity_cpu: float = 1

    # If true, cleans up layer's weights after computing forward pass
    full_offload: bool = False

    dtype: torch.dtype = torch.float16