import torch from transformers import AutoConfig, AutoModelForCausalLM # Load the configuration and initialize the model config_path = "config.json" # Adjust path as necessary config = AutoConfig.from_pretrained(config_path) model = AutoModelForCausalLM.from_config(config) # Reinitialize weights with a standard deviation of 0.02 for a more controlled initialization def reinitialize_weights(module): if hasattr(module, "weight") and not isinstance(module, torch.nn.LayerNorm): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if hasattr(module, "bias") and module.bias is not None: torch.nn.init.constant_(module.bias, 0.0) model.apply(reinitialize_weights) # Cast the model's parameters to bf16 model = model.to( dtype=torch.bfloat16 ) # Converts all floating point parameters to bfloat16 # Save the model with SafeTensors model.save_pretrained("./micro_mistral", save_in_safe_tensors_format=True)