Spaces:
Starting
on
A10G
Starting
on
A10G
# coding=utf-8 | |
# Copyright 2023 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" Conversion script for the LDM checkpoints. """ | |
import argparse | |
import json | |
import os | |
import torch | |
from transformers.file_utils import has_file | |
from diffusers import UNet2DConditionModel, UNet2DModel | |
do_only_config = False | |
do_only_weights = True | |
do_only_renaming = False | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--repo_path", | |
default=None, | |
type=str, | |
required=True, | |
help="The config json file corresponding to the architecture.", | |
) | |
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") | |
args = parser.parse_args() | |
config_parameters_to_change = { | |
"image_size": "sample_size", | |
"num_res_blocks": "layers_per_block", | |
"block_channels": "block_out_channels", | |
"down_blocks": "down_block_types", | |
"up_blocks": "up_block_types", | |
"downscale_freq_shift": "freq_shift", | |
"resnet_num_groups": "norm_num_groups", | |
"resnet_act_fn": "act_fn", | |
"resnet_eps": "norm_eps", | |
"num_head_channels": "attention_head_dim", | |
} | |
key_parameters_to_change = { | |
"time_steps": "time_proj", | |
"mid": "mid_block", | |
"downsample_blocks": "down_blocks", | |
"upsample_blocks": "up_blocks", | |
} | |
subfolder = "" if has_file(args.repo_path, "config.json") else "unet" | |
with open(os.path.join(args.repo_path, subfolder, "config.json"), "r", encoding="utf-8") as reader: | |
text = reader.read() | |
config = json.loads(text) | |
if do_only_config: | |
for key in config_parameters_to_change.keys(): | |
config.pop(key, None) | |
if has_file(args.repo_path, "config.json"): | |
model = UNet2DModel(**config) | |
else: | |
class_name = UNet2DConditionModel if "ldm-text2im-large-256" in args.repo_path else UNet2DModel | |
model = class_name(**config) | |
if do_only_config: | |
model.save_config(os.path.join(args.repo_path, subfolder)) | |
config = dict(model.config) | |
if do_only_renaming: | |
for key, value in config_parameters_to_change.items(): | |
if key in config: | |
config[value] = config[key] | |
del config[key] | |
config["down_block_types"] = [k.replace("UNetRes", "") for k in config["down_block_types"]] | |
config["up_block_types"] = [k.replace("UNetRes", "") for k in config["up_block_types"]] | |
if do_only_weights: | |
state_dict = torch.load(os.path.join(args.repo_path, subfolder, "diffusion_pytorch_model.bin")) | |
new_state_dict = {} | |
for param_key, param_value in state_dict.items(): | |
if param_key.endswith(".op.bias") or param_key.endswith(".op.weight"): | |
continue | |
has_changed = False | |
for key, new_key in key_parameters_to_change.items(): | |
if not has_changed and param_key.split(".")[0] == key: | |
new_state_dict[".".join([new_key] + param_key.split(".")[1:])] = param_value | |
has_changed = True | |
if not has_changed: | |
new_state_dict[param_key] = param_value | |
model.load_state_dict(new_state_dict) | |
model.save_pretrained(os.path.join(args.repo_path, subfolder)) | |