File size: 3,096 Bytes
01e655b ec6a7d0 374f426 01e655b f83b1b7 01e655b ec6a7d0 01e655b ec6a7d0 f83b1b7 01e655b f83b1b7 01e655b b44532e 01e655b b44532e 01e655b f83b1b7 01e655b f83b1b7 01e655b 0129fb6 f83b1b7 0129fb6 374f426 0129fb6 01e655b f83b1b7 01e655b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import torch
import random
import numpy as np
from modules.utils import rng
import logging
logger = logging.getLogger(__name__)
def deterministic(seed=0, cudnn_deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch_rn = rng.convert_np_to_torch(seed)
torch.manual_seed(torch_rn)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(torch_rn)
if cudnn_deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def is_numeric(obj):
if isinstance(obj, str):
try:
float(obj)
return True
except ValueError:
return False
elif isinstance(obj, (np.integer, np.signedinteger, np.unsignedinteger)):
return True
elif isinstance(obj, np.floating):
return True
elif isinstance(obj, (int, float)):
return True
else:
return False
class SeedContext:
def __init__(self, seed, cudnn_deterministic=False):
assert is_numeric(seed), "Seed must be an number."
try:
self.seed = int(np.clip(int(seed), -1, 2**32 - 1, out=None, dtype=np.int64))
except Exception as e:
raise ValueError(f"Seed must be an integer, but: {type(seed)}")
self.seed = seed
self.cudnn_deterministic = cudnn_deterministic
self.state = None
if isinstance(seed, str) and seed.isdigit():
self.seed = int(seed)
if isinstance(self.seed, float):
self.seed = int(self.seed)
if self.seed == -1:
self.seed = random.randint(0, 2**32 - 1)
def __enter__(self):
self.state = (
torch.get_rng_state(),
random.getstate(),
np.random.get_state(),
torch.backends.cudnn.deterministic,
torch.backends.cudnn.benchmark,
)
try:
deterministic(self.seed, cudnn_deterministic=self.cudnn_deterministic)
except Exception as e:
# raise ValueError(
# f"Seed must be an integer, but: <{type(self.seed)}> {self.seed}"
# )
logger.warning(
f"Deterministic field, with: <{type(self.seed)}> {self.seed}"
)
def __exit__(self, exc_type, exc_value, traceback):
torch.set_rng_state(self.state[0])
random.setstate(self.state[1])
np.random.set_state(self.state[2])
torch.backends.cudnn.deterministic = self.state[3]
torch.backends.cudnn.benchmark = self.state[4]
if __name__ == "__main__":
print(is_numeric("1234")) # True
print(is_numeric("12.34")) # True
print(is_numeric("-1234")) # True
print(is_numeric("abc123")) # False
print(is_numeric(np.int32(10))) # True
print(is_numeric(np.float64(10.5))) # True
print(is_numeric(10)) # True
print(is_numeric(10.5)) # True
print(is_numeric(np.int8(10))) # True
print(is_numeric(np.uint64(10))) # True
print(is_numeric(np.float16(10.5))) # True
print(is_numeric([1, 2, 3])) # False
|