|
import torch |
|
import safetensors.torch |
|
from transformers import T5Tokenizer, T5EncoderModel |
|
|
|
|
|
input_diffusion = "mp_rank_00_model_states.pt" |
|
|
|
|
|
input_bert = "pytorch_model.bin" |
|
|
|
|
|
|
|
|
|
input_vae = "sdxl_vae.safetensors" |
|
|
|
output = "freeway_animation_demo_hunyuan_dit.safetensors" |
|
|
|
|
|
|
|
mt5 = T5EncoderModel.from_pretrained("google/mt5-xl") |
|
tokenizer = T5Tokenizer.from_pretrained("google/mt5-xl") |
|
|
|
sp_model = torch.ByteTensor(list(tokenizer.sp_model.serialized_model_proto())) |
|
t5_sd = mt5.state_dict() |
|
|
|
out_sd = {} |
|
|
|
out_sd["text_encoders.mt5xl.spiece_model"] = sp_model |
|
|
|
for k in t5_sd: |
|
out_sd["text_encoders.mt5xl.transformer.{}".format(k)] = t5_sd[k].half() |
|
|
|
bert_sd = torch.load(input_bert, weights_only=True) |
|
for k in bert_sd: |
|
if not k.startswith("visual."): |
|
out_sd["text_encoders.hydit_clip.transformer.{}".format(k)] = bert_sd[k].half() |
|
|
|
del bert_sd, mt5, t5_sd |
|
|
|
hydit = torch.load(input_diffusion, weights_only=False)['ema'] |
|
for k in hydit: |
|
out_sd["model.{}".format(k)] = hydit[k].half() |
|
|
|
|
|
vae_sd = safetensors.torch.load_file(input_vae) |
|
|
|
for k in vae_sd: |
|
out_sd["vae.{}".format(k)] = vae_sd[k].half() |
|
|
|
safetensors.torch.save_file(out_sd, output) |
|
|