|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Count remaining (non-zero) weights in the encoder (i.e. the transformer layers). |
|
Sparsity and remaining weights levels are equivalent: sparsity % = 100 - remaining weights %. |
|
""" |
|
import argparse |
|
import os |
|
|
|
import torch |
|
from emmental.modules import ThresholdBinarizer, TopKBinarizer |
|
|
|
|
|
def main(args): |
|
serialization_dir = args.serialization_dir |
|
pruning_method = args.pruning_method |
|
threshold = args.threshold |
|
|
|
st = torch.load(os.path.join(serialization_dir, "pytorch_model.bin"), map_location="cpu") |
|
|
|
remaining_count = 0 |
|
encoder_count = 0 |
|
|
|
print("name".ljust(60, " "), "Remaining Weights %", "Remaining Weight") |
|
for name, param in st.items(): |
|
if "encoder" not in name: |
|
continue |
|
|
|
if "mask_scores" in name: |
|
if pruning_method == "topK": |
|
mask_ones = TopKBinarizer.apply(param, threshold).sum().item() |
|
elif pruning_method == "sigmoied_threshold": |
|
mask_ones = ThresholdBinarizer.apply(param, threshold, True).sum().item() |
|
elif pruning_method == "l0": |
|
l, r = -0.1, 1.1 |
|
s = torch.sigmoid(param) |
|
s_bar = s * (r - l) + l |
|
mask = s_bar.clamp(min=0.0, max=1.0) |
|
mask_ones = (mask > 0.0).sum().item() |
|
else: |
|
raise ValueError("Unknown pruning method") |
|
remaining_count += mask_ones |
|
print(name.ljust(60, " "), str(round(100 * mask_ones / param.numel(), 3)).ljust(20, " "), str(mask_ones)) |
|
else: |
|
encoder_count += param.numel() |
|
if "bias" in name or "LayerNorm" in name: |
|
remaining_count += param.numel() |
|
|
|
print("") |
|
print("Remaining Weights (global) %: ", 100 * remaining_count / encoder_count) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--pruning_method", |
|
choices=["l0", "topK", "sigmoied_threshold"], |
|
type=str, |
|
required=True, |
|
help=( |
|
"Pruning Method (l0 = L0 regularization, topK = Movement pruning, sigmoied_threshold = Soft movement" |
|
" pruning)" |
|
), |
|
) |
|
parser.add_argument( |
|
"--threshold", |
|
type=float, |
|
required=False, |
|
help=( |
|
"For `topK`, it is the level of remaining weights (in %) in the fine-pruned model." |
|
"For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared." |
|
"Not needed for `l0`" |
|
), |
|
) |
|
parser.add_argument( |
|
"--serialization_dir", |
|
type=str, |
|
required=True, |
|
help="Folder containing the model that was previously fine-pruned", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|