|
import argparse |
|
import torch |
|
from transformers import ( |
|
AutoProcessor, |
|
LlavaForConditionalGeneration, |
|
) |
|
|
|
|
|
def preprocess_text_encoder_tokenizer(args): |
|
|
|
processor = AutoProcessor.from_pretrained(args.input_dir) |
|
model = LlavaForConditionalGeneration.from_pretrained( |
|
args.input_dir, |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True, |
|
).to(0) |
|
|
|
model.language_model.save_pretrained( |
|
f"{args.output_dir}" |
|
) |
|
processor.tokenizer.save_pretrained( |
|
f"{args.output_dir}" |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--input_dir", |
|
type=str, |
|
required=True, |
|
help="The path to the llava-llama-3-8b-v1_1-transformers.", |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
type=str, |
|
default="", |
|
help="The output path of the llava-llama-3-8b-text-encoder-tokenizer." |
|
"if '', the parent dir of output will be the same as input dir.", |
|
) |
|
args = parser.parse_args() |
|
|
|
if len(args.output_dir) == 0: |
|
args.output_dir = "/".join(args.input_dir.split("/")[:-1]) |
|
|
|
preprocess_text_encoder_tokenizer(args) |
|
|