File size: 2,527 Bytes
84595ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import sys
import json
import requests
from tqdm import tqdm

def download_file(url: str, filename: str, download_dir: str):
    """Download a file if it does not already exist."""

    try:
        filepath = os.path.join(download_dir, filename)
        content_length = int(requests.head(url).headers.get("content-length", 0))

        # If file already exists and size matches, skip download
        if os.path.isfile(filepath) and os.path.getsize(filepath) == content_length:
            print(f"{filepath} already exists. Skipping download.")
            return
        if os.path.isfile(filepath) and os.path.getsize(filepath) != content_length:
            print(f"{filepath} already exists but size does not match. Redownloading.")
        else:
            print(f"Downloading {filename} from {url}")

        # Start download, stream=True allows for progress tracking
        response = requests.get(url, stream=True)

        # Check if request was successful
        response.raise_for_status()

        # Create progress bar
        total_size = int(response.headers.get('content-length', 0))
        progress_bar = tqdm(
            total=total_size, 
            unit='iB', 
            unit_scale=True, 
            ncols=70, 
            file=sys.stdout
        )

        # Write response content to file
        with open(filepath, 'wb') as f:
            for data in response.iter_content(chunk_size=1024):
                f.write(data)
                progress_bar.update(len(data))  # Update progress bar

        # Close progress bar
        progress_bar.close()

        # Error handling for incomplete downloads
        if total_size != 0 and progress_bar.n != total_size:
            print("ERROR, something went wrong while downloading")
            raise Exception()


    except Exception as e:
        print(f"An error occurred: {e}")

def main():
    """Main function to download files from URLs in a config file."""
    
    # Get JSON config file path
    script_dir = os.path.dirname(os.path.realpath(__file__))
    config_file_path = os.path.join(script_dir, "download_models.json")

    # Set download directory
    download_dir = "checkpoints"
    os.makedirs(download_dir, exist_ok=True)

    # Load URL and filenames from JSON
    with open(config_file_path, "r") as f:
        config = json.load(f)

    # Download each file specified in config
    for url, filename in config.items():
        download_file(url, filename, download_dir)


if __name__ == "__main__":
    main()