File size: 930 Bytes
3f70f85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import dataclasses
from typing import Optional
@dataclasses.dataclass
class BelleParam:
num_heads: int = 32
size_per_head: int = 128
inter_size: int = 16384
num_layers: int = 30
vocab_size: int = 250880
start_id: Optional[int] = 1
end_id: Optional[int] = 2
tensor_para_size: int = 1
pipeline_para_size: int = 1
remove_padding: bool = True
shared_contexts_ratio: float = 1.0
weights_data_type: str = "fp16"
def __post_init__(self):
if not 0.0 <= self.shared_contexts_ratio <= 1.0:
raise ValueError(
f'Got an invalid value of shared_context_ratio '
f'{self.shared_contexts_ratio} - range: [0.0, 1.0]')
def asdict(self):
return dataclasses.asdict(self)
BELLE_PARAM = BelleParam()
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
LIB_SO_PATH = os.path.join(current_dir, 'libth_transformer.so')
|