Spaces:
Running
Running
# Copyright 2024 the LlamaFactory team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
from types import MethodType | |
from typing import TYPE_CHECKING, Any, Dict | |
import torch | |
from peft import PeftModel | |
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available | |
from transformers.integrations import is_deepspeed_zero3_enabled | |
from transformers.modeling_utils import is_fsdp_enabled | |
from ..extras.logging import get_logger | |
from ..extras.misc import infer_optim_dtype | |
from .model_utils.attention import configure_attn_implementation, print_attn_implementation | |
from .model_utils.checkpointing import prepare_model_for_training | |
from .model_utils.embedding import resize_embedding_layer | |
from .model_utils.longlora import configure_longlora | |
from .model_utils.moe import add_z3_leaf_module, configure_moe | |
from .model_utils.quantization import configure_quantization | |
from .model_utils.rope import configure_rope | |
from .model_utils.valuehead import prepare_valuehead_model | |
from .model_utils.visual import autocast_projector_dtype, configure_visual_model | |
if TYPE_CHECKING: | |
from transformers import PretrainedConfig, PreTrainedTokenizer | |
from trl import AutoModelForCausalLMWithValueHead | |
from ..hparams import ModelArguments | |
logger = get_logger(__name__) | |
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: | |
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): | |
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) | |
def patch_config( | |
config: "PretrainedConfig", | |
tokenizer: "PreTrainedTokenizer", | |
model_args: "ModelArguments", | |
init_kwargs: Dict[str, Any], | |
is_trainable: bool, | |
) -> None: | |
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 | |
if model_args.infer_dtype == "auto": | |
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) | |
else: | |
model_args.compute_dtype = getattr(torch, model_args.infer_dtype) | |
if is_torch_npu_available(): | |
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"] | |
torch.npu.set_compile_mode(jit_compile=use_jit_compile) | |
configure_attn_implementation(config, model_args) | |
configure_rope(config, model_args, is_trainable) | |
configure_longlora(config, model_args, is_trainable) | |
configure_quantization(config, tokenizer, model_args, init_kwargs) | |
configure_moe(config, model_args, is_trainable) | |
configure_visual_model(config) | |
if model_args.use_cache and not is_trainable: | |
setattr(config, "use_cache", True) | |
logger.info("Using KV cache for faster generation.") | |
if getattr(config, "model_type", None) == "qwen": | |
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2") | |
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: | |
setattr(config, dtype_name, model_args.compute_dtype == dtype) | |
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2": | |
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn | |
# deepspeed zero3 is not compatible with low_cpu_mem_usage | |
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) | |
# cast data type of the model if: | |
# 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32) | |
# 2. fsdp + qlora | |
if model_args.quantization_bit is not None or (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()): | |
init_kwargs["torch_dtype"] = model_args.compute_dtype | |
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True | |
if "device_map" not in init_kwargs and model_args.device_map: | |
init_kwargs["device_map"] = model_args.device_map | |
if init_kwargs.get("device_map", None) == "auto": | |
init_kwargs["offload_folder"] = model_args.offload_folder | |
def patch_model( | |
model: "PreTrainedModel", | |
tokenizer: "PreTrainedTokenizer", | |
model_args: "ModelArguments", | |
is_trainable: bool, | |
add_valuehead: bool, | |
) -> None: | |
gen_config = model.generation_config # check and fix generation config | |
if not gen_config.do_sample and ( | |
(gen_config.temperature is not None and gen_config.temperature != 1.0) | |
or (gen_config.top_p is not None and gen_config.top_p != 1.0) | |
or (gen_config.typical_p is not None and gen_config.typical_p != 1.0) | |
): | |
gen_config.do_sample = True | |
if "GenerationMixin" not in str(model.generate.__func__): | |
model.generate = MethodType(PreTrainedModel.generate, model) | |
if add_valuehead: | |
prepare_valuehead_model(model) | |
if model_args.resize_vocab: | |
resize_embedding_layer(model, tokenizer) | |
if model_args.visual_inputs: | |
autocast_projector_dtype(model, model_args) | |
if is_trainable: | |
prepare_model_for_training(model, model_args) | |
add_z3_leaf_module(model) | |
if not model_args.use_unsloth: | |
print_attn_implementation(model.config) | |
try: | |
model.add_model_tags(["llama-factory"]) | |
except Exception: | |
logger.warning("Cannot properly tag the model.") | |
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: | |
def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: | |
if isinstance(self.pretrained_model, PreTrainedModel): | |
self.pretrained_model.tie_weights() | |
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: | |
if isinstance(self.pretrained_model, PreTrainedModel): | |
return self.pretrained_model.get_input_embeddings() | |
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None: | |
if isinstance(self.pretrained_model, PeftModel): | |
self.pretrained_model.create_or_update_model_card(output_dir) | |
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] | |
setattr(model, "_keys_to_ignore_on_save", ignore_modules) | |
setattr(model, "tie_weights", MethodType(tie_weights, model)) | |
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) | |
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model)) | |