The original Llama 3 8b (base) special token weights are zero, which might cause NaN gradients. This version re-initialized the weights of all the following special tokens to alleviate the problem.

<|eot_id|>
<|start_header_id|>
<|end_header_id|>

We set the weights of these tokens in embed and lm_head to be the mean of all other tokens.

Code for making this model:

import argparse

import transformers
import torch


def init_eot_embedding_llama3(model_path, output_dir, special_tokens=["<|eot_id|>", "<|start_header_id|>", "<|end_header_id|>"], mean_cutoff=128000, dtype=torch.bfloat16):
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
    model = transformers.AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=dtype)

    assert model.model.embed_tokens.weight.shape[0] >= mean_cutoff
    assert model.lm_head.weight.shape[0]            >= mean_cutoff

    with torch.no_grad():
        for token in special_tokens:
            token_id = tokenizer.convert_tokens_to_ids(token)

            print (f"Token {token} ID {token_id}")

            model.model.embed_tokens.weight[token_id] = torch.mean(model.model.embed_tokens.weight[:mean_cutoff].to(torch.float32), dim=0).to(dtype)
            model.lm_head.weight[token_id]            = torch.mean(model.lm_head.weight[:mean_cutoff].to(torch.float32), dim=0).to(dtype)

    # Save
    tokenizer.save_pretrained(output_dir)
    model.save_pretrained(output_dir)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        help="Location of model, or HuggingFace repo ID",
    )
    parser.add_argument(
        "--output-dir",
        help="Location to write resulting model and tokenizer",
    )

    init_eot_embedding_llama3(**vars(parser.parse_args()))


if __name__ == "__main__":
    main()
Downloads last month
58
Safetensors
Model size
8.03B params
Tensor type
BF16
Β·
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.

Model tree for imone/Llama-3-8B-fixed-special-embedding

Merges
1 model
Quantizations
2 models

Spaces using imone/Llama-3-8B-fixed-special-embedding 5