Spaces:
Runtime error
Runtime error
from peft import ( | |
LoraConfig, | |
PeftModel, | |
LoraModel, | |
PeftModelForCausalLM, | |
get_peft_model, | |
get_peft_model_state_dict, | |
prepare_model_for_int8_training, | |
set_peft_model_state_dict, | |
) | |
from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING | |
from peft.utils import _set_trainable, PromptLearningConfig | |
from peft.utils import PeftConfig | |
import torch | |
from transformers import LlamaForCausalLM | |
from omegaconf import DictConfig | |
import hydra | |
def get_peft_model_with_resize_embedding( | |
model, | |
peft_config=None, | |
model_id=None, | |
vocab_size=None, | |
torch_dtype='bf16' | |
): | |
if torch_dtype == 'bf16' or torch_dtype == 'bfloat16': | |
torch_dtype = torch.bfloat16 | |
elif torch_dtype == 'fp16' or torch_dtype == 'float16': | |
torch_dtype = torch.float16 | |
else: | |
torch_dtype = torch.float32 | |
if isinstance(model, DictConfig): | |
model = hydra.utils.instantiate(model, torch_dtype=torch_dtype) | |
# model.gradient_checkpointing_enable() | |
assert (peft_config is None) + (model_id is None) == 1 | |
# print(type(peft_config.target_modules)) | |
if vocab_size is not None: | |
print(f'Length of tokenizer and resize embedding: {vocab_size}') | |
model.resize_token_embeddings(vocab_size) | |
if peft_config is not None: | |
print('peft config: ', peft_config) | |
peft_model = get_peft_model(model=model, peft_config=peft_config) | |
peft_model.get_input_embeddings().requires_grad_(True) | |
peft_model.get_output_embeddings().requires_grad_(True) | |
peft_model.print_trainable_parameters() | |
# param_count = 0 | |
# if peft_model.modules_to_save is not None: | |
# for name, param in peft_model.named_parameters(): | |
# if any(module_name in name for module_name in peft_model.modules_to_save): | |
# param_count += param.numel() | |
# print(name, param.numel()) | |
else: | |
peft_model = PeftModel.from_pretrained(model=model, model_id=model_id) | |
return peft_model | |
def get_model_with_resize_embedding(model, vocab_size=None, torch_dtype='bf16'): | |
if torch_dtype == 'bf16' or torch_dtype == 'bfloat16': | |
torch_dtype = torch.bfloat16 | |
elif torch_dtype == 'fp16' or torch_dtype == 'float16': | |
torch_dtype = torch.float16 | |
else: | |
torch_dtype = torch.float32 | |
if isinstance(model, DictConfig): | |
model = hydra.utils.instantiate(model, torch_dtype=torch_dtype) | |
model.requires_grad_(False) | |
if vocab_size is not None: | |
print(f'Length of tokenizer and resize embedding: {vocab_size}') | |
model.resize_token_embeddings(vocab_size) | |
model.get_input_embeddings().requires_grad_(True) | |
model.get_output_embeddings().requires_grad_(True) | |
return model | |
def get_full_model_with_resize_embedding(model, vocab_size=None, torch_dtype='bf16'): | |
if torch_dtype == 'bf16' or torch_dtype == 'bfloat16': | |
torch_dtype = torch.bfloat16 | |
elif torch_dtype == 'fp16' or torch_dtype == 'float16': | |
torch_dtype = torch.float16 | |
else: | |
torch_dtype = torch.float32 | |
if isinstance(model, DictConfig): | |
model = hydra.utils.instantiate(model, torch_dtype=torch_dtype) | |
if vocab_size is not None: | |
print(f'Length of tokenizer and resize embedding: {vocab_size}') | |
model.resize_token_embeddings(vocab_size) | |
return model | |