LTX-Video-Playground / scripts /to_safetensors.py
Sapir's picture
Examples: update and fix scripts.
e46ff5e
raw
history blame
4.31 kB
import argparse
from pathlib import Path
from typing import Any, Dict
import safetensors.torch
import torch
import json
import shutil
def load_text_encoder(index_path: Path) -> Dict:
with open(index_path, 'r') as f:
index: Dict = json.load(f)
loaded_tensors = {}
for part_file in set(index.get("weight_map", {}).values()):
tensors = safetensors.torch.load_file(index_path.parent / part_file, device='cpu')
for tensor_name in tensors:
loaded_tensors[tensor_name] = tensors[tensor_name]
return loaded_tensors
def convert_unet(unet: Dict, add_prefix=True) -> Dict:
if add_prefix:
return {"model.diffusion_model." + key: value for key, value in unet.items()}
return unet
def convert_vae(vae_path: Path, add_prefix=True) -> Dict:
state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True)
stats_path = vae_path / "per_channel_statistics.json"
if stats_path.exists():
with open(stats_path, 'r') as f:
data = json.load(f)
transposed_data = list(zip(*data["data"]))
data_dict = {
f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor(vals)
for col, vals in zip(data["columns"], transposed_data)
}
else:
data_dict = {}
result = {("vae." if add_prefix else "") + key: value for key, value in state_dict.items()}
result.update(data_dict)
return result
def convert_encoder(encoder: Dict) -> Dict:
return {"text_encoders.t5xxl.transformer." + key: value for key, value in encoder.items()}
def save_config(config_src: str, config_dst: str):
shutil.copy(config_src, config_dst)
def load_vae_config(vae_path: Path) -> str:
config_path = vae_path / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"VAE config file {config_path} not found.")
return str(config_path)
def main(unet_path: str, vae_path: str, out_path: str, mode: str,
unet_config_path: str = None, scheduler_config_path: str = None) -> None:
unet = convert_unet(torch.load(unet_path, weights_only=True), add_prefix=(mode == 'single'))
# Load VAE from directory and config
vae = convert_vae(Path(vae_path), add_prefix=(mode == 'single'))
vae_config_path = load_vae_config(Path(vae_path))
if mode == 'single':
result = {**unet, **vae}
safetensors.torch.save_file(result, out_path)
elif mode == 'separate':
# Create directories for unet, vae, and scheduler
unet_dir = Path(out_path) / 'unet'
vae_dir = Path(out_path) / 'vae'
scheduler_dir = Path(out_path) / 'scheduler'
unet_dir.mkdir(parents=True, exist_ok=True)
vae_dir.mkdir(parents=True, exist_ok=True)
scheduler_dir.mkdir(parents=True, exist_ok=True)
# Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
safetensors.torch.save_file(unet, unet_dir / 'diffusion_pytorch_model.safetensors')
safetensors.torch.save_file(vae, vae_dir / 'diffusion_pytorch_model.safetensors')
# Save config files for unet, vae, and scheduler
if unet_config_path:
save_config(unet_config_path, unet_dir / 'config.json')
if vae_config_path:
save_config(vae_config_path, vae_dir / 'config.json')
if scheduler_config_path:
save_config(scheduler_config_path, scheduler_dir / 'scheduler_config.json')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--unet_path', '-u', type=str, default='unet/ema-002.pt')
parser.add_argument('--vae_path', '-v', type=str, default='vae/')
parser.add_argument('--out_path', '-o', type=str, default='xora.safetensors')
parser.add_argument('--mode', '-m', type=str, choices=['single', 'separate'], default='single',
help="Choose 'single' for the original behavior, 'separate' to save unet and vae separately.")
parser.add_argument('--unet_config_path', type=str, help="Path to the UNet config file (for separate mode)")
parser.add_argument('--scheduler_config_path', type=str,
help="Path to the Scheduler config file (for separate mode)")
args = parser.parse_args()
main(**args.__dict__)