leeyunjai commited on
Commit
7263355
1 Parent(s): 7b6efae

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +76 -0
utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch import Tensor
7
+
8
+
9
+ def _max_by_axis(the_list):
10
+ # type: (List[List[int]]) -> List[int]
11
+ maxes = the_list[0]
12
+ for sublist in the_list[1:]:
13
+ for index, item in enumerate(sublist):
14
+ maxes[index] = max(maxes[index], item)
15
+ return maxes
16
+
17
+
18
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
19
+ # TODO make this more general
20
+ if tensor_list[0].ndim == 3:
21
+ # TODO make it support different-sized images
22
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
23
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
24
+ batch_shape = [len(tensor_list)] + max_size
25
+ b, c, h, w = batch_shape
26
+ dtype = tensor_list[0].dtype
27
+ device = tensor_list[0].device
28
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
29
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
30
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
31
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
32
+ m[: img.shape[1], :img.shape[2]] = False
33
+ else:
34
+ raise ValueError('not supported')
35
+ return NestedTensor(tensor, mask)
36
+
37
+
38
+ class NestedTensor(object):
39
+ def __init__(self, tensors, mask: Optional[Tensor]):
40
+ self.tensors = tensors
41
+ self.mask = mask
42
+
43
+ def to(self, device):
44
+ # type: (Device) -> NestedTensor # noqa
45
+ cast_tensor = self.tensors.to(device)
46
+ mask = self.mask
47
+ if mask is not None:
48
+ assert mask is not None
49
+ cast_mask = mask.to(device)
50
+ else:
51
+ cast_mask = None
52
+ return NestedTensor(cast_tensor, cast_mask)
53
+
54
+ def decompose(self):
55
+ return self.tensors, self.mask
56
+
57
+ def __repr__(self):
58
+ return str(self.tensors)
59
+
60
+
61
+ def is_dist_avail_and_initialized():
62
+ if not dist.is_available():
63
+ return False
64
+ if not dist.is_initialized():
65
+ return False
66
+ return True
67
+
68
+
69
+ def get_rank():
70
+ if not is_dist_avail_and_initialized():
71
+ return 0
72
+ return dist.get_rank()
73
+
74
+
75
+ def is_main_process():
76
+ return get_rank() == 0