File size: 3,422 Bytes
674d663
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from peft import (
    LoraConfig,
    PeftModel,
    LoraModel,
    PeftModelForCausalLM,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
    set_peft_model_state_dict,
)
from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING
from peft.utils import _set_trainable, PromptLearningConfig
from peft.utils import PeftConfig

import torch
from transformers import LlamaForCausalLM
from omegaconf import DictConfig
import hydra


def get_peft_model_with_resize_embedding(
        model,
        peft_config=None,
        model_id=None,
        vocab_size=None,
        torch_dtype='bf16'
):
    if torch_dtype == 'bf16' or torch_dtype == 'bfloat16':
        torch_dtype = torch.bfloat16
    elif torch_dtype == 'fp16' or torch_dtype == 'float16':
        torch_dtype = torch.float16
    else:
        torch_dtype = torch.float32

    if isinstance(model, DictConfig):
        model = hydra.utils.instantiate(model, torch_dtype=torch_dtype)

    # model.gradient_checkpointing_enable()

    assert (peft_config is None) + (model_id is None) == 1

    # print(type(peft_config.target_modules))
    if vocab_size is not None:
        print(f'Length of tokenizer and resize embedding: {vocab_size}')
        model.resize_token_embeddings(vocab_size)

    if peft_config is not None:
        print('peft config: ', peft_config)
        peft_model = get_peft_model(model=model, peft_config=peft_config)
        peft_model.get_input_embeddings().requires_grad_(True)
        peft_model.get_output_embeddings().requires_grad_(True)

        peft_model.print_trainable_parameters()

        # param_count = 0
        # if peft_model.modules_to_save is not None:
        #     for name, param in peft_model.named_parameters():
        #         if any(module_name in name for module_name in peft_model.modules_to_save):
        #             param_count += param.numel()
        #             print(name, param.numel())

    else:
        peft_model = PeftModel.from_pretrained(model=model, model_id=model_id)

    return peft_model


def get_model_with_resize_embedding(model, vocab_size=None, torch_dtype='bf16'):
    if torch_dtype == 'bf16' or torch_dtype == 'bfloat16':
        torch_dtype = torch.bfloat16
    elif torch_dtype == 'fp16' or torch_dtype == 'float16':
        torch_dtype = torch.float16
    else:
        torch_dtype = torch.float32

    if isinstance(model, DictConfig):
        model = hydra.utils.instantiate(model, torch_dtype=torch_dtype)

    model.requires_grad_(False)
    if vocab_size is not None:
        print(f'Length of tokenizer and resize embedding: {vocab_size}')
        model.resize_token_embeddings(vocab_size)
        model.get_input_embeddings().requires_grad_(True)
        model.get_output_embeddings().requires_grad_(True)

    return model


def get_full_model_with_resize_embedding(model, vocab_size=None, torch_dtype='bf16'):
    if torch_dtype == 'bf16' or torch_dtype == 'bfloat16':
        torch_dtype = torch.bfloat16
    elif torch_dtype == 'fp16' or torch_dtype == 'float16':
        torch_dtype = torch.float16
    else:
        torch_dtype = torch.float32

    if isinstance(model, DictConfig):
        model = hydra.utils.instantiate(model, torch_dtype=torch_dtype)

    if vocab_size is not None:
        print(f'Length of tokenizer and resize embedding: {vocab_size}')
        model.resize_token_embeddings(vocab_size)

    return model