SAELens

gemma-2-2b layer 20 SAE width 65k SAE seems very off

#8
by charlieoneill - opened

I have been evaluating gemma-2-2b SAEs on a dataset of medical text. Looking at the 16k width on layer 20, the metrics seem to be about what I'd expect:

{
    "l0_139": {
      "l2_loss": 148.585,
      "l1_loss": 2728.04,
      "l0": 183.4201708984375,
      "frac_variance_explained": -0.40267578125,
      "cossim": 0.92232421875,
      "l2_ratio": 0.93435546875,
      "relative_reconstruction_bias": 1.97640625,
      "loss_original": 1.8141272115707396,
      "loss_reconstructed": 2.10102525472641,
      "loss_zero": 12.452932243347169,
      "frac_recovered": 0.9730078125,
      "frac_alive": 0.9940185546875,
      "hyperparameters": {
        "n_inputs": 200,
        "context_length": 1024,
        "l0": 139,
        "layer": 20,
        "width": "16k"
      }
    },
    "l0_22": {
      "l2_loss": 284.485,
      "l1_loss": 2719.16,
      "l0": 55.3458203125,
      "frac_variance_explained": -31.495,
      "cossim": 0.87486328125,
      "l2_ratio": 0.91640625,
      "relative_reconstruction_bias": 15.23140625,
      "loss_original": 1.8141272115707396,
      "loss_reconstructed": 2.4616863882541655,
      "loss_zero": 12.452932243347169,
      "frac_recovered": 0.9390829825401306,
      "frac_alive": 0.82476806640625,
      "hyperparameters": {
        "n_inputs": 200,
        "context_length": 1024,
        "l0": 22,
        "layer": 20,
        "width": "16k"
      }
    },
    "l0_294": {
      "l2_loss": 130.0175,
      "l1_loss": 3763.92,
      "l0": 352.6443994140625,
      "frac_variance_explained": -0.01845703125,
      "cossim": 0.9406640625,
      "l2_ratio": 0.94486328125,
      "relative_reconstruction_bias": 1.71236328125,
      "loss_original": 1.8141272115707396,
      "loss_reconstructed": 2.0600193762779235,
      "loss_zero": 12.452932243347169,
      "frac_recovered": 0.9768525409698486,
      "frac_alive": 0.99761962890625,
      "hyperparameters": {
        "n_inputs": 200,
        "context_length": 1024,
        "l0": 294,
        "layer": 20,
        "width": "16k"
      }
    },
    "l0_38": {
      "l2_loss": 251.58,
      "l1_loss": 2645.76,
      "l0": 73.9402734375,
      "frac_variance_explained": -20.34841796875,
      "cossim": 0.889765625,
      "l2_ratio": 0.9233984375,
      "relative_reconstruction_bias": 11.4728125,
      "loss_original": 1.8141272115707396,
      "loss_reconstructed": 2.3733639335632324,
      "loss_zero": 12.452932243347169,
      "frac_recovered": 0.947366454899311,
      "frac_alive": 0.89910888671875,
      "hyperparameters": {
        "n_inputs": 200,
        "context_length": 1024,
        "l0": 38,
        "layer": 20,
        "width": "16k"
      }
    },
    "l0_71": {
      "l2_loss": 189.87,
      "l1_loss": 2500.32,
      "l0": 109.7097705078125,
      "frac_variance_explained": -4.80037109375,
      "cossim": 0.90638671875,
      "l2_ratio": 0.92884765625,
      "relative_reconstruction_bias": 4.6397265625,
      "loss_original": 1.8141272115707396,
      "loss_reconstructed": 2.1981925880908966,
      "loss_zero": 12.452932243347169,
      "frac_recovered": 0.9638544994592667,
      "frac_alive": 0.96929931640625,
      "hyperparameters": {
        "n_inputs": 200,
        "context_length": 1024,
        "l0": 71,
        "layer": 20,
        "width": "16k"
      }
    }
  }

However, the 65k for layer 20 has really weird metrics, including a very poor loss recovered (i.e. Equation 10 from the gated SAEs paper: https://arxiv.org/pdf/2404.16014), despite having a low L2 loss. I thought it may be a quirk of the dataset, but have reproduced this somewhat on monology/pile-uncopyrighted:

{
  "l0_114": {
    "l2_loss": 65.14174501419068,
    "l1_loss": 326.4906903076172,
    "l0": 19.7434326171875,
    "frac_variance_explained": -1.1298050680756568,
    "cossim": 0.44833588257431983,
    "l2_ratio": 1.458713674545288,
    "relative_reconstruction_bias": 3.926573168039322,
    "loss_original": 2.151599160730839,
    "loss_reconstructed": 12.79894030570984,
    "loss_zero": 12.452933530807496,
    "frac_recovered": -0.03705257594643627,
    "frac_alive": 0.1755828857421875,
    "hyperparameters": {
      "n_inputs": 200,
      "context_length": 1024,
      "l0": 114,
      "layer": 20,
      "width": "65k"
    }
  },
  "l0_20": {
    "l2_loss": 78.6240915298462,
    "l1_loss": 274.0826930999756,
    "l0": 6.4778857421875,
    "frac_variance_explained": -8.00341603398323,
    "cossim": 0.3754740992188454,
    "l2_ratio": 1.6491711509227753,
    "relative_reconstruction_bias": 8.075071120262146,
    "loss_original": 2.151599160730839,
    "loss_reconstructed": 18.244347710609436,
    "loss_zero": 12.452933530807496,
    "frac_recovered": -0.5657323953509331,
    "frac_alive": 0.02691650390625,
    "hyperparameters": {
      "n_inputs": 200,
      "context_length": 1024,
      "l0": 20,
      "layer": 20,
      "width": "65k"
    }
  },
  "l0_221": {
    "l2_loss": 61.26867036819458,
    "l1_loss": 394.06997283935544,
    "l0": 30.2818212890625,
    "frac_variance_explained": -0.004639597833156586,
    "cossim": 0.47954541400074957,
    "l2_ratio": 1.4224228554964065,
    "relative_reconstruction_bias": 2.8707287490367888,
    "loss_original": 2.151599160730839,
    "loss_reconstructed": 10.630927562713623,
    "loss_zero": 12.452933530807496,
    "frac_recovered": 0.17530182713409886,
    "frac_alive": 0.2276763916015625,
    "hyperparameters": {
      "n_inputs": 200,
      "context_length": 1024,
      "l0": 221,
      "layer": 20,
      "width": "65k"
    }
  },
  "l0_34": {
    "l2_loss": 77.58435577392578,
    "l1_loss": 281.88170654296874,
    "l0": 8.2050439453125,
    "frac_variance_explained": -10.130443168580532,
    "cossim": 0.41340469181537626,
    "l2_ratio": 1.6560911977291106,
    "relative_reconstruction_bias": 9.29422394156456,
    "loss_original": 2.151599160730839,
    "loss_reconstructed": 17.128004446029664,
    "loss_zero": 12.452933530807496,
    "frac_recovered": -0.4539519951120019,
    "frac_alive": 0.065582275390625,
    "hyperparameters": {
      "n_inputs": 200,
      "context_length": 1024,
      "l0": 34,
      "layer": 20,
      "width": "65k"
    }
  },
  "l0_61": {
    "l2_loss": 77.927738571167,
    "l1_loss": 314.41664611816407,
    "l0": 14.6854248046875,
    "frac_variance_explained": -7.959465856552124,
    "cossim": 0.41553613662719724,
    "l2_ratio": 1.6819834589958191,
    "relative_reconstruction_bias": 7.942268486022949,
    "loss_original": 2.151599160730839,
    "loss_reconstructed": 15.694214601516723,
    "loss_zero": 12.452933530807496,
    "frac_recovered": -0.31723272004863245,
    "frac_alive": 0.1244659423828125,
    "hyperparameters": {
      "n_inputs": 200,
      "context_length": 1024,
      "l0": 61,
      "layer": 20,
      "width": "65k"
    }
  }
}

I will evaluate some other SAEs and other gemma models to see if this is just a specific problem with this SAE in this model in this layer. I did all evaluation with the dictionary_learning repo (https://github.com/saprmarks/dictionary_learning). But would be good if someone sanity checks me / tells me if I'm missing something.

charlieoneill changed discussion status to closed

Discussion closed? Was this a bug?

Sign up or log in to comment