{ "cells": [ { "cell_type": "code", "execution_count": 5, "metadata": { "id": "hS2zWviCGv-j" }, "outputs": [], "source": [ "model_name_or_path = \"mistralai/Mixtral-8x7B-Instruct-v0.1\"#@param {type:\"string\"}\n", "experts_extract_bit = \"10101010\" #@param {type:\"string\"}\n", "num_experts_per_tok = 2 #@param {type:\"integer\"}\n", "\n", "temp_dir = \"/content/drive/MyDrive/tf_models\" #@param {type:\"string\"}\n", "model_name = model_name_or_path.split(\"/\")[-1]\n", "target_dir = f\"{temp_dir}/{model_name}\"\n", "save_dir = \"/content/drive/MyDrive/tf_models/mx4x7b_x3\" #@param {type:\"string\"}\n", "\n", "\n", "experts_indexies = [i for i, bit in enumerate(experts_extract_bit) if bit == '1']\n", "# print( experts_indexies )\n", "\n", "if len(experts_extract_bit) != 8:\n", " raise ValueError(\"experts_extract_bit length must be 8\")\n", "" ] }, { "cell_type": "code", "source": [ "!pip install git+https://github.com/huggingface/transformers --upgrade\n", "!pip install torch accelerate bitsandbytes flash_attn sentencepiece protobuf" ], "metadata": { "id": "gJhESaUCul4-" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0kn65H6HvwXB", "outputId": "0fdbc596-8288-4a1e-b0a4-ee4904b0a32a" }, "execution_count": 2, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Mounted at /content/drive\n" ] } ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "background_save": true }, "id": "WwnZPGHATsqv" }, "outputs": [], "source": [ "%cd {temp_dir}\n", "save_model_dir = model_name.split('/')[-1]\n", "!mkdir -p {save_model_dir}\n", "\n", "!wget https://huggingface.co/{model_name_or_path}/resolve/main/config.json -O {save_model_dir}/config.json\n", "!wget https://huggingface.co/{model_name_or_path}/resolve/main/model.safetensors.index.json -O {save_model_dir}/model.safetensors.index.json\n", "!wget https://huggingface.co/{model_name_or_path}/resolve/main/generation_config.json -O {save_model_dir}/generation_config.json\n", "\n", "for i in range(1,20):\n", " file_count_str = str(i).zfill(5)\n", " !wget https://huggingface.co/{model_name_or_path}/resolve/main/model-{file_count_str}-of-00019.safetensors?download=true -O {save_model_dir}/model-{file_count_str}-of-00019.safetensors" ] }, { "cell_type": "code", "source": [ "def download_tokenizer_model(save_tokenizer_dir):\n", " !wget https://huggingface.co/{model_name_or_path}/resolve/main/tokenizer.model -O {save_tokenizer_dir}/tokenizer.model\n" ], "metadata": { "id": "SDnbZoAMSEB8" }, "execution_count": 5, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GpHX5HoDPCEM" }, "outputs": [], "source": [ "%cd {temp_dir}\n", "\n", "import json\n", "import re\n", "import torch\n", "from safetensors import safe_open\n", "from safetensors.torch import save_file\n", "\n", "# model-00001-of-00019.safetensors\n", "# model.safetensors.index.json\n", "\n", "# save tokenizer\n", "from transformers import AutoTokenizer\n", "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)\n", "tokenizer.save_pretrained(save_dir)\n", "\n", "# save config\n", "config_path = f\"{target_dir}/config.json\"\n", "config = None\n", "with open(config_path, \"r\") as f:\n", " config = json.load(f)\n", " config[\"num_experts_per_tok\"] = num_experts_per_tok if len(experts_indexies) >= num_experts_per_tok else 1\n", " config[\"num_local_experts\"] = len(experts_indexies)\n", "\n", "# save config\n", "with open(f\"{save_dir}/config.json\", \"w\") as f:\n", " json.dump(config, f, indent=2)\n", "\n", "\n", "# weight\n", "weight_map = {}\n", "first_weights = [\"lm_head.weight\", \"model.embed_tokens.weight\", \"model.norm.weight\"]\n", "\n", "# load weight map\n", "bin_index_path = f\"{target_dir}/model.safetensors.index.json\"\n", "with open(bin_index_path, \"r\") as f:\n", " weight_map = json.load(f)[\"weight_map\"]\n", "\n", "def tensor_load(file_name, map_location=None):\n", " tensors = {}\n", " with safe_open(file_name, framework=\"pt\") as f:\n", " for k in f.keys():\n", " tensors[k] = f.get_tensor(k)\n", " return tensors\n", "\n", "def get_weight_byte_size(weight):\n", "\n", " if isinstance(weight, torch.Tensor):\n", " weight_byte_size = weight.nelement() * weight.element_size()\n", " else:\n", " weight_byte_size = sum(p.nelement() * p.element_size() for p in weight.parameters())\n", "\n", " return weight_byte_size\n", "\n", "# load weight map\n", "layers = {}\n", "for key in weight_map.keys():\n", " if key in first_weights:\n", " continue\n", "\n", " # keyが\"model.layers.[0-9]+.\"にmatchする場合はlayers_listに追加する\n", " layer_str = re.match(r\"model\\.layers\\.[0-9]+\\.\", key)[0]\n", " if layer_str:\n", " layer_no = re.findall(r\"\\d+\",layer_str)\n", " layer_no = layer_no[0]\n", " if layer_no not in layers.keys():\n", " layers[layer_no] = []\n", "\n", " layers[layer_no].append({ \"key\":key, \"file_name\":weight_map[key] })\n", "\n", "# new weight_map index\n", "new_weight_map = {\n", " \"metadata\": {\n", " \"total_size\": 0\n", " },\n", " \"weight_map\": {\n", " }\n", "}\n", "\n", "# load tensors\n", "total_size = 0\n", "tensor_weights = {}\n", "tensors = {}\n", "current_file_name = \"\"\n", "\n", "file_count = 0\n", "file_count_str = str(file_count).zfill(5)\n", "\n", "for key in first_weights:\n", " file_name = weight_map[key]\n", " if current_file_name != file_name:\n", "\n", " # load safetensor\n", " tensors = tensor_load(f\"{target_dir}/{file_name}\", map_location=\"cpu\")\n", " current_file_name = file_name\n", "\n", " tensor_weights[key] = tensors[key]\n", " new_weight_map[\"weight_map\"][key] = f\"model-{file_count_str}.safetensors\"\n", "\n", " # add weight size\n", " total_size += get_weight_byte_size(tensor_weights[key])\n", "\n", "# save tensor\n", "save_file(tensor_weights, f\"{save_dir}/model-{file_count_str}.safetensors\", metadata={\"format\":\"pt\"})\n", "file_count += 1\n", "\n", "layer_keys = sorted([ int(k) for k in layers.keys()])\n", "\n", "for layer_no in layer_keys:\n", " print(\"starting layer:\",layer_no)\n", " file_count_str = str(file_count).zfill(5)\n", " tensor_weights = {}\n", "\n", " stock_expert_weights = {}\n", "\n", " current_file_name = \"\"\n", " for info in layers[str(layer_no)]:\n", " file_name = info[\"file_name\"]\n", " if current_file_name != file_name:\n", " print(\"Loading Tensors \", file_name)\n", " tensors = tensor_load(f\"{target_dir}/{file_name}\", map_location=\"cpu\")\n", " current_file_name = file_name\n", "\n", " layer_key = info[\"key\"]\n", " layer_weights = tensors[layer_key]\n", "\n", " if 'experts' in layer_key:\n", "\n", " lk = re.findall(r\"block_sparse_moe[.]experts[.][0-9]+.w\", layer_key)[0]\n", " exp_index = int( re.findall(r\"\\d+\",lk)[0] )\n", "\n", " # select target experts\n", " if exp_index in experts_indexies:\n", " new_layer_key = re.sub(r\"block_sparse_moe\\.experts\\.\\d+\\.w\", f\"block_sparse_moe.experts.{experts_indexies.index(exp_index)}.w\", layer_key)\n", "\n", " tensor_weights[new_layer_key] = layer_weights\n", "\n", " # add weight size\n", " total_size += get_weight_byte_size(tensor_weights[new_layer_key])\n", "\n", " new_weight_map[\"weight_map\"][new_layer_key] = f\"model-{file_count_str}.safetensors\"\n", " print(\"new experts\", new_layer_key, tensor_weights[new_layer_key].shape, \"from\", layer_key)\n", "\n", " elif 'gate' in layer_key:\n", " print(\"slice gate \", experts_indexies, layer_weights.shape, f\"-> ({len(experts_indexies)}, 4096)\", layer_key)\n", "\n", " # slice gate\n", " tensor_weights[layer_key] = layer_weights[experts_indexies]\n", "\n", " # add weight size\n", " total_size += get_weight_byte_size(tensor_weights[layer_key])\n", "\n", " new_weight_map[\"weight_map\"][layer_key] = f\"model-{file_count_str}.safetensors\"\n", " print(layer_key, tensor_weights[layer_key].shape)\n", "\n", " else:\n", " tensor_weights[layer_key] = layer_weights\n", "\n", " # add weight size\n", " total_size += get_weight_byte_size(tensor_weights[layer_key])\n", "\n", " new_weight_map[\"weight_map\"][layer_key] = f\"model-{file_count_str}.safetensors\"\n", " print(layer_key, tensor_weights[layer_key].shape)\n", "\n", " # save tensor\n", " save_file(tensor_weights, f\"{save_dir}/model-{file_count_str}.safetensors\", metadata={\"format\":\"pt\"})\n", " print(\"Save Tensors \", f\"{save_dir}/model-{file_count_str}.safetensors\")\n", " file_count += 1\n", "\n", "# save new_weight_map\n", "new_weight_map[\"metadata\"][\"total_size\"] = total_size\n", "with open(f\"{save_dir}/model.safetensors.index.json\", \"w\") as f:\n", " json.dump(new_weight_map, f, indent=2)\n", "\n", "# download tokenizer.model\n", "download_tokenizer_model(save_dir)\n", "\n", "print(\"Done.\")\n" ] }, { "cell_type": "code", "source": [ "from transformers import AutoTokenizer, AutoModelForCausalLM, MixtralForCausalLM\n", "import torch\n", "\n", "model_name_or_path = save_dir\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n", "model = MixtralForCausalLM.from_pretrained(model_name_or_path, load_in_8bit=True)\n", "\n", "text = \"[INST] What was John Holt's vision on education? [/INST] \"\n", "# text = \"[INST] What is the best anime? [/INST] \"\n", "inputs = tokenizer(\" \" + text, return_tensors=\"pt\")\n", "\n", "outputs = model.generate(**inputs, max_new_tokens=128)\n", "print(tokenizer.decode(outputs[0], skip_special_tokens=True))\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 138, "referenced_widgets": [ "e9cfddc47787435993179dcbfb1fb89c", "19f57b50941e4f929d12aadc944ce01b", "b27b7fcb570c4f95bf4adeefba1aa146", "05de7d1f41054d28b69479146f8cd557", "41b69b2ea8204f69b87d7ecc8c83977b", "3436fe87cab3486a829cd80f0805a099", "d7325fddbaea46639294e0831aa4df18", "420b8967a123405bb6349b2bee24a4d3", "c480548c4a3f42628cde8b011ababff0", "a28af63811954d6ebfbbb1d4ef437627", "f07257d9b18a4f69b2efea9550d12014" ] }, "id": "3dFbRvPe8yyK", "outputId": "5b7f3a08-68f7-4e16-ecc4-cf36d9756f66" }, "execution_count": 4, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Loading checkpoint shards: 0%| | 0/33 [00:00