xinyuanc91 commited on
Commit
f550c43
β€’
1 Parent(s): e13c46d

Delete download.py

Browse files
Files changed (1) hide show
  1. download.py +0 -44
download.py DELETED
@@ -1,44 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Functions for downloading pre-trained DiT models
9
- """
10
- from torchvision.datasets.utils import download_url
11
- import torch
12
- import os
13
-
14
-
15
-
16
- def find_model(model_name):
17
-
18
- checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
19
-
20
- if "ema" in checkpoint: # supports checkpoints from train.py
21
- print('Ema existing!')
22
- checkpoint = checkpoint["ema"]
23
- return checkpoint
24
-
25
-
26
- def download_model(model_name):
27
- """
28
- Downloads a pre-trained DiT model from the web.
29
- """
30
- assert model_name in pretrained_models
31
- local_path = f'pretrained_models/{model_name}'
32
- if not os.path.isfile(local_path):
33
- os.makedirs('pretrained_models', exist_ok=True)
34
- web_path = f'https://dl.fbaipublicfiles.com/DiT/models/{model_name}'
35
- download_url(web_path, 'pretrained_models')
36
- model = torch.load(local_path, map_location=lambda storage, loc: storage)
37
- return model
38
-
39
-
40
- if __name__ == "__main__":
41
- # Download all DiT checkpoints
42
- for model in pretrained_models:
43
- download_model(model)
44
- print('Done.')