xxxpo13 commited on
Commit
c38f042
·
verified ·
1 Parent(s): fee70e0

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +453 -92
utils.py CHANGED
@@ -1,96 +1,457 @@
1
  import os
2
  import torch
 
 
 
3
  import torch.distributed as dist
4
- import torch.nn as nn
5
- from torch.utils.data import DataLoader, DistributedSampler
6
- from torchvision import datasets, transforms
7
- from torch.nn.parallel import DistributedDataParallel as DDP
8
-
9
- # Set your model class here (for demonstration, we'll create a simple CNN)
10
- class SimpleCNN(nn.Module):
11
- def __init__(self):
12
- super(SimpleCNN, self).__init__()
13
- self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
14
- self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
15
- self.fc1 = nn.Linear(64 * 7 * 7, 128)
16
- self.fc2 = nn.Linear(128, 10)
17
-
18
- def forward(self, x):
19
- x = nn.ReLU()(self.conv1(x))
20
- x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
21
- x = nn.ReLU()(self.conv2(x))
22
- x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
23
- x = x.view(x.size(0), -1)
24
- x = nn.ReLU()(self.fc1(x))
25
- x = self.fc2(x)
26
- return x
27
-
28
- def init_distributed_mode():
29
- if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
30
- rank = int(os.environ['RANK'])
31
- world_size = int(os.environ['WORLD_SIZE'])
32
- dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
33
- torch.cuda.set_device(rank % torch.cuda.device_count())
34
- print(f"Initialized distributed mode: rank {rank}, world size {world_size}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  else:
36
- print("Not using distributed mode")
37
- rank = 0
38
- world_size = 1
39
- return rank, world_size
40
-
41
- def main():
42
- # Initialize the distributed mode
43
- rank, world_size = init_distributed_mode()
44
-
45
- # Set up data transformations
46
- transform = transforms.Compose([
47
- transforms.ToTensor(),
48
- transforms.Normalize((0.5,), (0.5,))
49
- ])
50
-
51
- # Load dataset
52
- train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
53
- train_sampler = DistributedSampler(train_dataset)
54
- train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler)
55
-
56
- # Initialize model
57
- model = SimpleCNN()
58
- device = torch.device(f'cuda:{rank % torch.cuda.device_count()}')
59
- model.to(device)
60
-
61
- # Wrap the model with DDP
62
- if world_size > 1:
63
- model = DDP(model, device_ids=[rank], output_device=rank)
64
-
65
- # Set up the optimizer and loss function
66
- optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
67
- criterion = nn.CrossEntropyLoss()
68
-
69
- # Training loop
70
- for epoch in range(10): # Train for 10 epochs
71
- train_sampler.set_epoch(epoch) # Shuffle data every epoch
72
- running_loss = 0.0
73
-
74
- for inputs, targets in train_loader:
75
- inputs, targets = inputs.to(device), targets.to(device)
76
-
77
- # Forward pass
78
- outputs = model(inputs)
79
- loss = criterion(outputs, targets)
80
-
81
- # Backward pass and optimization
82
- optimizer.zero_grad()
83
- loss.backward()
84
- optimizer.step()
85
-
86
- running_loss += loss.item()
87
-
88
- if rank == 0: # Only print from the main process
89
- print(f'Epoch [{epoch + 1}/10], Loss: {running_loss / len(train_loader):.4f}')
90
-
91
- # Clean up distributed training
92
- if world_size > 1:
93
- dist.destroy_process_group()
94
-
95
- if __name__ == '__main__':
96
- main()
 
1
  import os
2
  import torch
3
+ import PIL.Image
4
+ import numpy as np
5
+ from torch import nn
6
  import torch.distributed as dist
7
+ import timm.models.hub as timm_hub
8
+
9
+ """Modified from https://github.com/CompVis/taming-transformers.git"""
10
+
11
+ import hashlib
12
+ import requests
13
+ from tqdm import tqdm
14
+ try:
15
+ import piq
16
+ except:
17
+ pass
18
+
19
+ _CONTEXT_PARALLEL_GROUP = None
20
+ _CONTEXT_PARALLEL_SIZE = None
21
+
22
+
23
+ def is_dist_avail_and_initialized():
24
+ if not dist.is_available():
25
+ return False
26
+ if not dist.is_initialized():
27
+ return False
28
+ return True
29
+
30
+
31
+ def get_world_size():
32
+ if not is_dist_avail_and_initialized():
33
+ return 1
34
+ return dist.get_world_size()
35
+
36
+
37
+ def get_rank():
38
+ if not is_dist_avail_and_initialized():
39
+ return 0
40
+ return dist.get_rank()
41
+
42
+
43
+ def is_main_process():
44
+ return get_rank() == 0
45
+
46
+
47
+ def is_context_parallel_initialized():
48
+ if _CONTEXT_PARALLEL_GROUP is None:
49
+ return False
50
+ else:
51
+ return True
52
+
53
+
54
+ def set_context_parallel_group(size, group):
55
+ global _CONTEXT_PARALLEL_GROUP
56
+ global _CONTEXT_PARALLEL_SIZE
57
+ _CONTEXT_PARALLEL_GROUP = group
58
+ _CONTEXT_PARALLEL_SIZE = size
59
+
60
+
61
+ def initialize_context_parallel(context_parallel_size):
62
+ global _CONTEXT_PARALLEL_GROUP
63
+ global _CONTEXT_PARALLEL_SIZE
64
+
65
+ assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized"
66
+ _CONTEXT_PARALLEL_SIZE = context_parallel_size
67
+
68
+ rank = torch.distributed.get_rank()
69
+ world_size = torch.distributed.get_world_size()
70
+
71
+ for i in range(0, world_size, context_parallel_size):
72
+ ranks = range(i, i + context_parallel_size)
73
+ group = torch.distributed.new_group(ranks)
74
+ if rank in ranks:
75
+ _CONTEXT_PARALLEL_GROUP = group
76
+ break
77
+
78
+
79
+ def get_context_parallel_group():
80
+ assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
81
+
82
+ return _CONTEXT_PARALLEL_GROUP
83
+
84
+
85
+ def get_context_parallel_world_size():
86
+ assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
87
+
88
+ return _CONTEXT_PARALLEL_SIZE
89
+
90
+
91
+ def get_context_parallel_rank():
92
+ assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
93
+
94
+ rank = get_rank()
95
+ cp_rank = rank % _CONTEXT_PARALLEL_SIZE
96
+ return cp_rank
97
+
98
+
99
+ def get_context_parallel_group_rank():
100
+ assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
101
+
102
+ rank = get_rank()
103
+ cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE
104
+
105
+ return cp_group_rank
106
+
107
+
108
+ def download_cached_file(url, check_hash=True, progress=False):
109
+ """
110
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
111
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
112
+ """
113
+
114
+ def get_cached_file_path():
115
+ # a hack to sync the file path across processes
116
+ parts = torch.hub.urlparse(url)
117
+ filename = os.path.basename(parts.path)
118
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
119
+
120
+ return cached_file
121
+
122
+ if is_main_process():
123
+ timm_hub.download_cached_file(url, check_hash, progress)
124
+
125
+ if is_dist_avail_and_initialized():
126
+ dist.barrier()
127
+
128
+ return get_cached_file_path()
129
+
130
+
131
+ def convert_weights_to_fp16(model: nn.Module):
132
+ """Convert applicable model parameters to fp16"""
133
+
134
+ def _convert_weights_to_fp16(l):
135
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
136
+ l.weight.data = l.weight.data.to(torch.float16)
137
+ if l.bias is not None:
138
+ l.bias.data = l.bias.data.to(torch.float16)
139
+
140
+ model.apply(_convert_weights_to_fp16)
141
+
142
+
143
+ def convert_weights_to_bf16(model: nn.Module):
144
+ """Convert applicable model parameters to fp16"""
145
+
146
+ def _convert_weights_to_bf16(l):
147
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
148
+ l.weight.data = l.weight.data.to(torch.bfloat16)
149
+ if l.bias is not None:
150
+ l.bias.data = l.bias.data.to(torch.bfloat16)
151
+
152
+ model.apply(_convert_weights_to_bf16)
153
+
154
+
155
+ def save_result(result, result_dir, filename, remove_duplicate="", save_format='json'):
156
+ import json
157
+ import jsonlines
158
+ print("Dump result")
159
+
160
+ # Make the temp dir for saving results
161
+ if not os.path.exists(result_dir):
162
+ if is_main_process():
163
+ os.makedirs(result_dir)
164
+ if is_dist_avail_and_initialized():
165
+ torch.distributed.barrier()
166
+
167
+ result_file = os.path.join(
168
+ result_dir, "%s_rank%d.json" % (filename, get_rank())
169
+ )
170
+
171
+ final_result_file = os.path.join(result_dir, f"{filename}.{save_format}")
172
+
173
+ json.dump(result, open(result_file, "w"))
174
+
175
+ if is_dist_avail_and_initialized():
176
+ torch.distributed.barrier()
177
+
178
+ if is_main_process():
179
+ # print("rank %d starts merging results." % get_rank())
180
+ # combine results from all processes
181
+ result = []
182
+
183
+ for rank in range(get_world_size()):
184
+ result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, rank))
185
+ res = json.load(open(result_file, "r"))
186
+ result += res
187
+
188
+ # print("Remove duplicate")
189
+ if remove_duplicate:
190
+ result_new = []
191
+ id_set = set()
192
+ for res in result:
193
+ if res[remove_duplicate] not in id_set:
194
+ id_set.add(res[remove_duplicate])
195
+ result_new.append(res)
196
+ result = result_new
197
+
198
+ if save_format == 'json':
199
+ json.dump(result, open(final_result_file, "w"))
200
+ else:
201
+ assert save_format == 'jsonl', "Only support json adn jsonl format"
202
+ with jsonlines.open(final_result_file, "w") as writer:
203
+ writer.write_all(result)
204
+
205
+ # print("result file saved to %s" % final_result_file)
206
+
207
+ return final_result_file
208
+
209
+
210
+ # resizing utils
211
+ # TODO: clean up later
212
+ def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
213
+ h, w = input.shape[-2:]
214
+ factors = (h / size[0], w / size[1])
215
+
216
+ # First, we have to determine sigma
217
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
218
+ sigmas = (
219
+ max((factors[0] - 1.0) / 2.0, 0.001),
220
+ max((factors[1] - 1.0) / 2.0, 0.001),
221
+ )
222
+
223
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
224
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
225
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
226
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
227
+
228
+ # Make sure it is odd
229
+ if (ks[0] % 2) == 0:
230
+ ks = ks[0] + 1, ks[1]
231
+
232
+ if (ks[1] % 2) == 0:
233
+ ks = ks[0], ks[1] + 1
234
+
235
+ input = _gaussian_blur2d(input, ks, sigmas)
236
+
237
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
238
+ return output
239
+
240
+
241
+ def _compute_padding(kernel_size):
242
+ """Compute padding tuple."""
243
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
244
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
245
+ if len(kernel_size) < 2:
246
+ raise AssertionError(kernel_size)
247
+ computed = [k - 1 for k in kernel_size]
248
+
249
+ # for even kernels we need to do asymmetric padding :(
250
+ out_padding = 2 * len(kernel_size) * [0]
251
+
252
+ for i in range(len(kernel_size)):
253
+ computed_tmp = computed[-(i + 1)]
254
+
255
+ pad_front = computed_tmp // 2
256
+ pad_rear = computed_tmp - pad_front
257
+
258
+ out_padding[2 * i + 0] = pad_front
259
+ out_padding[2 * i + 1] = pad_rear
260
+
261
+ return out_padding
262
+
263
+
264
+ def _filter2d(input, kernel):
265
+ # prepare kernel
266
+ b, c, h, w = input.shape
267
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
268
+
269
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
270
+
271
+ height, width = tmp_kernel.shape[-2:]
272
+
273
+ padding_shape: list[int] = _compute_padding([height, width])
274
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
275
+
276
+ # kernel and input tensor reshape to align element-wise or batch-wise params
277
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
278
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
279
+
280
+ # convolve the tensor with the kernel.
281
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
282
+
283
+ out = output.view(b, c, h, w)
284
+ return out
285
+
286
+
287
+ def _gaussian(window_size: int, sigma):
288
+ if isinstance(sigma, float):
289
+ sigma = torch.tensor([[sigma]])
290
+
291
+ batch_size = sigma.shape[0]
292
+
293
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
294
+
295
+ if window_size % 2 == 0:
296
+ x = x + 0.5
297
+
298
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
299
+
300
+ return gauss / gauss.sum(-1, keepdim=True)
301
+
302
+
303
+ def _gaussian_blur2d(input, kernel_size, sigma):
304
+ if isinstance(sigma, tuple):
305
+ sigma = torch.tensor([sigma], dtype=input.dtype)
306
+ else:
307
+ sigma = sigma.to(dtype=input.dtype)
308
+
309
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
310
+ bs = sigma.shape[0]
311
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
312
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
313
+ out_x = _filter2d(input, kernel_x[..., None, :])
314
+ out = _filter2d(out_x, kernel_y[..., None])
315
+
316
+ return out
317
+
318
+
319
+ URL_MAP = {
320
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
321
+ }
322
+
323
+ CKPT_MAP = {
324
+ "vgg_lpips": "vgg.pth"
325
+ }
326
+
327
+ MD5_MAP = {
328
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
329
+ }
330
+
331
+
332
+ def download(url, local_path, chunk_size=1024):
333
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
334
+ with requests.get(url, stream=True) as r:
335
+ total_size = int(r.headers.get("content-length", 0))
336
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
337
+ with open(local_path, "wb") as f:
338
+ for data in r.iter_content(chunk_size=chunk_size):
339
+ if data:
340
+ f.write(data)
341
+ pbar.update(chunk_size)
342
+
343
+
344
+ def md5_hash(path):
345
+ with open(path, "rb") as f:
346
+ content = f.read()
347
+ return hashlib.md5(content).hexdigest()
348
+
349
+
350
+ def get_ckpt_path(name, root, check=False):
351
+ assert name in URL_MAP
352
+ path = os.path.join(root, CKPT_MAP[name])
353
+ print(md5_hash(path))
354
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
355
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
356
+ download(URL_MAP[name], path)
357
+ md5 = md5_hash(path)
358
+ assert md5 == MD5_MAP[name], md5
359
+ return path
360
+
361
+
362
+ class KeyNotFoundError(Exception):
363
+ def __init__(self, cause, keys=None, visited=None):
364
+ self.cause = cause
365
+ self.keys = keys
366
+ self.visited = visited
367
+ messages = list()
368
+ if keys is not None:
369
+ messages.append("Key not found: {}".format(keys))
370
+ if visited is not None:
371
+ messages.append("Visited: {}".format(visited))
372
+ messages.append("Cause:\n{}".format(cause))
373
+ message = "\n".join(messages)
374
+ super().__init__(message)
375
+
376
+
377
+ def retrieve(
378
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
379
+ ):
380
+ """Given a nested list or dict return the desired value at key expanding
381
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
382
+ is done in-place.
383
+
384
+ Parameters
385
+ ----------
386
+ list_or_dict : list or dict
387
+ Possibly nested list or dictionary.
388
+ key : str
389
+ key/to/value, path like string describing all keys necessary to
390
+ consider to get to the desired value. List indices can also be
391
+ passed here.
392
+ splitval : str
393
+ String that defines the delimiter between keys of the
394
+ different depth levels in `key`.
395
+ default : obj
396
+ Value returned if :attr:`key` is not found.
397
+ expand : bool
398
+ Whether to expand callable nodes on the path or not.
399
+
400
+ Returns
401
+ -------
402
+ The desired value or if :attr:`default` is not ``None`` and the
403
+ :attr:`key` is not found returns ``default``.
404
+
405
+ Raises
406
+ ------
407
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
408
+ ``None``.
409
+ """
410
+
411
+ keys = key.split(splitval)
412
+
413
+ success = True
414
+ try:
415
+ visited = []
416
+ parent = None
417
+ last_key = None
418
+ for key in keys:
419
+ if callable(list_or_dict):
420
+ if not expand:
421
+ raise KeyNotFoundError(
422
+ ValueError(
423
+ "Trying to get past callable node with expand=False."
424
+ ),
425
+ keys=keys,
426
+ visited=visited,
427
+ )
428
+ list_or_dict = list_or_dict()
429
+ parent[last_key] = list_or_dict
430
+
431
+ last_key = key
432
+ parent = list_or_dict
433
+
434
+ try:
435
+ if isinstance(list_or_dict, dict):
436
+ list_or_dict = list_or_dict[key]
437
+ else:
438
+ list_or_dict = list_or_dict[int(key)]
439
+ except (KeyError, IndexError, ValueError) as e:
440
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
441
+
442
+ visited += [key]
443
+ # final expansion of retrieved value
444
+ if expand and callable(list_or_dict):
445
+ list_or_dict = list_or_dict()
446
+ parent[last_key] = list_or_dict
447
+ except KeyNotFoundError as e:
448
+ if default is None:
449
+ raise e
450
+ else:
451
+ list_or_dict = default
452
+ success = False
453
+
454
+ if not pass_success:
455
+ return list_or_dict
456
  else:
457
+ return list_or_dict, success