# Copyright (c) OpenMMLab. All rights reserved. import argparse import torch from peft import PeftModel from transformers import (AutoModelForCausalLM, AutoTokenizer, CLIPImageProcessor, CLIPVisionModel) from xtuner.model.utils import LoadWoInit def parse_args(): parser = argparse.ArgumentParser( description='Merge a HuggingFace adapter to base model') parser.add_argument('model_name_or_path', help='model name or path') parser.add_argument('adapter_name_or_path', help='adapter name or path') parser.add_argument( 'save_dir', help='the directory to save the merged model') parser.add_argument( '--max-shard-size', type=str, default='2GB', help='Only applicable for LLM. The maximum size for ' 'each sharded checkpoint.') parser.add_argument( '--is-clip', action='store_true', help='Indicate if the model is a clip model') parser.add_argument( '--safe-serialization', action='store_true', help='Indicate if using `safe_serialization`') parser.add_argument( '--device', default='cuda', choices=('cuda', 'cpu', 'auto'), help='Indicate the device') args = parser.parse_args() return args def main(): args = parse_args() if args.is_clip: with LoadWoInit(): model = CLIPVisionModel.from_pretrained( args.model_name_or_path, device_map=args.device) processor = CLIPImageProcessor.from_pretrained(args.model_name_or_path) else: with LoadWoInit(): model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map=args.device, trust_remote_code=True) processor = AutoTokenizer.from_pretrained( args.model_name_or_path, trust_remote_code=True) model_unmerged = PeftModel.from_pretrained( model, args.adapter_name_or_path, device_map=args.device, is_trainable=False, trust_remote_code=True) model_merged = model_unmerged.merge_and_unload() print(f'Saving to {args.save_dir}...') model_merged.save_pretrained( args.save_dir, safe_serialization=args.safe_serialization, max_shard_size=args.max_shard_size) processor.save_pretrained(args.save_dir) print('All done!') if __name__ == '__main__': main()