File size: 1,193 Bytes
f08eddf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import argparse
import torch
from transformers import (
AutoProcessor,
LlavaForConditionalGeneration,
)
def preprocess_text_encoder_tokenizer(args):
processor = AutoProcessor.from_pretrained(args.input_dir)
model = LlavaForConditionalGeneration.from_pretrained(
args.input_dir,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(0)
model.language_model.save_pretrained(
f"{args.output_dir}"
)
processor.tokenizer.save_pretrained(
f"{args.output_dir}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
type=str,
required=True,
help="The path to the llava-llama-3-8b-v1_1-transformers.",
)
parser.add_argument(
"--output_dir",
type=str,
default="",
help="The output path of the llava-llama-3-8b-text-encoder-tokenizer."
"if '', the parent dir of output will be the same as input dir.",
)
args = parser.parse_args()
if len(args.output_dir) == 0:
args.output_dir = "/".join(args.input_dir.split("/")[:-1])
preprocess_text_encoder_tokenizer(args)
|