File size: 622 Bytes
6de3e11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import bitsandbytes as bnb
import torch
from transformers.trainer_pt_utils import LabelSmoother

IGNORE_TOKEN_ID = LabelSmoother.ignore_index


def find_all_linear_names(use_8bit, model):
    cls = bnb.nn.Linear8bitLt if use_8bit else torch.nn.Linear
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    target_modules = list(lora_module_names)
    return target_modules


def load_from_checkpoint(resume_from_checkpoint, model=None):
    pass