File size: 2,630 Bytes
7362797
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Chameleon License Agreement.

import hashlib
import subprocess
import sys
from pathlib import Path


def download_file(url: str, output_path: Path):
    print(f"Downloading {output_path}")
    subprocess.check_call(["wget", "--continue", url, "-O", str(output_path)])


def validate_checksum(folder: Path):
    chks_parts = (folder / "checklist.chk").read_text().split()
    for expected_checksum, file in zip(chks_parts[::2], chks_parts[1::2]):
        file_path = folder / file
        checksum = hashlib.md5(file_path.read_bytes()).hexdigest()
        if checksum != expected_checksum:
            print(f"Checksum mismatch for {file_path}")
            sys.exit(1)


def download_tokenizer(presigned_url: str, target_folder: Path):
    tokenizer_folder = target_folder / "tokenizer"
    tokenizer_folder.mkdir(parents=True, exist_ok=True)

    for filename in [
        "text_tokenizer.json",
        "vqgan.ckpt",
        "vqgan.yaml",
        "checklist.chk",
    ]:
        download_file(
            presigned_url.replace("*", f"tokenizer/{filename}"),
            tokenizer_folder / filename,
        )

    validate_checksum(tokenizer_folder)


def download_model(presigned_url: str, target_folder: Path, model: str):
    model_folder = target_folder / "models" / model
    model_folder.mkdir(parents=True, exist_ok=True)

    download_filenames = ["params.json", "consolidate_params.json", "checklist.chk"]

    if model == "7b":
        download_filenames += ["consolidated.pth"]
    elif model == "30b":
        download_filenames += [f"consolidated.{i:02}.pth" for i in range(4)]
    else:
        print(f"Unknown model: {model}")
        sys.exit(1)

    for filename in download_filenames:
        download_file(
            presigned_url.replace("*", f"{model}/{filename}"),
            model_folder / filename,
        )

    validate_checksum(model_folder)


def main():
    presigned_url = (
        sys.argv[1] if len(sys.argv) > 1 else input("Enter the URL from email: ")
    )

    target_folder = Path("./data")
    target_folder.mkdir(parents=True, exist_ok=True)

    download_tokenizer(presigned_url, target_folder)

    model_size = input(
        "Enter the list of models to download without spaces (7B,30B), or press Enter for all: "
    )
    if not model_size:
        model_size = "7B,30B"

    for model in model_size.split(","):
        model = model.strip().lower()
        download_model(presigned_url, target_folder, model)


if __name__ == "__main__":
    main()