--- language: - en pipeline_tag: text-generation tags: - distillation - model_hub_mixin - pytorch_model_hub_mixin - simple-stories datasets: - lennart-finke/SimpleStories --- For loading this model from within [https://github.com/danbraunai/simple_stories_train](https://github.com/danbraunai/simple_stories_train), you can run: ```python from typing import Any import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin from simple_stories_train.models.llama import Llama, LlamaConfig from simple_stories_train.models.model_configs import MODEL_CONFIGS_DICT class LlamaTransformer( nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/danbraunai/simple_stories_train", language=["en"], pipeline_tag="text-generation" ): def __init__(self, **config : Any): super().__init__() self.llama = Llama(LlamaConfig(**config)) def forward(self, x : torch.Tensor): return self.llama(x) config = MODEL_CONFIGS_DICT["d12"] model = LlamaTransformer(**config) HUB_REPO_NAME = "lennart-finke/SimpleStories-125M" model = model.from_pretrained(HUB_REPO_NAME) ``` - Library: https://github.com/danbraunai/simple_stories_train