|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import re |
|
import transformers |
|
import torch |
|
from tqdm import tqdm |
|
from transformers import GPT2LMHeadModel, GPT2TokenizerFast |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
device = "cuda" |
|
|
|
tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") |
|
|
|
from datasets import load_dataset |
|
|
|
test = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation") |
|
|
|
encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt") |
|
import time |
|
import gc |
|
def run_experiment(model): |
|
print(f'Memory usage of model alone = {model.get_memory_footprint()/10**6}') |
|
max_length = model.config.n_positions |
|
stride = 512 |
|
seq_len = encodings.input_ids.size(1) |
|
|
|
nlls = [] |
|
start_time = time.time() |
|
prev_end_loc = 0 |
|
for begin_loc in tqdm(range(0, seq_len, stride)): |
|
end_loc = min(begin_loc + max_length, seq_len) |
|
trg_len = end_loc - prev_end_loc |
|
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) |
|
target_ids = input_ids.clone() |
|
target_ids[:, :-trg_len] = -100 |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids, labels=target_ids) |
|
|
|
|
|
neg_log_likelihood = outputs.loss |
|
|
|
if begin_loc == 0: |
|
print(f'Memory usage at forward pass = {torch.cuda.memory_allocated(0)/10**6}') |
|
nlls.append(neg_log_likelihood) |
|
|
|
prev_end_loc = end_loc |
|
if end_loc == seq_len: |
|
break |
|
|
|
ppl = torch.exp(torch.stack(nlls).mean()) |
|
print(f'Loss = {ppl.item()}') |
|
print(f'Time taken: {- start_time + time.time()}') |
|
|
|
|
|
from transformers import BitsAndBytesConfig |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
) |
|
model =AutoModelForCausalLM.from_pretrained("gpt2", quantization_config=bnb_config ) |
|
|
|
|
|
print('4 bit model') |
|
run_experiment(model) |
|
|
|
torch.save(model, 'bnb-4.pth') |
|
print() |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_8bit=True, |
|
) |
|
model =AutoModelForCausalLM.from_pretrained("gpt2", quantization_config=bnb_config ) |
|
print('8 bit model') |
|
run_experiment(model) |
|
torch.save(model, 'bnb-8.pth') |
|
print() |
|
|
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
) |
|
model =AutoModelForCausalLM.from_pretrained("gpt2", quantization_config=bnb_config ) |
|
print('4 bit nf4 model') |
|
run_experiment(model) |
|
torch.save(model, 'bnb-nf4.pth') |
|
print() |
|
|
|
|
|
|
|
|
|
|