File size: 3,425 Bytes
d754e91
 
87a0e23
d754e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87a0e23
 
d754e91
4b2400e
d754e91
 
 
87a0e23
d754e91
 
4b2400e
d754e91
 
 
 
 
 
4b2400e
d754e91
 
 
 
 
4b2400e
 
 
 
 
 
 
 
 
 
 
 
d754e91
 
 
 
 
 
 
87a0e23
 
 
d754e91
 
87a0e23
 
d754e91
 
 
 
 
 
87a0e23
 
d754e91
 
 
 
 
 
 
 
0e92a92
c15d0e4
d754e91
87a0e23
b9929ef
 
 
7b14813
87a0e23
9279c83
 
 
 
 
 
 
 
87a0e23
 
 
 
 
 
 
9279c83
87a0e23
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import sys
import gc

import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer

from .globals import Global


def get_device():
    if torch.cuda.is_available():
        return "cuda"
    else:
        return "cpu"

    try:
        if torch.backends.mps.is_available():
            return "mps"
    except:  # noqa: E722
        pass


device = get_device()


def get_base_model():
    load_base_model()
    return Global.loaded_base_model


def get_model_with_lora(lora_weights: str = "tloen/alpaca-lora-7b"):
    Global.model_has_been_used = True

    if device == "cuda":
        model = PeftModel.from_pretrained(
            get_base_model(),
            lora_weights,
            torch_dtype=torch.float16,
            device_map={'': 0},  # ? https://github.com/tloen/alpaca-lora/issues/21
        )
    elif device == "mps":
        model = PeftModel.from_pretrained(
            get_base_model(),
            lora_weights,
            device_map={"": device},
            torch_dtype=torch.float16,
        )
    else:
        model = PeftModel.from_pretrained(
            get_base_model(),
            lora_weights,
            device_map={"": device},
        )

    model.config.pad_token_id = get_tokenizer().pad_token_id = 0
    model.config.bos_token_id = 1
    model.config.eos_token_id = 2

    if not Global.load_8bit:
        model.half()  # seems to fix bugs for some users.

    model.eval()
    if torch.__version__ >= "2" and sys.platform != "win32":
        model = torch.compile(model)
    return model


def get_tokenizer():
    load_base_model()
    return Global.loaded_tokenizer


def load_base_model():
    if Global.ui_dev_mode:
        return

    if Global.loaded_tokenizer is None:
        Global.loaded_tokenizer = LlamaTokenizer.from_pretrained(
            Global.base_model
        )
    if Global.loaded_base_model is None:
        if device == "cuda":
            Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
                Global.base_model,
                load_in_8bit=Global.load_8bit,
                torch_dtype=torch.float16,
                # device_map="auto",
                device_map={'': 0},  # ? https://github.com/tloen/alpaca-lora/issues/21
            )
        elif device == "mps":
            Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
                Global.base_model,
                device_map={"": device},
                torch_dtype=torch.float16,
            )
        else:
            Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
                Global.base_model, device_map={"": device}, low_cpu_mem_usage=True
            )

        Global.loaded_base_model.config.pad_token_id = get_tokenizer().pad_token_id = 0
        Global.loaded_base_model.config.bos_token_id = 1
        Global.loaded_base_model.config.eos_token_id = 2


def clear_cache():
    gc.collect()

    # if not shared.args.cpu: # will not be running on CPUs anyway
    with torch.no_grad():
        torch.cuda.empty_cache()


def unload_models():
    del Global.loaded_base_model
    Global.loaded_base_model = None

    del Global.loaded_tokenizer
    Global.loaded_tokenizer = None

    clear_cache()

    Global.model_has_been_used = False


def unload_models_if_already_used():
    if Global.model_has_been_used:
        unload_models()