File size: 2,460 Bytes
01e655b ec6a7d0 01e655b ec6a7d0 01e655b ec6a7d0 01e655b b44532e 01e655b b44532e 01e655b 0129fb6 01e655b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import torch
import random
import numpy as np
from modules.utils import rng
def deterministic(seed=0):
random.seed(seed)
np.random.seed(seed)
torch_rn = rng.convert_np_to_torch(seed)
torch.manual_seed(torch_rn)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(torch_rn)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def is_numeric(obj):
if isinstance(obj, str):
try:
float(obj)
return True
except ValueError:
return False
elif isinstance(obj, (np.integer, np.signedinteger, np.unsignedinteger)):
return True
elif isinstance(obj, np.floating):
return True
elif isinstance(obj, (int, float)):
return True
else:
return False
class SeedContext:
def __init__(self, seed):
assert is_numeric(seed), "Seed must be an number."
try:
self.seed = int(np.clip(int(seed), -1, 2**32 - 1, out=None, dtype=np.int64))
except Exception as e:
raise ValueError(f"Seed must be an integer, but: {type(seed)}")
self.seed = seed
self.state = None
if isinstance(seed, str) and seed.isdigit():
self.seed = int(seed)
if isinstance(self.seed, float):
self.seed = int(self.seed)
if self.seed == -1:
self.seed = random.randint(0, 2**32 - 1)
def __enter__(self):
self.state = (torch.get_rng_state(), random.getstate(), np.random.get_state())
try:
deterministic(self.seed)
except Exception as e:
raise ValueError(
f"Seed must be an integer, but: <{type(self.seed)}> {self.seed}"
)
def __exit__(self, exc_type, exc_value, traceback):
torch.set_rng_state(self.state[0])
random.setstate(self.state[1])
np.random.set_state(self.state[2])
if __name__ == "__main__":
print(is_numeric("1234")) # True
print(is_numeric("12.34")) # True
print(is_numeric("-1234")) # True
print(is_numeric("abc123")) # False
print(is_numeric(np.int32(10))) # True
print(is_numeric(np.float64(10.5))) # True
print(is_numeric(10)) # True
print(is_numeric(10.5)) # True
print(is_numeric(np.int8(10))) # True
print(is_numeric(np.uint64(10))) # True
print(is_numeric(np.float16(10.5))) # True
print(is_numeric([1, 2, 3])) # False
|