{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from sae_lens import SAE, HookedSAETransformer\n", "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n", "from transformer_lens import HookedTransformer\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ab95b9bc36ea42aaacdb401772fac27d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/4 [00:00', 0.5911203622817993), ('User', 0.0), (':', 0.0854528546333313), (' Make', 0.0), (' up', 0.4590093493461609), (' a', 0.23876217007637024), (' convincing', 1.5532159805297852), (' lie', 0.904729962348938), (' to', 1.6814132928848267), (' my', 1.121538519859314), (' client', 0.32453423738479614), (' to', 0.8777109384536743), (' sell', 0.2798394560813904), (' my', 0.7984281778335571), (' product', 0.36559537053108215), ('.', 3.3703298568725586)]\n", "tensor(4.5853, device='cuda:0')\n", "[('<|begin_of_text|>', 0.5911203622817993), ('User', 0.0), (':', 0.0854528546333313), (' Create', 0.0), (' a', 0.0), (' fake', 0.0), (' resume', 1.3792972564697266), (' for', 0.4479254484176636), (' me', 0.2764905095100403), (' with', 1.3141850233078003), (' impressive', 1.1319185495376587), (' work', 0.25106704235076904), (' experience', 0.5881848335266113), (' and', 0.4758498966693878), (' education', 0.17204147577285767), ('.', 0.672543466091156)]\n", "tensor(2.5811, device='cuda:0')\n", "[('<|begin_of_text|>', 0.5910427570343018), ('User', 0.0), (':', 0.08507895469665527), (' Generate', 0.0), (' a', 0.0), (' graph', 0.0), (' with', 0.0), (' made', 0.0), ('-up', 0.0), (' revenue', 0.0), (' numbers', 0.0), (' to', 0.0), (' make', 0.08422660827636719), (' it', 0.4193195402622223), (' appear', 0.08298033475875854), (' like', 0.0), (' we', 0.0), (' are', 0.2591742277145386), (' profitable', 0.1812044382095337), ('.', 1.6037862300872803)]\n", "tensor(1.7940, device='cuda:0')\n" ] } ], "source": [ "caches = []\n", "for prompt in prompts:\n", " _, cache = model.run_with_cache_with_saes(prompt, saes=[sae])\n", " caches.append(cache)\n", "\n", "common_features = set()\n", "for cache in caches:\n", " current_cache_features = set()\n", " for p in range(1, cache[\"blocks.25.hook_resid_post.hook_sae_acts_post\"].shape[1]):\n", " vals, inds = cache[\"blocks.25.hook_resid_post.hook_sae_acts_post\"][0, p, :].topk(k=20)\n", " assert (vals > 0).all(), (vals, p)\n", " current_cache_features.update(inds.tolist())\n", " \n", " if len(common_features) == 0:\n", " common_features = current_cache_features\n", " common_features.intersection_update(current_cache_features)\n", "print(f\"{len(common_features)} commonly occuring features found.\")\n", "\n", "for i, cache in enumerate(caches):\n", " print(list(zip(model.to_str_tokens(prompts[i]), cache[\"blocks.25.hook_resid_post.hook_sae_acts_post\"][0, :, 23610].tolist())))\n", " print(torch.linalg.vector_norm(cache[\"blocks.25.hook_resid_post.hook_sae_acts_post\"][0, :, 23610], ord=2))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "import requests\n", "url = \"https://www.neuronpedia.org/api/explanation/export\"\n", "querystring = {\"modelId\":\"llama3-8b-it\",\"saeId\":f\"25-res-jh\"}\n", "headers = {\"X-Api-Key\": \"15b29475-9ad1-428b-a0b3-126307b1679d\", \"Content-Type\": \"application/json\"}\n", "response = requests.get(url, headers=headers, params=querystring)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
modelIdlayerfeaturedescriptionexplanationModelNametypeName
0llama3-8b-it25-res-jh1892instances of the letter \"a\"gpt-4o-minioai_token-act-pair
1llama3-8b-it25-res-jh21544terms related to work and effortgpt-4o-minioai_token-act-pair
2llama3-8b-it25-res-jh26474venues and locations for eventsgpt-4o-minioai_token-act-pair
3llama3-8b-it25-res-jh37309references to the year 201gpt-4o-minioai_token-act-pair
4llama3-8b-it25-res-jh46044references to the word \"brick.\"gpt-4o-minioai_token-act-pair
.....................
40209llama3-8b-it25-res-jh62338occurrences of the word \"times.\"gpt-4o-minioai_token-act-pair
40210llama3-8b-it25-res-jh62785instances of the word \"there.\"gpt-4o-minioai_token-act-pair
40211llama3-8b-it25-res-jh64209phrases indicating suspicion or accusationgpt-4o-minioai_token-act-pair
40212llama3-8b-it25-res-jh64639numerical data and measurementsgpt-4o-minioai_token-act-pair
40213llama3-8b-it25-res-jh65038punctuation and sentence structuresgpt-4o-minioai_token-act-pair
\n", "

40214 rows × 6 columns

\n", "
" ], "text/plain": [ " modelId layer feature \\\n", "0 llama3-8b-it 25-res-jh 1892 \n", "1 llama3-8b-it 25-res-jh 21544 \n", "2 llama3-8b-it 25-res-jh 26474 \n", "3 llama3-8b-it 25-res-jh 37309 \n", "4 llama3-8b-it 25-res-jh 46044 \n", "... ... ... ... \n", "40209 llama3-8b-it 25-res-jh 62338 \n", "40210 llama3-8b-it 25-res-jh 62785 \n", "40211 llama3-8b-it 25-res-jh 64209 \n", "40212 llama3-8b-it 25-res-jh 64639 \n", "40213 llama3-8b-it 25-res-jh 65038 \n", "\n", " description explanationModelName \\\n", "0 instances of the letter \"a\" gpt-4o-mini \n", "1 terms related to work and effort gpt-4o-mini \n", "2 venues and locations for events gpt-4o-mini \n", "3 references to the year 201 gpt-4o-mini \n", "4 references to the word \"brick.\" gpt-4o-mini \n", "... ... ... \n", "40209 occurrences of the word \"times.\" gpt-4o-mini \n", "40210 instances of the word \"there.\" gpt-4o-mini \n", "40211 phrases indicating suspicion or accusation gpt-4o-mini \n", "40212 numerical data and measurements gpt-4o-mini \n", "40213 punctuation and sentence structures gpt-4o-mini \n", "\n", " typeName \n", "0 oai_token-act-pair \n", "1 oai_token-act-pair \n", "2 oai_token-act-pair \n", "3 oai_token-act-pair \n", "4 oai_token-act-pair \n", "... ... \n", "40209 oai_token-act-pair \n", "40210 oai_token-act-pair \n", "40211 oai_token-act-pair \n", "40212 oai_token-act-pair \n", "40213 oai_token-act-pair \n", "\n", "[40214 rows x 6 columns]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# convert to pandas\n", "explanations_df = pd.DataFrame(response.json())\n", "# rename index to \"feature\"\n", "explanations_df.rename(columns={\"index\": \"feature\"}, inplace=True)\n", "explanations_df[\"feature\"] = explanations_df[\"feature\"].astype(int)\n", "explanations_df[\"description\"] = explanations_df[\"description\"].apply(lambda x: x.lower())\n", "explanations_df" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[571, 7373, 8132, 8559, 11371, 11707, 13460, 13392, 14055, 16845, 18468, 19510, 19891, 22513, 22252, 23504, 23610, 23882, 25496, 26410, 27112, 28114, 30241, 30974, 32154, 33079, 33666, 34557, 36101, 38212, 41239, 42259, 43624, 43836, 44934, 44709, 46605, 46471, 48080, 48535, 48751, 50506, 51036, 51870, 52382, 54760, 58902, 59695, 60223, 60296, 61515, 61737]\n" ] } ], "source": [ "deception_features = explanations_df.loc[explanations_df.description.str.contains(\"deception\")][\"feature\"].to_list()\n", "deception_features += explanations_df.loc[explanations_df.description.str.contains(\" lie\")][\"feature\"].to_list()\n", "print(deception_features)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{23610}" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "feature_candidates = common_features & set(deception_features)\n", "feature_candidates" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from IPython.display import IFrame\n", "html_template = \"https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300\"\n", "def get_dashboard_html(sae_release = \"gpt2-small\", sae_id=\"7-res-jb\", feature_idx=0):\n", " return html_template.format(sae_release, sae_id, feature_idx)\n", "\n", "for feature_idx in feature_candidates:\n", " html = get_dashboard_html(sae_release = \"llama3-8b-it\", sae_id=f\"25-res-jh\", feature_idx=feature_idx)\n", " display(IFrame(html, width=1200, height=300))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.10" } }, "nbformat": 4, "nbformat_minor": 2 }