|
|
|
|
|
import os |
|
import warnings |
|
import torch |
|
import transformers |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
|
def load_model(repo_id, bnb=None, torch_dtype='auto'): |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
if device.type == 'cuda': |
|
|
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' |
|
|
|
|
|
torch.use_deterministic_algorithms(True) |
|
|
|
|
|
|
|
warnings.filterwarnings('ignore', message="`resume_download` is deprecated") |
|
|
|
|
|
warnings.filterwarnings('ignore', message="MatMul8bitLt: inputs will be cast from") |
|
|
|
print(f'Loading model "{repo_id}" (bnb = "{bnb}")...') |
|
|
|
|
|
transformers.logging.set_verbosity_error() |
|
tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True) |
|
transformers.logging.set_verbosity_warning() |
|
|
|
bnb_config = None |
|
if bnb == 'nf8': |
|
bnb_config = BitsAndBytesConfig(load_in_8bit=True) |
|
if bnb == 'nf4': |
|
bnb_config = BitsAndBytesConfig(load_in_4bit=True) |
|
|
|
device_map = 'auto' |
|
if device.type == 'cpu': |
|
|
|
device_map = None |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
repo_id, |
|
torch_dtype=torch_dtype, |
|
device_map=device_map, |
|
quantization_config=bnb_config, |
|
) |
|
|
|
|
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
model.eval() |
|
|
|
print('Done loading model.') |
|
|
|
return model, tokenizer |
|
|
|
def load_tokenizer(repo_id): |
|
|
|
transformers.logging.set_verbosity_error() |
|
tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True) |
|
transformers.logging.set_verbosity_warning() |
|
return tokenizer |