File size: 4,415 Bytes
bebbcd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
from pathlib import Path
from typing import Any, Dict
import safetensors.torch
import torch
import json
import shutil


def load_text_encoder(index_path: Path) -> Dict:
    with open(index_path, 'r') as f:
        index: Dict = json.load(f)

    loaded_tensors = {}
    for part_file in set(index.get("weight_map", {}).values()):
        tensors = safetensors.torch.load_file(index_path.parent / part_file, device='cpu')
        for tensor_name in tensors:
            loaded_tensors[tensor_name] = tensors[tensor_name]

    return loaded_tensors


def convert_unet(unet: Dict, add_prefix=True) -> Dict:
    if add_prefix:
        return {"model.diffusion_model." + key: value for key, value in unet.items()}
    return unet


def convert_vae(vae_path: Path, add_prefix=True) -> Dict:
    state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True)
    stats_path = vae_path / "per_channel_statistics.json"
    if stats_path.exists():
        with open(stats_path, 'r') as f:
            data = json.load(f)
        transposed_data = list(zip(*data["data"]))
        data_dict = {
            f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor(vals)
            for col, vals in zip(data["columns"], transposed_data)
        }
    else:
        data_dict = {}

    result = {("vae." if add_prefix else "") + key: value for key, value in state_dict.items()}
    result.update(data_dict)
    return result


def convert_encoder(encoder: Dict) -> Dict:
    return {"text_encoders.t5xxl.transformer." + key: value for key, value in encoder.items()}


def save_config(config_src: str, config_dst: str):
    shutil.copy(config_src, config_dst)


def load_vae_config(vae_path: Path) -> str:
    config_path = vae_path / "config.json"
    if not config_path.exists():
        raise FileNotFoundError(f"VAE config file {config_path} not found.")
    return str(config_path)


def main(unet_path: str, vae_path: str, t5_path: str, out_path: str, mode: str,
         unet_config_path: str = None, scheduler_config_path: str = None) -> None:
    unet = convert_unet(torch.load(unet_path, weights_only=True), add_prefix=(mode == 'single'))

    # Load VAE from directory and config
    vae = convert_vae(Path(vae_path), add_prefix=(mode == 'single'))
    vae_config_path = load_vae_config(Path(vae_path))

    if mode == 'single':
        result = {**unet, **vae}
        safetensors.torch.save_file(result, out_path)
    elif mode == 'separate':
        # Create directories for unet, vae, and scheduler
        unet_dir = Path(out_path) / 'unet'
        vae_dir = Path(out_path) / 'vae'
        scheduler_dir = Path(out_path) / 'scheduler'

        unet_dir.mkdir(parents=True, exist_ok=True)
        vae_dir.mkdir(parents=True, exist_ok=True)
        scheduler_dir.mkdir(parents=True, exist_ok=True)

        # Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
        safetensors.torch.save_file(unet, unet_dir / 'diffusion_pytorch_model.safetensors')
        safetensors.torch.save_file(vae, vae_dir / 'diffusion_pytorch_model.safetensors')

        # Save config files for unet, vae, and scheduler
        if unet_config_path:
            save_config(unet_config_path, unet_dir / 'config.json')
        if vae_config_path:
            save_config(vae_config_path, vae_dir / 'config.json')
        if scheduler_config_path:
            save_config(scheduler_config_path, scheduler_dir / 'scheduler_config.json')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--unet_path', '-u', type=str, default='unet/ema-002.pt')
    parser.add_argument('--vae_path', '-v', type=str, default='vae/')
    parser.add_argument('--t5_path', '-t', type=str, default='t5/PixArt-XL-2-1024-MS/')
    parser.add_argument('--out_path', '-o', type=str, default='xora.safetensors')
    parser.add_argument('--mode', '-m', type=str, choices=['single', 'separate'], default='single',
                        help="Choose 'single' for the original behavior, 'separate' to save unet and vae separately.")
    parser.add_argument('--unet_config_path', type=str, help="Path to the UNet config file (for separate mode)")
    parser.add_argument('--scheduler_config_path', type=str,
                        help="Path to the Scheduler config file (for separate mode)")

    args = parser.parse_args()
    main(**args.__dict__)