File size: 3,407 Bytes
d73e1e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import logging
import os
import sys
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from .norm import SimpleRMSNorm
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("print_config")
BASE_DIM = 256
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def logging_info(string):
if is_main_process():
logger.info(string)
def print_params(**kwargs):
if is_main_process():
logger.info(f"start print config of {kwargs['__class__']}")
for key in kwargs:
if key in ["__class__", "self"]:
continue
logger.info(f"{key}: {kwargs[key]}")
logger.info(f"end print config of {kwargs['__class__']}")
def print_config(config):
if is_main_process():
logger.info(f"start print config of {config['__class__']}")
for key in config:
if key in ["__class__", "self"]:
continue
logger.info(f"{key}: {config[key]}")
logger.info(f"end print config of {config['__class__']}")
def print_module(module):
named_modules = set()
for p in module.named_modules():
named_modules.update([p[0]] )
named_modules = list(named_modules)
string_repr = ''
for p in module.named_parameters():
name = p[0].split('.')[0]
if name not in named_modules:
string_repr = string_repr + '('+ name +'): ' \
+'Tensor(' + str(tuple(p[1].shape))+ ', requires_grad='+ str(p[1].requires_grad) +')\n'
return string_repr.rstrip("\n")
def get_activation_fn(activation):
logger.info(f"activation: {activation}")
if activation == "gelu":
return F.gelu
elif activation == "relu":
return F.relu
elif activation == "elu":
return F.elu
elif activation == "sigmoid":
return F.sigmoid
elif activation == "exp":
def f(x):
with torch.no_grad():
x_max = torch.max(x, dim=-1, keepdims=True).values
y = torch.exp(x - x_max)
return y
return f
elif activation == "leak":
return F.leaky_relu
elif activation == "1+elu":
def f(x):
return 1 + F.elu(x)
return f
elif activation == "2+elu":
def f(x):
return 2 + F.elu(x)
return f
elif activation == "silu" or activation == "swish":
return F.silu
elif activation == "sine":
return torch.sin
else:
logger.info(f"activation: does not support {activation}, use Identity!!!")
return lambda x: x
def get_norm_fn(norm_type):
if norm_type == "simplermsnorm":
return SimpleRMSNorm
else:
return nn.LayerNorm
def convert_to_multiple_of_base(x):
return BASE_DIM * ((x + BASE_DIM - 1) // BASE_DIM) |