import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer import re import transformers import torch from tqdm import tqdm from transformers import GPT2LMHeadModel, GPT2TokenizerFast import warnings warnings.filterwarnings("ignore") device = "cuda" tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") from datasets import load_dataset test = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation") # print(len(test)) encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt") import time import gc def run_experiment(model): print(f'Memory usage of model alone = {model.get_memory_footprint()/10**6}') max_length = model.config.n_positions stride = 512 seq_len = encodings.input_ids.size(1) nlls = [] start_time = time.time() prev_end_loc = 0 for begin_loc in tqdm(range(0, seq_len, stride)): end_loc = min(begin_loc + max_length, seq_len) trg_len = end_loc - prev_end_loc # may be different from stride on last loop input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) target_ids = input_ids.clone() target_ids[:, :-trg_len] = -100 with torch.no_grad(): outputs = model(input_ids, labels=target_ids) # loss is calculated using CrossEntropyLoss which averages over valid labels neg_log_likelihood = outputs.loss if begin_loc == 0: print(f'Memory usage at forward pass = {torch.cuda.memory_allocated(0)/10**6}') nlls.append(neg_log_likelihood) prev_end_loc = end_loc if end_loc == seq_len: break ppl = torch.exp(torch.stack(nlls).mean()) print(f'Loss = {ppl.item()}') print(f'Time taken: {- start_time + time.time()}') from transformers import BitsAndBytesConfig bnb_config = BitsAndBytesConfig( load_in_4bit=True, ) model =AutoModelForCausalLM.from_pretrained("gpt2", quantization_config=bnb_config ) ## 4bit print('4 bit model') run_experiment(model) torch.save(model, 'bnb-4.pth') print() ## 8bit bnb_config = BitsAndBytesConfig( load_in_8bit=True, ) model =AutoModelForCausalLM.from_pretrained("gpt2", quantization_config=bnb_config ) print('8 bit model') run_experiment(model) torch.save(model, 'bnb-8.pth') print() ## nf4 bit bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", ) model =AutoModelForCausalLM.from_pretrained("gpt2", quantization_config=bnb_config ) print('4 bit nf4 model') run_experiment(model) torch.save(model, 'bnb-nf4.pth') print()