Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
src_ckpt = "/checkpoint/wnhsu/w2v/archived/hubert_base_ls960_it2.pt" | |
ref_ckpt = "/checkpoint/wnhsu/w2v/hubert_icassp_oss_v3/iter2_km100-400k-grp-L6/oss.km500_p0_1_s334.pmw1_0.puw0_0.grpnorm.ml10.mp0_8.untie.mxsz250000.ufreq1.maxtok1400000.MU100k.s1337.ngpu32/checkpoint_last.pt" | |
new_ckpt = "/checkpoint/wnhsu/w2v/archived/hubert_base_ls960_it2_updated.pt" | |
def update_state(state): | |
state["model"]["label_embs_concat"] = state["model"].pop("label_embs") | |
state["args"].task = "hubert_pretraining" | |
state["args"].labels = f"['{state['args'].labels}']" | |
return state | |
src_state = torch.load(src_ckpt) | |
src_state = update_state(src_state) | |
torch.save(src_state, new_ckpt) | |