vardaan123 commited on
Commit
4b4b296
·
1 Parent(s): 01981f0

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +74 -0
utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import csv
4
+ import argparse
5
+ import random
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import torch
9
+ import pandas as pd
10
+ import re
11
+
12
+ from torch.utils.data import DataLoader
13
+
14
+ try:
15
+ from torch_geometric.data import Batch
16
+ except ImportError:
17
+ pass
18
+
19
+ def set_seed(seed):
20
+ """Sets seed"""
21
+ if torch.cuda.is_available():
22
+ torch.cuda.manual_seed(seed)
23
+ torch.manual_seed(seed)
24
+ np.random.seed(seed)
25
+ random.seed(seed)
26
+ torch.backends.cudnn.benchmark = False
27
+ torch.backends.cudnn.deterministic = True
28
+
29
+
30
+ def move_to(obj, device):
31
+ if isinstance(obj, dict):
32
+ return {k: move_to(v, device) for k, v in obj.items()}
33
+ elif isinstance(obj, list):
34
+ return [move_to(v, device) for v in obj]
35
+ elif isinstance(obj, float) or isinstance(obj, int):
36
+ return obj
37
+ else:
38
+ # Assume obj is a Tensor or other type
39
+ # (like Batch, for MolPCBA) that supports .to(device)
40
+ return obj.to(device)
41
+
42
+ def detach_and_clone(obj):
43
+ if torch.is_tensor(obj):
44
+ return obj.detach().clone()
45
+ elif isinstance(obj, dict):
46
+ return {k: detach_and_clone(v) for k, v in obj.items()}
47
+ elif isinstance(obj, list):
48
+ return [detach_and_clone(v) for v in obj]
49
+ elif isinstance(obj, float) or isinstance(obj, int):
50
+ return obj
51
+ else:
52
+ raise TypeError("Invalid type for detach_and_clone")
53
+
54
+ def collate_list(vec):
55
+ """
56
+ If vec is a list of Tensors, it concatenates them all along the first dimension.
57
+
58
+ If vec is a list of lists, it joins these lists together, but does not attempt to
59
+ recursively collate. This allows each element of the list to be, e.g., its own dict.
60
+
61
+ If vec is a list of dicts (with the same keys in each dict), it returns a single dict
62
+ with the same keys. For each key, it recursively collates all entries in the list.
63
+ """
64
+ if not isinstance(vec, list):
65
+ raise TypeError("collate_list must take in a list")
66
+ elem = vec[0]
67
+ if torch.is_tensor(elem):
68
+ return torch.cat(vec)
69
+ elif isinstance(elem, list):
70
+ return [obj for sublist in vec for obj in sublist]
71
+ elif isinstance(elem, dict):
72
+ return {k: collate_list([d[k] for d in vec]) for k in elem}
73
+ else:
74
+ raise TypeError("Elements of the list to collate must be tensors or dicts.")