File size: 6,949 Bytes
d6682b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import os
from pathlib import Path
import torch
from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from ...util.globals import *
from ...util.nethook import Trace, set_requires_grad
from ...util.runningstats import CombinedStat, Mean, NormMean, SecondMoment, tally
from .tok_dataset import (
TokenizedDataset,
dict_to_,
flatten_masked_batch,
length_collation,
)
STAT_TYPES = {
"mom2": SecondMoment,
"mean": Mean,
"norm_mean": NormMean,
}
def main():
"""
Command-line utility to precompute cached stats.
"""
import argparse
parser = argparse.ArgumentParser(description="ROME Statistics Collector")
def aa(*args, **kwargs):
parser.add_argument(*args, **kwargs)
aa("--model_name", default="gpt2-xl", choices=["gpt2-xl", "EleutherAI/gpt-j-6B"])
aa("--dataset", default="wikipedia", choices=["wikitext", "wikipedia"])
aa("--layers", default=[17], type=lambda x: list(map(int, x.split(","))))
aa("--to_collect", default=["mom2"], type=lambda x: x.split(","))
aa("--sample_size", default=100000, type=lambda x: None if x == "all" else int(x))
aa("--batch_tokens", default=None, type=lambda x: None if x == "any" else int(x))
aa("--precision", default="float32", choices=["float64", "float32", "float16"])
aa("--stats_dir", default=STATS_DIR)
aa("--download", default=1, type=int, choices=[0, 1])
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForCausalLM.from_pretrained(args.model_name).eval().cuda()
set_requires_grad(False, model)
for layer_num in args.layers:
print(
f"Computing stats for layer {layer_num} of {args.model_name} "
f'over {args.sample_size or "all"} samples of {args.dataset}. '
"Note, the statistics are collected over the inputs to the second MLP layer, "
"or equivalently the outputs of the first MLP layer."
)
proj_layer_name = "c_proj" if "gpt2" in args.model_name else "fc_out"
layer_name = f"transformer.h.{layer_num}.mlp.{proj_layer_name}"
layer_stats(
model,
tokenizer,
layer_name,
args.stats_dir,
args.dataset,
args.to_collect,
sample_size=args.sample_size,
precision=args.precision,
batch_tokens=args.batch_tokens,
download=args.download,
)
def layer_stats(
model,
tokenizer,
layer_name,
stats_dir,
ds_name,
to_collect,
model_name=None,
sample_size=None,
precision=None,
batch_tokens=None,
download=True,
progress=tqdm,
force_recompute=False,
hparams=None
):
"""
Function to load or compute cached stats.
"""
def get_ds():
# Load_From_File
# from datasets import Dataset
# raw_ds = Dataset.from_file('XXX/XXX/wikipedia-train.arrow')
# raw_ds = {'train': raw_ds}
raw_ds = load_dataset(
ds_name,
dict(wikitext="wikitext-103-raw-v1", wikipedia="20200501.en")[ds_name]
)
if hasattr(model.config, 'n_positions'):
maxlen = model.config.n_positions
elif hasattr(model.config, 'max_sequence_length'):
maxlen = model.config.max_sequence_length
elif hasattr(model.config, 'max_position_embeddings'):
maxlen = model.config.max_position_embeddings
elif hasattr(model.config,'seq_length'):
maxlen = model.config.seq_length
else:
raise NotImplementedError
if hasattr(model.config, 'model_type') and 'mistral' in model.config.model_type:
if hasattr(model.config, 'sliding_window') and model.config.sliding_window:
maxlen = model.config.sliding_window or 4096
else:
maxlen = 4096
if batch_tokens is not None and batch_tokens < maxlen:
maxlen = batch_tokens
return TokenizedDataset(raw_ds["train"], tokenizer, maxlen=maxlen)
# Continue with computation of statistics
batch_size = 100 # Examine this many dataset texts at once
if hasattr(model.config, 'n_positions'):
npos = model.config.n_positions
elif hasattr(model.config, 'max_sequence_length'):
npos = model.config.max_sequence_length
elif hasattr(model.config, 'max_position_embeddings'):
npos = model.config.max_position_embeddings
elif hasattr(model.config,'seq_length'):
npos = model.config.seq_length
else:
raise NotImplementedError
if hasattr(model.config, 'model_type') and 'mistral' in model.config.model_type:
if hasattr(model.config, 'sliding_window') and model.config.sliding_window:
npos = model.config.sliding_window or 4096
else:
npos = 4096
if batch_tokens is None:
batch_tokens = npos * 3 # Sort and divide into batches with this many tokens
if precision is None:
precision = "float64"
dtype = getattr(torch, precision)
size_suffix = "" if sample_size is None else f"_{sample_size}"
if batch_tokens < npos:
size_suffix = "_t{batch_tokens}" + size_suffix
if model_name is None:
# model_name = model.config._name_or_path.replace("/", "_")
model_name = model.config._name_or_path.rsplit("/")[-1]
stats_dir = Path(stats_dir)
file_extension = f"{model_name}/{ds_name}_stats/{layer_name}_{precision}_{'-'.join(sorted(to_collect))}{size_suffix}.npz"
filename = stats_dir / file_extension
print(f"Computing Cov locally....")
ds = get_ds() if not filename.exists() else None
if progress is None:
progress = lambda x: x
stat = CombinedStat(**{k: STAT_TYPES[k]() for k in to_collect})
loader = tally(
stat,
ds,
cache=(filename if not force_recompute else None),
sample_size=sample_size,
batch_size=batch_size,
collate_fn=length_collation(batch_tokens),
pin_memory=True,
random_sample=1,
num_workers=2,
)
batch_count = -(-(sample_size or len(ds)) // batch_size)
with torch.no_grad():
for batch_group in progress(loader, total=batch_count):
for batch in batch_group:
batch = dict_to_(batch, f"cuda:{hparams.device}")
with Trace(
model, layer_name, retain_input=True, retain_output=False, stop=True
) as tr:
model(**batch)
feats = flatten_masked_batch(tr.input, batch["attention_mask"])
# feats = flatten_masked_batch(tr.output, batch["attention_mask"])
feats = feats.to(dtype=dtype)
stat.add(feats)
return stat
if __name__ == "__main__":
main()
|