mini-mistral / init_model.py
Nbardy's picture
rename
64a3ce9
raw
history blame
946 Bytes
import torch
from transformers import AutoConfig, AutoModelForCausalLM
# Load the configuration and initialize the model
config_path = "config.json" # Adjust path as necessary
config = AutoConfig.from_pretrained(config_path)
model = AutoModelForCausalLM.from_config(config)
# Reinitialize weights with a standard deviation of 0.02 for a more controlled initialization
def reinitialize_weights(module):
if hasattr(module, "weight") and not isinstance(module, torch.nn.LayerNorm):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, "bias") and module.bias is not None:
torch.nn.init.constant_(module.bias, 0.0)
model.apply(reinitialize_weights)
# Cast the model's parameters to bf16
model = model.to(
dtype=torch.bfloat16
) # Converts all floating point parameters to bfloat16
# Save the model with SafeTensors
model.save_pretrained("./micro_mistral", save_in_safe_tensors_format=True)