|
import torch |
|
import warnings |
|
from config import * |
|
from peft import LoraConfig |
|
from transformers import BitsAndBytesConfig |
|
|
|
warnings.filterwarnings(action='ignore') |
|
|
|
def load_trol(link): |
|
|
|
""" |
|
model selection |
|
""" |
|
if link == 'TroL-1.8B': |
|
from .arch_internlm2.modeling_trol import TroLForCausalLM |
|
from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer |
|
bits = 4 |
|
path = TROL_1_8B |
|
bit_quant_skip = ["vit", "vision_proj", "ffn", "output"] |
|
|
|
elif link == 'TroL-3.8B': |
|
from trol.arch_phi3.modeling_trol import TroLForCausalLM |
|
from transformers import LlamaTokenizerFast as TroLTokenizer |
|
bits = 8 |
|
path = TROL_3_8B |
|
bit_quant_skip = ["vision_model", "mlp1", "lm_head"] |
|
|
|
elif link == 'TroL-7B': |
|
from .arch_internlm2.modeling_trol import TroLForCausalLM |
|
from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer |
|
bits = 4 |
|
path = TROL_7B |
|
bit_quant_skip = ["vit", "vision_proj", "ffn", "output"] |
|
else: |
|
raise Exception("Unsupported Link") |
|
|
|
|
|
huggingface_config = {} |
|
|
|
|
|
if bits in [4, 8]: |
|
huggingface_config.update(dict( |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True, |
|
attn_implementation="flash_attention_2", |
|
quantization_config=BitsAndBytesConfig( |
|
load_in_4bit=bits == 4, |
|
load_in_8bit=bits == 8, |
|
llm_int8_skip_modules=bit_quant_skip, |
|
llm_int8_threshold=6.0, |
|
llm_int8_has_fp16_weight=False, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type='nf4' |
|
) |
|
)) |
|
else: |
|
huggingface_config.update(dict( |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True, |
|
attn_implementation="flash_attention_2", |
|
)) |
|
|
|
|
|
tok_trol = TroLTokenizer.from_pretrained(path, padding_side='left') |
|
try: |
|
trol = TroLForCausalLM.from_pretrained(path, **huggingface_config) |
|
except: |
|
del huggingface_config["attn_implementation"] |
|
trol = TroLForCausalLM.from_pretrained(path, **huggingface_config) |
|
return trol, tok_trol |