|
import dataclasses |
|
from typing import Optional |
|
|
|
|
|
@dataclasses.dataclass |
|
class BelleParam: |
|
num_heads: int = 32 |
|
size_per_head: int = 128 |
|
inter_size: int = 16384 |
|
num_layers: int = 30 |
|
vocab_size: int = 250880 |
|
start_id: Optional[int] = 1 |
|
end_id: Optional[int] = 2 |
|
tensor_para_size: int = 1 |
|
pipeline_para_size: int = 1 |
|
remove_padding: bool = True |
|
shared_contexts_ratio: float = 1.0 |
|
weights_data_type: str = "fp16" |
|
|
|
def __post_init__(self): |
|
if not 0.0 <= self.shared_contexts_ratio <= 1.0: |
|
raise ValueError( |
|
f'Got an invalid value of shared_context_ratio ' |
|
f'{self.shared_contexts_ratio} - range: [0.0, 1.0]') |
|
|
|
def asdict(self): |
|
return dataclasses.asdict(self) |
|
|
|
|
|
BELLE_PARAM = BelleParam() |
|
import os |
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
LIB_SO_PATH = os.path.join(current_dir, 'libth_transformer.so') |
|
|