File size: 7,088 Bytes
dde5d93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import argparse
import os

import torch
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer

from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline


def main(args):
    # checkpoint from https://huggingface.co/Alpha-VLLM/Lumina-Next-SFT or https://huggingface.co/Alpha-VLLM/Lumina-Next-T2I
    all_sd = load_file(args.origin_ckpt_path, device="cpu")
    converted_state_dict = {}
    # pad token
    converted_state_dict["pad_token"] = all_sd["pad_token"]

    # patch embed
    converted_state_dict["patch_embedder.weight"] = all_sd["x_embedder.weight"]
    converted_state_dict["patch_embedder.bias"] = all_sd["x_embedder.bias"]

    # time and caption embed
    converted_state_dict["time_caption_embed.timestep_embedder.linear_1.weight"] = all_sd["t_embedder.mlp.0.weight"]
    converted_state_dict["time_caption_embed.timestep_embedder.linear_1.bias"] = all_sd["t_embedder.mlp.0.bias"]
    converted_state_dict["time_caption_embed.timestep_embedder.linear_2.weight"] = all_sd["t_embedder.mlp.2.weight"]
    converted_state_dict["time_caption_embed.timestep_embedder.linear_2.bias"] = all_sd["t_embedder.mlp.2.bias"]
    converted_state_dict["time_caption_embed.caption_embedder.0.weight"] = all_sd["cap_embedder.0.weight"]
    converted_state_dict["time_caption_embed.caption_embedder.0.bias"] = all_sd["cap_embedder.0.bias"]
    converted_state_dict["time_caption_embed.caption_embedder.1.weight"] = all_sd["cap_embedder.1.weight"]
    converted_state_dict["time_caption_embed.caption_embedder.1.bias"] = all_sd["cap_embedder.1.bias"]

    for i in range(24):
        # adaln
        converted_state_dict[f"layers.{i}.gate"] = all_sd[f"layers.{i}.attention.gate"]
        converted_state_dict[f"layers.{i}.adaLN_modulation.1.weight"] = all_sd[f"layers.{i}.adaLN_modulation.1.weight"]
        converted_state_dict[f"layers.{i}.adaLN_modulation.1.bias"] = all_sd[f"layers.{i}.adaLN_modulation.1.bias"]

        # qkv
        converted_state_dict[f"layers.{i}.attn1.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"]
        converted_state_dict[f"layers.{i}.attn1.to_k.weight"] = all_sd[f"layers.{i}.attention.wk.weight"]
        converted_state_dict[f"layers.{i}.attn1.to_v.weight"] = all_sd[f"layers.{i}.attention.wv.weight"]

        # cap
        converted_state_dict[f"layers.{i}.attn2.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"]
        converted_state_dict[f"layers.{i}.attn2.to_k.weight"] = all_sd[f"layers.{i}.attention.wk_y.weight"]
        converted_state_dict[f"layers.{i}.attn2.to_v.weight"] = all_sd[f"layers.{i}.attention.wv_y.weight"]

        # output
        converted_state_dict[f"layers.{i}.attn2.to_out.0.weight"] = all_sd[f"layers.{i}.attention.wo.weight"]

        # attention
        # qk norm
        converted_state_dict[f"layers.{i}.attn1.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"]
        converted_state_dict[f"layers.{i}.attn1.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"]

        converted_state_dict[f"layers.{i}.attn1.norm_k.weight"] = all_sd[f"layers.{i}.attention.k_norm.weight"]
        converted_state_dict[f"layers.{i}.attn1.norm_k.bias"] = all_sd[f"layers.{i}.attention.k_norm.bias"]

        converted_state_dict[f"layers.{i}.attn2.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"]
        converted_state_dict[f"layers.{i}.attn2.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"]

        converted_state_dict[f"layers.{i}.attn2.norm_k.weight"] = all_sd[f"layers.{i}.attention.ky_norm.weight"]
        converted_state_dict[f"layers.{i}.attn2.norm_k.bias"] = all_sd[f"layers.{i}.attention.ky_norm.bias"]

        # attention norm
        converted_state_dict[f"layers.{i}.attn_norm1.weight"] = all_sd[f"layers.{i}.attention_norm1.weight"]
        converted_state_dict[f"layers.{i}.attn_norm2.weight"] = all_sd[f"layers.{i}.attention_norm2.weight"]
        converted_state_dict[f"layers.{i}.norm1_context.weight"] = all_sd[f"layers.{i}.attention_y_norm.weight"]

        # feed forward
        converted_state_dict[f"layers.{i}.feed_forward.linear_1.weight"] = all_sd[f"layers.{i}.feed_forward.w1.weight"]
        converted_state_dict[f"layers.{i}.feed_forward.linear_2.weight"] = all_sd[f"layers.{i}.feed_forward.w2.weight"]
        converted_state_dict[f"layers.{i}.feed_forward.linear_3.weight"] = all_sd[f"layers.{i}.feed_forward.w3.weight"]

        # feed forward norm
        converted_state_dict[f"layers.{i}.ffn_norm1.weight"] = all_sd[f"layers.{i}.ffn_norm1.weight"]
        converted_state_dict[f"layers.{i}.ffn_norm2.weight"] = all_sd[f"layers.{i}.ffn_norm2.weight"]

    # final layer
    converted_state_dict["final_layer.linear.weight"] = all_sd["final_layer.linear.weight"]
    converted_state_dict["final_layer.linear.bias"] = all_sd["final_layer.linear.bias"]

    converted_state_dict["final_layer.adaLN_modulation.1.weight"] = all_sd["final_layer.adaLN_modulation.1.weight"]
    converted_state_dict["final_layer.adaLN_modulation.1.bias"] = all_sd["final_layer.adaLN_modulation.1.bias"]

    # Lumina-Next-SFT 2B
    transformer = LuminaNextDiT2DModel(
        sample_size=128,
        patch_size=2,
        in_channels=4,
        hidden_size=2304,
        num_layers=24,
        num_attention_heads=32,
        num_kv_heads=8,
        multiple_of=256,
        ffn_dim_multiplier=None,
        norm_eps=1e-5,
        learn_sigma=True,
        qk_norm=True,
        cross_attention_dim=2048,
        scaling_factor=1.0,
    )
    transformer.load_state_dict(converted_state_dict, strict=True)

    num_model_params = sum(p.numel() for p in transformer.parameters())
    print(f"Total number of transformer parameters: {num_model_params}")

    if args.only_transformer:
        transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
    else:
        scheduler = FlowMatchEulerDiscreteScheduler()

        vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.float32)

        tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
        text_encoder = AutoModel.from_pretrained("google/gemma-2b")

        pipeline = LuminaText2ImgPipeline(
            tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
        )
        pipeline.save_pretrained(args.dump_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
    )
    parser.add_argument(
        "--image_size",
        default=1024,
        type=int,
        choices=[256, 512, 1024],
        required=False,
        help="Image size of pretrained model, either 512 or 1024.",
    )
    parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
    parser.add_argument("--only_transformer", default=True, type=bool, required=True)

    args = parser.parse_args()
    main(args)