import argparse import torch from transformers import CognitivessConfig, CognitivessForCausalLM def convert_cognitivess_checkpoint_to_hf(model_dir, save_dir): config = CognitivessConfig.from_pretrained(model_dir) model = CognitivessForCausalLM(config) # Load the model weights from the Cognitivess checkpoint state_dict = torch.load(f"{model_dir}/pytorch_model.bin", map_location="cpu") model.load_state_dict(state_dict) # Save the model in Hugging Face format model.save_pretrained(save_dir) config.save_pretrained(save_dir) print(f"Model converted and saved to {save_dir}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_dir", type=str, required=True, help="Path to the Cognitivess model directory") parser.add_argument("--save_dir", type=str, required=True, help="Path to the directory to save the converted model") args = parser.parse_args() convert_cognitivess_checkpoint_to_hf(args.model_dir, args.save_dir)