NotShrirang commited on
Commit
f4e648b
·
1 Parent(s): 1510533

feat: add application file

Browse files
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ weights/
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+ cover/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ .pybuilder/
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ # For a library or package, you might want to ignore these files since the code is
89
+ # intended to run in multiple environments; otherwise, check them in:
90
+ # .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # poetry
100
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
102
+ # commonly ignored for libraries.
103
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104
+ #poetry.lock
105
+
106
+ # pdm
107
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108
+ #pdm.lock
109
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110
+ # in version control.
111
+ # https://pdm.fming.dev/#use-with-ide
112
+ .pdm.toml
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
README.md CHANGED
@@ -1,14 +1,191 @@
1
- ---
2
- title: QuillGPT
3
- emoji: 📉
4
- colorFrom: yellow
5
- colorTo: yellow
6
- sdk: streamlit
7
- sdk_version: 1.40.2
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Implementation of the GPT decoder block in PyTorch
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![QuillGPT-cropped-removebg-preview](https://github.com/NotShrirang/QuillGPT/assets/85283622/2e63d8ce-24f8-4bf0-835a-0c621f1d7400)
2
+
3
+ # QuillGPT
4
+
5
+ ![GitHub stars](https://img.shields.io/github/stars/NotShrirang/GPT-From-Scratch?style=social)
6
+ ![GitHub forks](https://img.shields.io/github/forks/NotShrirang/GPT-From-Scratch?style=social)
7
+ ![GitHub commits](https://img.shields.io/github/commit-activity/t/NotShrirang/QuillGPT)
8
+ ![GitHub issues](https://img.shields.io/github/issues/NotShrirang/GPT-From-Scratch)
9
+ ![GitHub pull requests](https://img.shields.io/github/issues-pr/NotShrirang/GPT-From-Scratch)
10
+ ![GitHub](https://img.shields.io/github/license/NotShrirang/GPT-From-Scratch)
11
+ ![GitHub last commit](https://img.shields.io/github/last-commit/NotShrirang/GPT-From-Scratch)
12
+ ![GitHub repo size](https://img.shields.io/github/repo-size/NotShrirang/GPT-From-Scratch)
13
+ ![Streamlit Playground](https://img.shields.io/badge/Streamlit%20App-red?style=flat-rounded-square&logo=streamlit&labelColor=white)
14
+ ![Docker Container](https://img.shields.io/badge/docker-blue?style=flat-rounded-square&logo=docker&labelColor=white)
15
+
16
+ QuillGPT is an implementation of the GPT decoder block based on the architecture from [Attention is All You Need](https://arxiv.org/abs/1706.03762) paper by Vaswani et. al. implemented in PyTorch. Additionally, this repository contains two pre-trained models—Shakespearean GPT and Harpoon GPT—along with their trained weights. For ease of experimentation and deployment, a Streamlit Playground is provided for interactive exploration of these models and FastAPI microservice implemented with Docker containerization for scalable deployment. You'll also find Python scripts for training new GPT models and performing inference on them, along with notebooks showcasing trained models. To facilitate text encoding and decoding, a simple tokenizer is implemented. Explore QuillGPT to utilize these tools and enhance your natural language processing projects!
17
+
18
+ ## Table of Contents
19
+
20
+ - [Models](#models)
21
+ - [Getting Started](#getting-started)
22
+ - [Installation](#installation)
23
+ - [Streamlit Playground](#streamlit-playground)
24
+ - [FastAPI Microservice](#for-running-fastapi-microservice)
25
+ - [Running Docker Container](#for-using-containerized-version)
26
+ - [Usage](#usage)
27
+ - [Training the GPT Model](#training-the-gpt-model)
28
+ - [Using the Trained Model for Inference](#for-inference)
29
+ - [Explanation](#explanation)
30
+ - [Decoder Block](#the-decoder-block)
31
+ - [Input Embeddings](#input-embeddings)
32
+ - [Positional Embeddings](#positional-embeddings)
33
+ - [Self-Attention](#self-attention)
34
+ - [License](#license)
35
+ - [Contributing](#contributing)
36
+ - [Support](#support)
37
+
38
+ ## <div align="center">Models</div>
39
+
40
+ There are two pre-trained models and weights included in this repository.
41
+
42
+ | Feature | Shakespearean GPT | Harpoon GPT |
43
+ | ------------------------------ | --------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------- |
44
+ | **Parameters** | 10.7 M | 226 M |
45
+ | **Weights** | [Weights](https://github.com/NotShrirang/GPT-From-Scratch/blob/main/weights/GPT_model_char.pt) | [Weights](https://www.dropbox.com/scl/fi/vi5z3s17otn0jf7sr40po/Harpoon_Corpus_GPT_model.pt?rlkey=r7oppeslusv736fzmi908le95&st=wak0uf2t&dl=0) |
46
+ | **Model Config** | [Config](https://github.com/NotShrirang/GPT-From-Scratch/blob/main/config/shakespearean_config.json) | [Config](https://github.com/NotShrirang/GPT-From-Scratch/blob/main/config/harpoon_config.json) |
47
+ | **Training Data** | Text from Shakespearean plays ([input.txt](https://github.com/NotShrirang/GPT-From-Scratch/blob/main/data/input.txt)) | Random text from books ([corpus.txt](https://github.com/NotShrirang/GPT-From-Scratch/blob/main/data/corpus.txt)) |
48
+ | **Embedding Type** | Character embeddings | Character embeddings |
49
+ | **Training Notebook** | [Notebook](https://github.com/NotShrirang/GPT-From-Scratch/blob/main/notebooks/GPT_From_Scratch_CharEmbeddings.ipynb) | [Notebook](https://github.com/NotShrirang/GPT-From-Scratch/blob/main/notebooks/GPT_From_Scratch_with_1024_char_embd.ipynb) |
50
+ | **Hardware** | NVIDIA T4 | NVIDIA A100 |
51
+ | **Training & Validation Loss** | ![loss](https://github.com/user-attachments/assets/df89c1f6-d89a-4a3a-8340-edcf7416878c) | ![loss](https://github.com/user-attachments/assets/76c5e0d1-a53c-4d0d-ac8f-5529ec3a5008) |
52
+
53
+ ## Getting Started:
54
+
55
+ ### Installation:
56
+
57
+ To run the training and inference scripts, follow these steps:
58
+
59
+ 1. Clone the repository:
60
+
61
+ ```sh
62
+ git clone https://github.com/NotShrirang/GPT-From-Scratch.git
63
+ cd GPT-From-Scratch
64
+ ```
65
+
66
+ 2. Install the required packages:
67
+
68
+ ```sh
69
+ pip install -r requirements.txt
70
+ ```
71
+
72
+ Make sure you download the weights for Harpoon GPT from [here](https://www.dropbox.com/scl/fi/vi5z3s17otn0jf7sr40po/Harpoon_Corpus_GPT_model.pt?rlkey=r7oppeslusv736fzmi908le95&st=wak0uf2t&dl=0) before proceeding!
73
+
74
+ ### Streamlit Playground:
75
+
76
+ It is hosted on Streamlit Cloud Service. You can visit it through the link [here](https://quillgpt.streamlit.app/).
77
+
78
+ [![Streamlit Demo](https://github.com/NotShrirang/GPT-From-Scratch/assets/85283622/fa888670-2c44-4f97-a07d-c58473d847d0)](https://quillgpt.streamlit.app/)
79
+
80
+ ```sh
81
+ streamlit run app.py
82
+ ```
83
+
84
+ ### For running FastAPI Microservice:
85
+
86
+ ```sh
87
+ python main.py
88
+ ```
89
+
90
+ ### For using Containerized Version:
91
+
92
+ #### Build and Run the Docker Container with bash:
93
+
94
+ ```sh
95
+ ./run.sh start-dev
96
+ ```
97
+
98
+ #### To stop the Docker Container, run the following command:
99
+
100
+ ```sh
101
+ ./run.sh stop-dev
102
+ ```
103
+
104
+ ## Usage
105
+
106
+ ### Training the GPT Model:
107
+
108
+ To train the GPT model, follow these steps:
109
+
110
+ 1. Prepare data. Put the whole text data into single .txt file and save it.
111
+ 2. Write the configurations for transformer and save the file.
112
+ <br>For example:
113
+ `json
114
+ {
115
+ "data_path": "data/corpus.txt",
116
+ "vocab_size": 135,
117
+ "batch_size": 32,
118
+ "block_size": 256,
119
+ "max_iters": 3000,
120
+ "eval_interval": 300,
121
+ "learning_rate": 3e-5,
122
+ "eval_iters": 50,
123
+ "n_embd": 1024,
124
+ "n_head": 12,
125
+ "n_layer": 18,
126
+ "dropout": 0.3,
127
+ }
128
+ `
129
+
130
+ 3. Train model using script `scripts/train_gpt.py`
131
+
132
+ ```bash
133
+ python scripts/train_gpt.py \
134
+ --config_path config/config.json \
135
+ --data_path data/corpus.txt \
136
+ --output_dir trained_models
137
+ ```
138
+
139
+ (You can change the `config_path`, `data_path` and `output_dir` as per your requirements.)
140
+
141
+ 4. The trained model will be saved in the `output_dir` specified in the command.
142
+
143
+ ### For Inference:
144
+
145
+ After training, you can use the trained GPT model for text generation. Here's an example of using the trained model for inference:
146
+
147
+ ```bash
148
+ python scripts/inference_gpt.py \
149
+ --config_path config/shakespearean_config.json \
150
+ --weights_path weights/GPT_model_char.pt \
151
+ --max_length 500 \
152
+ --prompt "Once upon a time"
153
+ ```
154
+
155
+ ## <div align="center">Explanation</div>
156
+
157
+ ### The Decoder Block:
158
+
159
+ <div align="center"><img src="https://github.com/NotShrirang/GPT-From-Scratch/assets/85283622/397049a3-10cc-49b5-8696-f19806b2668e" width=350 alt="Decoder Architecture"/></div>
160
+
161
+ The decoder block is a crucial component of the GPT (Generative Pre-trained Transformer) model, it is where GPT actually generates the text. It leverages the self-attention mechanism to process input sequences and generate coherent outputs. Each decoder block consists of multiple layers, including self-attention layers, feed-forward neural networks, and layer normalization. The self-attention layers allow the model to weigh the importance of different words in a sequence, capturing context and dependencies regardless of their positions. This enables the GPT model to generate contextually relevant text.
162
+
163
+ ### Input Embeddings:
164
+
165
+ <div align="center">![vector embeddings](https://github.com/NotShrirang/GPT-From-Scratch/assets/85283622/29b4c375-c9f0-47b9-9d34-2a21dfdf0be8)</div>
166
+
167
+ Input embeddings play a crucial role in transformer-based models like GPT by transforming input tokens into meaningful numerical representations. These embeddings serve as the initial input for the model, capturing semantic information about the words in the sequence. The process involves mapping each token in the input sequence to a high-dimensional vector space, where similar tokens are positioned closer together. This enables the model to understand the relationships between different words and effectively learn from the input data. The input embeddings are then fed into the subsequent layers of the model for further processing.
168
+
169
+ ### Positional Embeddings:
170
+
171
+ ![positional_encoding](https://github.com/NotShrirang/GPT-From-Scratch/assets/85283622/90293fb0-8f20-4dc0-adba-8c31a54ef4f4)
172
+
173
+ In addition to input embeddings, positional embeddings are another vital component of transformer architectures such as GPT. Since transformers lack inherent information about the order of tokens in a sequence, positional embeddings are introduced to provide the model with positional information. These embeddings encode the position of each token within the sequence, allowing the model to distinguish between tokens based on their positions. By incorporating positional embeddings, transformers like GPT can effectively capture the sequential nature of data and generate coherent outputs that maintain the correct order of words in the generated text.
174
+
175
+ ### Self-Attention:
176
+
177
+ ![self attention](https://github.com/NotShrirang/GPT-From-Scratch/assets/85283622/a6d785e4-ab00-4da0-a072-791f680d2bb8)
178
+
179
+ Self-attention, a fundamental mechanism in transformer-based models like GPT, operates by assigning importance scores to different words in a sequence. This process involves three key steps: calculating attention scores, applying softmax to obtain attention weights, and finally combining these weights with the input embeddings to generate contextually informed representations. At its core, self-attention allows the model to focus more on relevant words while de-emphasizing less important ones, facilitating effective learning of contextual dependencies within the input data. This mechanism is pivotal in capturing long-range dependencies and contextual nuances, enabling transformer models to generate long sequences of text.
180
+
181
+ ## License
182
+
183
+ MIT © [Shrirang Mahajan](https://github.com/NotShrirang)
184
+
185
+ ## Contributing
186
+
187
+ Feel free to submit pull requests, create issues, or spread the word!
188
+
189
+ ## Support
190
+
191
+ Support me by simply starring this repository! ⭐
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import streamlit as st
3
+ from colorama import Fore
4
+ from core.models.gpt import GPTLanguageModel
5
+ from core.tokenizers.tokenizer import Tokenizer
6
+ from core.utils.gptutils import hyperparameters, load_data
7
+
8
+ st.set_page_config(layout='wide',
9
+ page_title='QuillGPT',
10
+ page_icon='🪶',
11
+ initial_sidebar_state='expanded'
12
+ )
13
+
14
+ def decode_text(input, model: GPTLanguageModel, max_tokens, temperature):
15
+ for idx in model.generate(idx=input, max_new_tokens=max_tokens, max_seq_length=50, temperature=temperature):
16
+ text = tokenizer.decode(idx[0].tolist())[-1]
17
+ yield text
18
+
19
+ models = {
20
+ "Shakespearean GPT": './weights/GPT_model_char.pt',
21
+ "GPT": './weights/Harpoon_Corpus_GPT_model_word2.pt',
22
+ }
23
+
24
+ st.sidebar.header('QuillGPT')
25
+
26
+ st.sidebar.write("This app generates text using a GPT model trained on either the Harpoon corpus or Shakespearean plays.")
27
+
28
+ # Select one of the two model
29
+ model_name = st.sidebar.selectbox('Select a model:', list(models.keys()))
30
+ if model_name == "GPT":
31
+ st.title('GPT From Scratch')
32
+ st.write("This model was trained on the Harpoon corpus.")
33
+ else:
34
+ st.title('Shakespearean GPT')
35
+ st.write("This model was trained on Shakespearean plays.")
36
+
37
+ path = models[model_name]
38
+
39
+ if model_name == "GPT":
40
+ config_path = './config/harpoon_config.json'
41
+ data_path = './data/corpus.txt'
42
+ name = "Harpoon GPT"
43
+ tokenizer: Tokenizer = Tokenizer()
44
+ tokenizer.from_pretrained(config_path)
45
+ vocab_size = tokenizer.vocab_size
46
+ (batch_size, block_size, max_iters, eval_interval, learning_rate, device,
47
+ eval_iters, n_embd, n_head, n_layer, dropout) = hyperparameters(config_path=config_path)
48
+
49
+ elif model_name == "Shakespearean GPT":
50
+ config_path = './config/shakespearean_config.json'
51
+ data_path = './data/input.txt'
52
+ name = "Shakespearean GPT"
53
+ tokenizer: Tokenizer = Tokenizer()
54
+ tokenizer.from_pretrained(config_path)
55
+ vocab_size = tokenizer.vocab_size
56
+ (batch_size, block_size, max_iters, eval_interval, learning_rate, device,
57
+ eval_iters, n_embd, n_head, n_layer, dropout) = hyperparameters(config_path=config_path)
58
+
59
+
60
+ if model_name == "GPT":
61
+ input_text = st.text_area(
62
+ 'Enter a prompt:', 'And then Ted said, "'
63
+ )
64
+ else:
65
+ input_text = st.text_area(
66
+ 'Enter a prompt:', 'Write a scene about ROMEO arguing with JULIET. \nROMEO:'
67
+ )
68
+
69
+ temperature = st.sidebar.slider('Temperature:', 0.1, 1.0, 0.5, 0.1)
70
+ max_tokens = st.sidebar.slider('Max Tokens:', 250, 1000, 500, 50)
71
+
72
+ @st.cache_resource
73
+ def load_model(path):
74
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
75
+
76
+ try:
77
+ model = GPTLanguageModel(
78
+ vocab_size, n_embd, block_size, n_head, n_layer, dropout, device, name=name
79
+ ).to(device)
80
+ state_dict = torch.load(
81
+ path, map_location=device)
82
+
83
+ model.load_state_dict(state_dict)
84
+ return model, device
85
+ except FileNotFoundError as e:
86
+ st.error(f"Don't forget to download the model weights from the link in the README.md file.")
87
+ return None, None
88
+
89
+
90
+ model, device = load_model(path)
91
+
92
+
93
+ if model:
94
+ if st.button('Generate Text'):
95
+ prompt = input_text
96
+ st.subheader(model.name)
97
+ input = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device)
98
+ generated_text = []
99
+ st.write(f":green[{prompt}]")
100
+ st.write_stream(decode_text(input, model, max_tokens, temperature))
config/config.json ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 135,
3
+ "batch_size": 32,
4
+ "block_size": 256,
5
+ "max_iters": 3000,
6
+ "eval_interval": 300,
7
+ "learning_rate": 3e-5,
8
+ "eval_iters": 50,
9
+ "n_embd": 1024,
10
+ "n_head": 8,
11
+ "n_layer": 8,
12
+ "dropout": 0.3,
13
+ "encode": {
14
+ "\n": 0,
15
+ " ": 1,
16
+ "!": 2,
17
+ "\"": 3,
18
+ "#": 4,
19
+ "$": 5,
20
+ "%": 6,
21
+ "&": 7,
22
+ "'": 8,
23
+ "(": 9,
24
+ ")": 10,
25
+ "*": 11,
26
+ "+": 12,
27
+ ",": 13,
28
+ "-": 14,
29
+ ".": 15,
30
+ "/": 16,
31
+ "0": 17,
32
+ "1": 18,
33
+ "2": 19,
34
+ "3": 20,
35
+ "4": 21,
36
+ "5": 22,
37
+ "6": 23,
38
+ "7": 24,
39
+ "8": 25,
40
+ "9": 26,
41
+ ":": 27,
42
+ ";": 28,
43
+ "<": 29,
44
+ "=": 30,
45
+ ">": 31,
46
+ "?": 32,
47
+ "@": 33,
48
+ "A": 34,
49
+ "B": 35,
50
+ "C": 36,
51
+ "D": 37,
52
+ "E": 38,
53
+ "F": 39,
54
+ "G": 40,
55
+ "H": 41,
56
+ "I": 42,
57
+ "J": 43,
58
+ "K": 44,
59
+ "L": 45,
60
+ "M": 46,
61
+ "N": 47,
62
+ "O": 48,
63
+ "P": 49,
64
+ "Q": 50,
65
+ "R": 51,
66
+ "S": 52,
67
+ "T": 53,
68
+ "U": 54,
69
+ "V": 55,
70
+ "W": 56,
71
+ "X": 57,
72
+ "Y": 58,
73
+ "Z": 59,
74
+ "[": 60,
75
+ "\\": 61,
76
+ "]": 62,
77
+ "^": 63,
78
+ "_": 64,
79
+ "`": 65,
80
+ "a": 66,
81
+ "b": 67,
82
+ "c": 68,
83
+ "d": 69,
84
+ "e": 70,
85
+ "f": 71,
86
+ "g": 72,
87
+ "h": 73,
88
+ "i": 74,
89
+ "j": 75,
90
+ "k": 76,
91
+ "l": 77,
92
+ "m": 78,
93
+ "n": 79,
94
+ "o": 80,
95
+ "p": 81,
96
+ "q": 82,
97
+ "r": 83,
98
+ "s": 84,
99
+ "t": 85,
100
+ "u": 86,
101
+ "v": 87,
102
+ "w": 88,
103
+ "x": 89,
104
+ "y": 90,
105
+ "z": 91,
106
+ "{": 92,
107
+ "|": 93,
108
+ "}": 94,
109
+ "\u00a0": 95,
110
+ "\u00a3": 96,
111
+ "\u00b0": 97,
112
+ "\u00b2": 98,
113
+ "\u00b3": 99,
114
+ "\u00bc": 100,
115
+ "\u00bd": 101,
116
+ "\u00be": 102,
117
+ "\u00c6": 103,
118
+ "\u00c7": 104,
119
+ "\u00c8": 105,
120
+ "\u00c9": 106,
121
+ "\u00d7": 107,
122
+ "\u00dc": 108,
123
+ "\u00e0": 109,
124
+ "\u00e1": 110,
125
+ "\u00e2": 111,
126
+ "\u00e6": 112,
127
+ "\u00e7": 113,
128
+ "\u00e8": 114,
129
+ "\u00e9": 115,
130
+ "\u00ea": 116,
131
+ "\u00eb": 117,
132
+ "\u00ee": 118,
133
+ "\u00ef": 119,
134
+ "\u00f1": 120,
135
+ "\u00f2": 121,
136
+ "\u00f4": 122,
137
+ "\u00f6": 123,
138
+ "\u00f7": 124,
139
+ "\u00f9": 125,
140
+ "\u00fb": 126,
141
+ "\u00fc": 127,
142
+ "\u2013": 128,
143
+ "\u2014": 129,
144
+ "\u2018": 130,
145
+ "\u2019": 131,
146
+ "\u201c": 132,
147
+ "\u201d": 133,
148
+ "\ufeff": 134
149
+ },
150
+ "decode": {
151
+ "0": "\n",
152
+ "1": " ",
153
+ "2": "!",
154
+ "3": "\"",
155
+ "4": "#",
156
+ "5": "$",
157
+ "6": "%",
158
+ "7": "&",
159
+ "8": "'",
160
+ "9": "(",
161
+ "10": ")",
162
+ "11": "*",
163
+ "12": "+",
164
+ "13": ",",
165
+ "14": "-",
166
+ "15": ".",
167
+ "16": "/",
168
+ "17": "0",
169
+ "18": "1",
170
+ "19": "2",
171
+ "20": "3",
172
+ "21": "4",
173
+ "22": "5",
174
+ "23": "6",
175
+ "24": "7",
176
+ "25": "8",
177
+ "26": "9",
178
+ "27": ":",
179
+ "28": ";",
180
+ "29": "<",
181
+ "30": "=",
182
+ "31": ">",
183
+ "32": "?",
184
+ "33": "@",
185
+ "34": "A",
186
+ "35": "B",
187
+ "36": "C",
188
+ "37": "D",
189
+ "38": "E",
190
+ "39": "F",
191
+ "40": "G",
192
+ "41": "H",
193
+ "42": "I",
194
+ "43": "J",
195
+ "44": "K",
196
+ "45": "L",
197
+ "46": "M",
198
+ "47": "N",
199
+ "48": "O",
200
+ "49": "P",
201
+ "50": "Q",
202
+ "51": "R",
203
+ "52": "S",
204
+ "53": "T",
205
+ "54": "U",
206
+ "55": "V",
207
+ "56": "W",
208
+ "57": "X",
209
+ "58": "Y",
210
+ "59": "Z",
211
+ "60": "[",
212
+ "61": "\\",
213
+ "62": "]",
214
+ "63": "^",
215
+ "64": "_",
216
+ "65": "`",
217
+ "66": "a",
218
+ "67": "b",
219
+ "68": "c",
220
+ "69": "d",
221
+ "70": "e",
222
+ "71": "f",
223
+ "72": "g",
224
+ "73": "h",
225
+ "74": "i",
226
+ "75": "j",
227
+ "76": "k",
228
+ "77": "l",
229
+ "78": "m",
230
+ "79": "n",
231
+ "80": "o",
232
+ "81": "p",
233
+ "82": "q",
234
+ "83": "r",
235
+ "84": "s",
236
+ "85": "t",
237
+ "86": "u",
238
+ "87": "v",
239
+ "88": "w",
240
+ "89": "x",
241
+ "90": "y",
242
+ "91": "z",
243
+ "92": "{",
244
+ "93": "|",
245
+ "94": "}",
246
+ "95": "\u00a0",
247
+ "96": "\u00a3",
248
+ "97": "\u00b0",
249
+ "98": "\u00b2",
250
+ "99": "\u00b3",
251
+ "100": "\u00bc",
252
+ "101": "\u00bd",
253
+ "102": "\u00be",
254
+ "103": "\u00c6",
255
+ "104": "\u00c7",
256
+ "105": "\u00c8",
257
+ "106": "\u00c9",
258
+ "107": "\u00d7",
259
+ "108": "\u00dc",
260
+ "109": "\u00e0",
261
+ "110": "\u00e1",
262
+ "111": "\u00e2",
263
+ "112": "\u00e6",
264
+ "113": "\u00e7",
265
+ "114": "\u00e8",
266
+ "115": "\u00e9",
267
+ "116": "\u00ea",
268
+ "117": "\u00eb",
269
+ "118": "\u00ee",
270
+ "119": "\u00ef",
271
+ "120": "\u00f1",
272
+ "121": "\u00f2",
273
+ "122": "\u00f4",
274
+ "123": "\u00f6",
275
+ "124": "\u00f7",
276
+ "125": "\u00f9",
277
+ "126": "\u00fb",
278
+ "127": "\u00fc",
279
+ "128": "\u2013",
280
+ "129": "\u2014",
281
+ "130": "\u2018",
282
+ "131": "\u2019",
283
+ "132": "\u201c",
284
+ "133": "\u201d",
285
+ "134": "\ufeff"
286
+ }
287
+ }
config/example-config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "batch_size": 32,
3
+ "block_size": 256,
4
+ "max_iters": 3000,
5
+ "eval_interval": 300,
6
+ "learning_rate": 3e-5,
7
+ "eval_iters": 50,
8
+ "n_embd": 1024,
9
+ "n_head": 12,
10
+ "n_layer": 18,
11
+ "dropout": 0.3
12
+ }
config/harpoon_config.json ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 135,
3
+ "batch_size": 32,
4
+ "block_size": 256,
5
+ "max_iters": 3000,
6
+ "eval_interval": 300,
7
+ "learning_rate": 3e-5,
8
+ "eval_iters": 50,
9
+ "n_embd": 1024,
10
+ "n_head": 12,
11
+ "n_layer": 18,
12
+ "dropout": 0.3,
13
+ "encode": {
14
+ "\n": 0,
15
+ " ": 1,
16
+ "!": 2,
17
+ "\"": 3,
18
+ "#": 4,
19
+ "$": 5,
20
+ "%": 6,
21
+ "&": 7,
22
+ "'": 8,
23
+ "(": 9,
24
+ ")": 10,
25
+ "*": 11,
26
+ "+": 12,
27
+ ",": 13,
28
+ "-": 14,
29
+ ".": 15,
30
+ "/": 16,
31
+ "0": 17,
32
+ "1": 18,
33
+ "2": 19,
34
+ "3": 20,
35
+ "4": 21,
36
+ "5": 22,
37
+ "6": 23,
38
+ "7": 24,
39
+ "8": 25,
40
+ "9": 26,
41
+ ":": 27,
42
+ ";": 28,
43
+ "<": 29,
44
+ "=": 30,
45
+ ">": 31,
46
+ "?": 32,
47
+ "@": 33,
48
+ "A": 34,
49
+ "B": 35,
50
+ "C": 36,
51
+ "D": 37,
52
+ "E": 38,
53
+ "F": 39,
54
+ "G": 40,
55
+ "H": 41,
56
+ "I": 42,
57
+ "J": 43,
58
+ "K": 44,
59
+ "L": 45,
60
+ "M": 46,
61
+ "N": 47,
62
+ "O": 48,
63
+ "P": 49,
64
+ "Q": 50,
65
+ "R": 51,
66
+ "S": 52,
67
+ "T": 53,
68
+ "U": 54,
69
+ "V": 55,
70
+ "W": 56,
71
+ "X": 57,
72
+ "Y": 58,
73
+ "Z": 59,
74
+ "[": 60,
75
+ "\\": 61,
76
+ "]": 62,
77
+ "^": 63,
78
+ "_": 64,
79
+ "`": 65,
80
+ "a": 66,
81
+ "b": 67,
82
+ "c": 68,
83
+ "d": 69,
84
+ "e": 70,
85
+ "f": 71,
86
+ "g": 72,
87
+ "h": 73,
88
+ "i": 74,
89
+ "j": 75,
90
+ "k": 76,
91
+ "l": 77,
92
+ "m": 78,
93
+ "n": 79,
94
+ "o": 80,
95
+ "p": 81,
96
+ "q": 82,
97
+ "r": 83,
98
+ "s": 84,
99
+ "t": 85,
100
+ "u": 86,
101
+ "v": 87,
102
+ "w": 88,
103
+ "x": 89,
104
+ "y": 90,
105
+ "z": 91,
106
+ "{": 92,
107
+ "|": 93,
108
+ "}": 94,
109
+ "\u00a0": 95,
110
+ "\u00a3": 96,
111
+ "\u00b0": 97,
112
+ "\u00b2": 98,
113
+ "\u00b3": 99,
114
+ "\u00bc": 100,
115
+ "\u00bd": 101,
116
+ "\u00be": 102,
117
+ "\u00c6": 103,
118
+ "\u00c7": 104,
119
+ "\u00c8": 105,
120
+ "\u00c9": 106,
121
+ "\u00d7": 107,
122
+ "\u00dc": 108,
123
+ "\u00e0": 109,
124
+ "\u00e1": 110,
125
+ "\u00e2": 111,
126
+ "\u00e6": 112,
127
+ "\u00e7": 113,
128
+ "\u00e8": 114,
129
+ "\u00e9": 115,
130
+ "\u00ea": 116,
131
+ "\u00eb": 117,
132
+ "\u00ee": 118,
133
+ "\u00ef": 119,
134
+ "\u00f1": 120,
135
+ "\u00f2": 121,
136
+ "\u00f4": 122,
137
+ "\u00f6": 123,
138
+ "\u00f7": 124,
139
+ "\u00f9": 125,
140
+ "\u00fb": 126,
141
+ "\u00fc": 127,
142
+ "\u2013": 128,
143
+ "\u2014": 129,
144
+ "\u2018": 130,
145
+ "\u2019": 131,
146
+ "\u201c": 132,
147
+ "\u201d": 133,
148
+ "\ufeff": 134
149
+ },
150
+ "decode": {
151
+ "0": "\n",
152
+ "1": " ",
153
+ "2": "!",
154
+ "3": "\"",
155
+ "4": "#",
156
+ "5": "$",
157
+ "6": "%",
158
+ "7": "&",
159
+ "8": "'",
160
+ "9": "(",
161
+ "10": ")",
162
+ "11": "*",
163
+ "12": "+",
164
+ "13": ",",
165
+ "14": "-",
166
+ "15": ".",
167
+ "16": "/",
168
+ "17": "0",
169
+ "18": "1",
170
+ "19": "2",
171
+ "20": "3",
172
+ "21": "4",
173
+ "22": "5",
174
+ "23": "6",
175
+ "24": "7",
176
+ "25": "8",
177
+ "26": "9",
178
+ "27": ":",
179
+ "28": ";",
180
+ "29": "<",
181
+ "30": "=",
182
+ "31": ">",
183
+ "32": "?",
184
+ "33": "@",
185
+ "34": "A",
186
+ "35": "B",
187
+ "36": "C",
188
+ "37": "D",
189
+ "38": "E",
190
+ "39": "F",
191
+ "40": "G",
192
+ "41": "H",
193
+ "42": "I",
194
+ "43": "J",
195
+ "44": "K",
196
+ "45": "L",
197
+ "46": "M",
198
+ "47": "N",
199
+ "48": "O",
200
+ "49": "P",
201
+ "50": "Q",
202
+ "51": "R",
203
+ "52": "S",
204
+ "53": "T",
205
+ "54": "U",
206
+ "55": "V",
207
+ "56": "W",
208
+ "57": "X",
209
+ "58": "Y",
210
+ "59": "Z",
211
+ "60": "[",
212
+ "61": "\\",
213
+ "62": "]",
214
+ "63": "^",
215
+ "64": "_",
216
+ "65": "`",
217
+ "66": "a",
218
+ "67": "b",
219
+ "68": "c",
220
+ "69": "d",
221
+ "70": "e",
222
+ "71": "f",
223
+ "72": "g",
224
+ "73": "h",
225
+ "74": "i",
226
+ "75": "j",
227
+ "76": "k",
228
+ "77": "l",
229
+ "78": "m",
230
+ "79": "n",
231
+ "80": "o",
232
+ "81": "p",
233
+ "82": "q",
234
+ "83": "r",
235
+ "84": "s",
236
+ "85": "t",
237
+ "86": "u",
238
+ "87": "v",
239
+ "88": "w",
240
+ "89": "x",
241
+ "90": "y",
242
+ "91": "z",
243
+ "92": "{",
244
+ "93": "|",
245
+ "94": "}",
246
+ "95": "\u00a0",
247
+ "96": "\u00a3",
248
+ "97": "\u00b0",
249
+ "98": "\u00b2",
250
+ "99": "\u00b3",
251
+ "100": "\u00bc",
252
+ "101": "\u00bd",
253
+ "102": "\u00be",
254
+ "103": "\u00c6",
255
+ "104": "\u00c7",
256
+ "105": "\u00c8",
257
+ "106": "\u00c9",
258
+ "107": "\u00d7",
259
+ "108": "\u00dc",
260
+ "109": "\u00e0",
261
+ "110": "\u00e1",
262
+ "111": "\u00e2",
263
+ "112": "\u00e6",
264
+ "113": "\u00e7",
265
+ "114": "\u00e8",
266
+ "115": "\u00e9",
267
+ "116": "\u00ea",
268
+ "117": "\u00eb",
269
+ "118": "\u00ee",
270
+ "119": "\u00ef",
271
+ "120": "\u00f1",
272
+ "121": "\u00f2",
273
+ "122": "\u00f4",
274
+ "123": "\u00f6",
275
+ "124": "\u00f7",
276
+ "125": "\u00f9",
277
+ "126": "\u00fb",
278
+ "127": "\u00fc",
279
+ "128": "\u2013",
280
+ "129": "\u2014",
281
+ "130": "\u2018",
282
+ "131": "\u2019",
283
+ "132": "\u201c",
284
+ "133": "\u201d",
285
+ "134": "\ufeff"
286
+ }
287
+ }
config/shakespearean_config.json ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 65,
3
+ "batch_size": 32,
4
+ "block_size": 256,
5
+ "max_iters": 3000,
6
+ "eval_interval": 300,
7
+ "learning_rate": 3e-5,
8
+ "eval_iters": 50,
9
+ "n_embd": 384,
10
+ "n_head": 6,
11
+ "n_layer": 6,
12
+ "dropout": 0.3,
13
+ "encode": {
14
+ "\n": 0,
15
+ " ": 1,
16
+ "!": 2,
17
+ "$": 3,
18
+ "&": 4,
19
+ "'": 5,
20
+ ",": 6,
21
+ "-": 7,
22
+ ".": 8,
23
+ "3": 9,
24
+ ":": 10,
25
+ ";": 11,
26
+ "?": 12,
27
+ "A": 13,
28
+ "B": 14,
29
+ "C": 15,
30
+ "D": 16,
31
+ "E": 17,
32
+ "F": 18,
33
+ "G": 19,
34
+ "H": 20,
35
+ "I": 21,
36
+ "J": 22,
37
+ "K": 23,
38
+ "L": 24,
39
+ "M": 25,
40
+ "N": 26,
41
+ "O": 27,
42
+ "P": 28,
43
+ "Q": 29,
44
+ "R": 30,
45
+ "S": 31,
46
+ "T": 32,
47
+ "U": 33,
48
+ "V": 34,
49
+ "W": 35,
50
+ "X": 36,
51
+ "Y": 37,
52
+ "Z": 38,
53
+ "a": 39,
54
+ "b": 40,
55
+ "c": 41,
56
+ "d": 42,
57
+ "e": 43,
58
+ "f": 44,
59
+ "g": 45,
60
+ "h": 46,
61
+ "i": 47,
62
+ "j": 48,
63
+ "k": 49,
64
+ "l": 50,
65
+ "m": 51,
66
+ "n": 52,
67
+ "o": 53,
68
+ "p": 54,
69
+ "q": 55,
70
+ "r": 56,
71
+ "s": 57,
72
+ "t": 58,
73
+ "u": 59,
74
+ "v": 60,
75
+ "w": 61,
76
+ "x": 62,
77
+ "y": 63,
78
+ "z": 64
79
+ },
80
+ "decode": {
81
+ "0": "\n",
82
+ "1": " ",
83
+ "2": "!",
84
+ "3": "$",
85
+ "4": "&",
86
+ "5": "'",
87
+ "6": ",",
88
+ "7": "-",
89
+ "8": ".",
90
+ "9": "3",
91
+ "10": ":",
92
+ "11": ";",
93
+ "12": "?",
94
+ "13": "A",
95
+ "14": "B",
96
+ "15": "C",
97
+ "16": "D",
98
+ "17": "E",
99
+ "18": "F",
100
+ "19": "G",
101
+ "20": "H",
102
+ "21": "I",
103
+ "22": "J",
104
+ "23": "K",
105
+ "24": "L",
106
+ "25": "M",
107
+ "26": "N",
108
+ "27": "O",
109
+ "28": "P",
110
+ "29": "Q",
111
+ "30": "R",
112
+ "31": "S",
113
+ "32": "T",
114
+ "33": "U",
115
+ "34": "V",
116
+ "35": "W",
117
+ "36": "X",
118
+ "37": "Y",
119
+ "38": "Z",
120
+ "39": "a",
121
+ "40": "b",
122
+ "41": "c",
123
+ "42": "d",
124
+ "43": "e",
125
+ "44": "f",
126
+ "45": "g",
127
+ "46": "h",
128
+ "47": "i",
129
+ "48": "j",
130
+ "49": "k",
131
+ "50": "l",
132
+ "51": "m",
133
+ "52": "n",
134
+ "53": "o",
135
+ "54": "p",
136
+ "55": "q",
137
+ "56": "r",
138
+ "57": "s",
139
+ "58": "t",
140
+ "59": "u",
141
+ "60": "v",
142
+ "61": "w",
143
+ "62": "x",
144
+ "63": "y",
145
+ "64": "z"
146
+ }
147
+ }
core/__init__.py ADDED
File without changes
core/layers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .layers import Block, FeedForward, MultiHeadAttention, Head, RoPE, LlamaBlock, RMSNorm
core/layers/layers.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import tqdm
5
+
6
+
7
+ class Head(nn.Module):
8
+ """One head of self-attention."""
9
+
10
+ def __init__(self, n_embd, head_size, block_size, dropout):
11
+ super().__init__()
12
+ self.key = nn.Linear(n_embd, head_size, bias=False)
13
+ self.query = nn.Linear(n_embd, head_size, bias=False)
14
+ self.value = nn.Linear(n_embd, head_size, bias=False)
15
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
16
+ self.dropout = nn.Dropout(dropout)
17
+
18
+ def forward(self, x):
19
+ B, T, C = x.shape
20
+ k = self.key(x)
21
+ q = self.query(x)
22
+ wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
23
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
24
+ wei = F.softmax(wei, dim=-1)
25
+ wei = self.dropout(wei)
26
+ v = self.value(x)
27
+ out = wei @ v
28
+ return out
29
+
30
+ class MultiHeadAttention(nn.Module):
31
+ """Multiple heads of self-attention in parallel."""
32
+
33
+ def __init__(self, n_embd, n_head, block_size, dropout):
34
+ super().__init__()
35
+ assert n_embd % n_head == 0, f"n_embd ({n_embd}) must be divisible by num_heads ({n_head})"
36
+
37
+ self.n_embd = n_embd
38
+ self.n_head = n_head
39
+ self.head_size = n_embd // n_head
40
+
41
+ self.heads = nn.ModuleList([Head(n_embd, self.head_size, block_size, dropout) for _ in range(n_head)])
42
+ self.proj = nn.Linear(n_embd, n_embd)
43
+ self.dropout = nn.Dropout(dropout)
44
+
45
+ def forward(self, x):
46
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
47
+ out = self.dropout(self.proj(out))
48
+ return out
49
+
50
+ class FeedForward(nn.Module):
51
+ """A simple linear layer followed by a non-linearity."""
52
+
53
+ def __init__(self, n_embd, dropout):
54
+ super().__init__()
55
+ self.net = nn.Sequential(
56
+ nn.Linear(n_embd, 4 * n_embd),
57
+ nn.ReLU(),
58
+ nn.Linear(4 * n_embd, n_embd),
59
+ nn.Dropout(dropout),
60
+ )
61
+
62
+ def forward(self, x):
63
+ return self.net(x)
64
+
65
+ class Block(nn.Module):
66
+ """Transformer block: communication followed by computation."""
67
+
68
+ def __init__(self, n_embd, n_head, block_size, dropout):
69
+ super().__init__()
70
+ self.sa = MultiHeadAttention(n_embd, n_head, block_size, dropout)
71
+ self.ffwd = FeedForward(n_embd, dropout)
72
+ self.ln1 = nn.LayerNorm(n_embd)
73
+ self.ln2 = nn.LayerNorm(n_embd)
74
+
75
+ def forward(self, x):
76
+ x = x + self.sa(self.ln1(x))
77
+ x = x + self.ffwd(self.ln2(x))
78
+ return x
79
+
80
+
81
+ class RoPE(nn.Module):
82
+ """Rotary Positional Encoding (RoPE) layer."""
83
+
84
+ def __init__(self, embd_dim, max_freq=10):
85
+ super().__init__()
86
+ self.embd_dim = embd_dim
87
+ self.max_freq = max_freq
88
+ self.freqs = 2 ** torch.linspace(0, max_freq - 1, embd_dim // 2) * torch.pi
89
+ self.inv_freqs = 1. / self.freqs
90
+
91
+ def forward(self, x):
92
+ x = x + torch.sin(x @ self.freqs) * self.inv_freqs
93
+ x = x + torch.cos(x @ self.freqs) * self.inv_freqs
94
+ return x
95
+
96
+
97
+ class RMSNorm(nn.Module):
98
+ """Root Mean Square Layer Normalization (RMSNorm)."""
99
+
100
+ def __init__(self, embd_dim, epsilon=1e-8):
101
+ super().__init__()
102
+ self.embd_dim = embd_dim
103
+ self.epsilon = epsilon
104
+ self.gamma = nn.Parameter(torch.ones(embd_dim))
105
+ self.beta = nn.Parameter(torch.zeros(embd_dim))
106
+
107
+ def forward(self, x: torch.Tensor):
108
+ mean = x.mean(-1, keepdim=True)
109
+ variance = x.var(-1, keepdim=True)
110
+ x = x - mean
111
+ x = x / torch.sqrt(variance + self.epsilon)
112
+ x = x * self.gamma + self.beta
113
+ return x
114
+
115
+
116
+ class LlamaFFN(nn.Module):
117
+ """Feed-forward network of the LLAMA model with SwiGLU activation."""
118
+
119
+ def __init__(self, n_embd, dropout):
120
+ super().__init__()
121
+ self.linear1 = nn.Linear(n_embd, 4 * n_embd)
122
+ self.linear2 = nn.Linear(4 * n_embd, n_embd)
123
+ self.dropout = nn.Dropout(dropout)
124
+
125
+ def swiglu(self, x):
126
+ """Applies SwiGLU activation."""
127
+ x1, x2 = torch.chunk(x, 2, dim=-1)
128
+ return x1 * F.silu(x2)
129
+
130
+ def forward(self, x):
131
+ x = self.linear1(x)
132
+ x = self.swiglu(x)
133
+ x = self.dropout(x)
134
+ x = self.linear2(x)
135
+ return x
136
+
137
+
138
+ class AttentionHeadWithKVCacheAndRoPE(nn.Module):
139
+ """One head of self-attention with key and value cache and RoPE."""
140
+
141
+ def __init__(self, n_embd, head_size, block_size, dropout):
142
+ super().__init__()
143
+ self.key = nn.Linear(n_embd, head_size, bias=False)
144
+ self.query = nn.Linear(n_embd, head_size, bias=False)
145
+ self.value = nn.Linear(n_embd, head_size, bias=False)
146
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
147
+ self.dropout = nn.Dropout(dropout)
148
+ self.pe = RoPE(head_size)
149
+ self.ln = RMSNorm(n_embd)
150
+
151
+ def forward(self, x, kv_cache):
152
+ B, T, C = x.shape
153
+ k = self.key(x)
154
+ q = self.query(x)
155
+ v = self.value(x)
156
+ if kv_cache is not None:
157
+ k = torch.cat([kv_cache['k'], k], dim=1)
158
+ v = torch.cat([kv_cache['v'], v], dim=1)
159
+ wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
160
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
161
+ wei = F.softmax(wei, dim=-1)
162
+ wei = self.dropout(wei)
163
+ out = wei @ v
164
+ if kv_cache is None:
165
+ kv_cache = {'k': k, 'q': q, 'v': v}
166
+ else:
167
+ kv_cache['k'] = k
168
+ kv_cache['q'] = q
169
+ kv_cache['v'] = v
170
+ return self.pe(out) + x
171
+
172
+
173
+ class MultiHeadAttentionWithKVCacheAndRoPE(nn.Module):
174
+ """Multiple heads of self-attention in parallel."""
175
+
176
+ def __init__(self, n_embd, n_head, block_size, dropout):
177
+ super().__init__()
178
+ assert n_embd % n_head == 0, f"n_embd ({n_embd}) must be divisible by num_heads ({n_head})"
179
+
180
+ self.n_embd = n_embd
181
+ self.n_head = n_head
182
+ self.head_size = n_embd // n_head
183
+
184
+ self.heads = nn.ModuleList([AttentionHeadWithKVCacheAndRoPE(n_embd, self.head_size, block_size, dropout) for _ in range(n_head)])
185
+ self.proj = nn.Linear(n_embd, n_embd)
186
+ self.dropout = nn.Dropout(dropout)
187
+
188
+ def forward(self, x, kv_cache):
189
+ out = torch.cat([h(x, kv_cache) for h in self.heads], dim=-1)
190
+ out = self.dropout(self.proj(out))
191
+ return out
192
+
193
+
194
+ class LlamaBlock(nn.Module):
195
+ """LLAMA block: communication followed by computation."""
196
+
197
+ def __init__(self, n_embd, n_head, block_size, dropout):
198
+ super().__init__()
199
+ self.ln1 = RMSNorm(n_embd)
200
+ self.sa = MultiHeadAttentionWithKVCacheAndRoPE(n_embd, n_head, block_size, dropout)
201
+ self.ln2 = RMSNorm(n_embd)
202
+ self.ffwd = LlamaFFN(n_embd, dropout)
203
+
204
+ def forward(self, x, kv_cache):
205
+ x = x + self.sa(self.ln1(x), kv_cache)
206
+ x = x + self.ffwd(self.ln2(x))
207
+ return x
core/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import gpt
core/models/gpt.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import tqdm
4
+ from torch.nn import functional as F
5
+ from core.layers import Block
6
+
7
+ class GPTLanguageModel(nn.Module):
8
+
9
+ def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer, dropout, device, name = "GPT"):
10
+ super().__init__()
11
+ self.name = name
12
+ self.block_size = block_size
13
+ self.device = device
14
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
15
+ self.position_embedding_table = nn.Embedding(block_size, n_embd)
16
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
17
+ self.ln_f = nn.LayerNorm(n_embd)
18
+ self.lm_head = nn.Linear(n_embd, vocab_size)
19
+ self.apply(self._init_weights)
20
+ self.history = {}
21
+ self.vocab_size = vocab_size
22
+
23
+ def _init_weights(self, module):
24
+ if isinstance(module, nn.Linear):
25
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
26
+ if module.bias is not None:
27
+ nn.init.zeros_(module.bias)
28
+ elif isinstance(module, nn.Embedding):
29
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
30
+
31
+ def forward(self, idx, targets=None):
32
+ B, T = idx.shape
33
+
34
+ assert torch.all(idx < self.vocab_size), f"Input indices must be less than vocab_size ({self.vocab_size})"
35
+ assert T <= self.block_size, f"Input sequence length ({T}) must be <= block_size ({self.block_size})"
36
+
37
+ tok_emb = self.token_embedding_table(idx)
38
+ pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
39
+ x = tok_emb + pos_emb
40
+ x = self.blocks(x)
41
+ x = self.ln_f(x)
42
+ logits = self.lm_head(x)
43
+
44
+ if targets is None:
45
+ loss = None
46
+ else:
47
+ B, T, C = logits.shape
48
+ logits = logits.view(B * T, C)
49
+ targets = targets.view(B * T)
50
+ loss = F.cross_entropy(logits, targets)
51
+
52
+ return logits, loss
53
+
54
+ def generate(self, idx, max_new_tokens, max_seq_length=200, temperature=1.0):
55
+ for _ in range(max_new_tokens):
56
+ if idx.size(1) > max_seq_length:
57
+ idx = idx[:, -max_seq_length:]
58
+ idx_cond = idx[:, -self.block_size:]
59
+ logits, _ = self(idx_cond)
60
+ logits = logits[:, -1, :] / temperature
61
+ probs = F.softmax(logits, dim=-1)
62
+ idx_next = torch.multinomial(probs, num_samples=1)
63
+ idx = torch.cat((idx, idx_next), dim=1)
64
+ yield idx
core/models/llama.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import tqdm
4
+ from torch.nn import functional as F
5
+ from core.layers import LlamaBlock, RMSNorm
6
+
7
+ class LlamaLanguageModel(nn.Module):
8
+
9
+ def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer, dropout, device, name = "llama"):
10
+ super().__init__()
11
+ self.name = name
12
+ self.block_size = block_size
13
+ self.device = device
14
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
15
+ self.blocks = nn.Sequential(*[LlamaBlock(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
16
+ self.ln_f = RMSNorm(n_embd)
17
+ self.lm_head = nn.Linear(n_embd, vocab_size)
18
+ self.apply(self._init_weights)
19
+ self.history = {}
20
+ self.vocab_size = vocab_size
21
+
22
+ def _init_weights(self, module):
23
+ if isinstance(module, nn.Linear):
24
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
25
+ if module.bias is not None:
26
+ nn.init.zeros_(module.bias)
27
+ elif isinstance(module, nn.Embedding):
28
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
29
+
30
+ def forward(self, idx):
31
+ B, T = idx.shape
32
+ kv_cache = None
33
+ token_embeddings = self.token_embedding_table(idx)
34
+ for block in self.blocks:
35
+ token_embeddings = block(token_embeddings, kv_cache)
36
+ token_embeddings = self.ln_f(token_embeddings)
37
+ logits = self.lm_head(token_embeddings)
38
+ return logits, token_embeddings
39
+
40
+
41
+ def generate(self, idx, max_new_tokens, max_seq_length=200, temperature=1.0):
42
+ for _ in range(max_new_tokens):
43
+ if idx.size(1) > max_seq_length:
44
+ idx = idx[:, -max_seq_length:]
45
+ idx_cond = idx[:, -self.block_size:]
46
+ logits, _ = self(idx_cond)
47
+ logits = logits[:, -1, :] / temperature
48
+ probs = F.softmax(logits, dim=-1)
49
+ idx_next = torch.multinomial(probs, num_samples=1)
50
+ idx = torch.cat((idx, idx_next), dim=1)
51
+ yield idx
core/tokenizers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import tokenizer
core/tokenizers/tokenizer.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Iterable
4
+ import torch
5
+
6
+ class Tokenizer:
7
+ def __init__(self, data_path: str = None):
8
+ self.config = None
9
+ self.stoi = None
10
+ self.itos = None
11
+ self.vocab_size = None
12
+ if data_path:
13
+ self.data = self.load_data(data_path)
14
+ else:
15
+ self.data = None
16
+
17
+ def from_pretrained(self, config_path: str):
18
+ with open(config_path) as f:
19
+ config = json.load(f)
20
+ self.config = config
21
+ if 'encode' not in config:
22
+ raise ValueError("Config file must contain an 'encode' key.")
23
+ if 'decode' not in config:
24
+ raise ValueError("Config file must contain a 'decode' key.")
25
+ if 'vocab_size' not in config:
26
+ raise ValueError("Config file must contain a 'vocab_size' key.")
27
+ stoi = config['encode']
28
+ self.stoi = {k: int(v) for k, v in stoi.items()}
29
+ itos = config['decode']
30
+ self.itos = {int(k): v for k, v in itos.items()}
31
+ self.vocab_size = config['vocab_size']
32
+ return self
33
+
34
+ def load_data(self, path: str) -> str:
35
+ if not os.path.exists(path):
36
+ raise FileNotFoundError("File not found.")
37
+ if not path.endswith('.txt'):
38
+ raise ValueError("File must be a text file.")
39
+ with open(path, 'r', encoding='utf-8') as f:
40
+ text = f.read()
41
+ chars = sorted(list(set(text)))
42
+ vocab_size = len(chars)
43
+ stoi = {ch: i for i, ch in enumerate(chars)}
44
+ itos = {i: ch for i, ch in enumerate(chars)}
45
+ self.config = {"vocab_size": vocab_size, "encode": stoi, "decode": itos}
46
+ self.stoi = stoi
47
+ self.itos = itos
48
+ data = torch.tensor(self(text), dtype=torch.long)
49
+ n = int(0.9*len(data))
50
+ train_data = data[:n]
51
+ val_data = data[n:]
52
+ self.train_data = train_data
53
+ self.val_data = val_data
54
+ self.vocab_size = vocab_size
55
+ return text
56
+
57
+ def __repr__(self) -> str:
58
+ if self.config:
59
+ return f"Tokenizer(config={self.config})"
60
+ else:
61
+ return f"Tokenizer()"
62
+
63
+ def __str__(self) -> str:
64
+ if self.config:
65
+ return f"Tokenizer(config_path={self.config})"
66
+ else:
67
+ return f"Tokenizer()"
68
+
69
+ def __len__(self) -> int:
70
+ return len(self.stoi)
71
+
72
+ def __getitem__(self, key: str) -> int:
73
+ return self.stoi[key]
74
+
75
+ def __contains__(self, key: str) -> bool:
76
+ return key in self.stoi
77
+
78
+ def __iter__(self):
79
+ return iter(self.stoi)
80
+
81
+ def __reversed__(self):
82
+ return reversed(self.stoi)
83
+
84
+ def keys(self):
85
+ return self.stoi.keys()
86
+
87
+ def values(self):
88
+ return self.stoi.values()
89
+
90
+ def items(self):
91
+ return self.stoi.items()
92
+
93
+ def __call__(self, *args, **kwds) -> list[int]:
94
+ return self.encode(*args, **kwds)
95
+
96
+ def encode(self, s: str | list[str]) -> list[int]:
97
+ if isinstance(s, str):
98
+ return [self.stoi[c] for c in s]
99
+ elif isinstance(s, list):
100
+ return [[self.stoi[i] for i in c] for c in s]
101
+ else:
102
+ raise ValueError("Input must be a string or a list of strings.")
103
+
104
+ def decode(self, l: list[int]) -> str:
105
+ if isinstance(l[0], int):
106
+ return ''.join([self.itos[i] for i in l])
107
+ elif isinstance(l[0], Iterable):
108
+ return [''.join([self.itos[i] for i in c]) for c in l]
109
+ else:
110
+ raise ValueError("Input must be a list of integers or a list of list of integers.")
111
+
112
+ def save_pretrained(self, path: str) -> str:
113
+ with open(path + 'vocab.json', 'w') as f:
114
+ json.dump(self.config, f)
115
+ return "Tokenizer saved at {}.".format(path)
core/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import gptutils, preprocessing
core/utils/gptutils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+
4
+ # ------------ Hyperparameters ------------
5
+ def hyperparameters(config_path: str):
6
+ with open(config_path) as f:
7
+ config = json.load(f)
8
+
9
+ batch_size = config['batch_size']
10
+ block_size = config['block_size']
11
+ max_iters = config['max_iters']
12
+ eval_interval = config['eval_interval']
13
+ learning_rate = config['learning_rate']
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ eval_iters = config['eval_iters']
16
+ n_embd = config['n_embd']
17
+ n_head = config['n_head']
18
+ n_layer = config['n_layer']
19
+ dropout = config['dropout']
20
+ return (batch_size, block_size, max_iters, eval_interval, learning_rate,
21
+ device, eval_iters, n_embd, n_head, n_layer, dropout)
22
+ # ----------------------------------------
23
+
24
+ def load_data(path) -> tuple[torch.Tensor, torch.Tensor, int, callable, callable]:
25
+ with open(path, 'r', encoding='utf-8') as f:
26
+ text = f.read()
27
+
28
+ # words = text.split()
29
+ # vocab_size = len(words)
30
+ # stoi = {word: i for i, word in enumerate(words)}
31
+ # itos = {i: word for i, word in enumerate(words)}
32
+ # def encode(s): return [stoi[w] for w in s.split()]
33
+ # def decode(ids): return ' '.join([itos[i] for i in ids])
34
+
35
+ chars = sorted(list(set(text)))
36
+ vocab_size = len(chars)
37
+ stoi = {ch: i for i, ch in enumerate(chars)}
38
+ itos = {i: ch for i, ch in enumerate(chars)}
39
+ def encode(s): return [stoi[c] for c in s]
40
+ def decode(l): return ''.join([itos[i] for i in l])
41
+ data = torch.tensor(encode(text), dtype=torch.long)
42
+ n = int(0.9*len(data))
43
+ train_data = data[:n]
44
+ val_data = data[n:]
45
+
46
+
47
+ return train_data, val_data, vocab_size, encode, decode
48
+
49
+
50
+ def get_batch(split, train_data, val_data, device, block_size, batch_size):
51
+ data = train_data if split == 'train' else val_data
52
+ ix = torch.randint(len(data) - block_size, (batch_size,))
53
+ x = torch.stack([data[i:i+block_size] for i in ix])
54
+ y = torch.stack([data[i+1:i+block_size+1] for i in ix])
55
+ x, y = x.to(device), y.to(device)
56
+ return x, y
57
+
58
+
59
+ @torch.no_grad()
60
+ def estimate_loss(model, get_batch, eval_iters, train_data, val_data, device, block_size, batch_size):
61
+ out = {}
62
+ model.eval()
63
+ for split in ['train', 'val']:
64
+ losses = torch.zeros(eval_iters)
65
+ for k in range(eval_iters):
66
+ X, Y = get_batch(split, train_data, val_data, device, block_size, batch_size)
67
+ logits, loss = model(X, Y)
68
+ losses[k] = loss.item()
69
+ out[split] = losses.mean()
70
+ model.train()
71
+ return out
core/utils/preprocessing.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ class DataLoader:
6
+ def __init__(self, data_path):
7
+ self.data_path = data_path
8
+ self.batch_size = None
9
+ self.block_size = None
10
+ self.data = None
11
+ self.train_data = None
12
+ self.val_data = None
13
+
14
+ def load_data(self, block_size=128, split=0.8, batch_size=64, device='cpu'):
15
+ with open(self.data_path, 'r') as f:
16
+ data = f.read()
17
+ self.block_size = block_size
18
+ self.batch_size = batch_size
19
+ self.device = device
20
+ self.data = data
21
+
22
+ def __len__(self):
23
+ return int(np.ceil(len(self.data) / self.batch_size))
24
+
25
+ def __getitem__(self, index):
26
+ indexes = self.indexes[index *
27
+ self.batch_size:(index + 1) * self.batch_size]
28
+ batch = [self.data[i] for i in indexes]
29
+ batch = np.array(batch)
30
+ return batch
31
+
32
+ def get_batch(self, split, device='cpu'):
33
+ if self.data is None:
34
+ raise ValueError('Data not loaded')
35
+ data = self.train_data if split == 'train' else self.val_data
36
+ ix = torch.randint(len(data) - self.block_size, (self.batch_size,))
37
+ x = torch.stack([data[i:i+self.block_size] for i in ix])
38
+ y = torch.stack([data[i+1:i+self.block_size+1] for i in ix])
39
+ x, y = x.to(device), y.to(device)
40
+ return x, y
41
+
42
+
43
+ class Encoder:
44
+ def __init__(self, data, type='char'):
45
+ self.data = data
46
+ self.type = type
47
+ self.vocab_size = None
48
+ if type == 'char':
49
+ self.chars = sorted(list(set(data)))
50
+ self.stoi = {ch: i for i, ch in enumerate(self.chars)}
51
+ self.itos = {i: ch for i, ch in enumerate(self.chars)}
52
+ self.vocab_size = len(self.chars)
53
+ elif type == 'word':
54
+ self.words = data.split()
55
+ self.stoi = {word: i for i, word in enumerate(self.words)}
56
+ self.itos = {i: word for i, word in enumerate(self.words)}
57
+ self.vocab_size = len(self.words)
58
+ else:
59
+ raise ValueError('Type must be either "char" or "word"')
60
+
61
+ def encode(self, string: str):
62
+ if self.type == 'char':
63
+ return torch.tensor([self.stoi[c] for c in string])
64
+ elif self.type == 'word':
65
+ return torch.tensor([self.stoi[w] for w in string.split()])
66
+ else:
67
+ raise ValueError('Type must be either "char" or "word"')
68
+
69
+ def decode(self, ids: list):
70
+ if self.type == 'char':
71
+ return ''.join([self.itos[i] for i in ids])
72
+ elif self.type == 'word':
73
+ return ' '.join([self.itos[i] for i in ids])
74
+ else:
75
+ raise ValueError('Type must be either "char" or "word"')
requirements.txt ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.2.0
2
+ annotated-types==0.6.0
3
+ anyio==4.3.0
4
+ asttokens==2.4.1
5
+ attrs==23.2.0
6
+ blinker==1.7.0
7
+ cachetools==5.3.3
8
+ certifi==2024.2.2
9
+ charset-normalizer==3.3.2
10
+ click==8.1.7
11
+ colorama==0.4.6
12
+ comm==0.2.2
13
+ debugpy==1.8.1
14
+ decorator==5.1.1
15
+ dnspython==2.6.1
16
+ email_validator==2.1.1
17
+ executing==2.0.1
18
+ fastapi==0.110.3
19
+ filelock==3.13.1
20
+ fsspec==2024.2.0
21
+ gitdb==4.0.11
22
+ GitPython==3.1.42
23
+ h11==0.14.0
24
+ httpcore==1.0.5
25
+ httptools==0.6.1
26
+ httpx==0.27.0
27
+ idna==3.6
28
+ ipykernel==6.29.4
29
+ ipython==8.24.0
30
+ itsdangerous==2.2.0
31
+ jedi==0.19.1
32
+ Jinja2==3.1.3
33
+ joblib==1.4.0
34
+ jsonschema==4.21.1
35
+ jsonschema-specifications==2023.12.1
36
+ jupyter_client==8.6.1
37
+ jupyter_core==5.7.2
38
+ markdown-it-py==3.0.0
39
+ MarkupSafe==2.1.5
40
+ matplotlib-inline==0.1.7
41
+ mdurl==0.1.2
42
+ mpmath==1.3.0
43
+ nest-asyncio==1.6.0
44
+ networkx==3.2.1
45
+ numpy==1.26.4
46
+ orjson==3.10.2
47
+ packaging==23.2
48
+ pandas==2.2.1
49
+ parso==0.8.4
50
+ pillow==10.2.0
51
+ platformdirs==4.2.1
52
+ prompt-toolkit==3.0.43
53
+ protobuf==4.25.3
54
+ psutil==5.9.8
55
+ pure-eval==0.2.2
56
+ pyarrow==15.0.2
57
+ pydantic==2.7.1
58
+ pydantic-extra-types==2.7.0
59
+ pydantic-settings==2.2.1
60
+ pydantic_core==2.18.2
61
+ pydeck==0.8.1b0
62
+ Pygments==2.17.2
63
+ python-dateutil==2.9.0.post0
64
+ python-dotenv==1.0.1
65
+ python-multipart==0.0.9
66
+ pytz==2024.1
67
+ PyYAML==6.0.1
68
+ pyzmq==26.0.2
69
+ referencing==0.34.0
70
+ regex==2024.4.28
71
+ requests==2.31.0
72
+ rich==13.7.1
73
+ rpds-py==0.18.0
74
+ six==1.16.0
75
+ smmap==5.0.1
76
+ sniffio==1.3.1
77
+ stack-data==0.6.3
78
+ starlette==0.37.2
79
+ streamlit==1.32.2
80
+ sympy==1.12
81
+ tenacity==8.2.3
82
+ toml==0.10.2
83
+ toolz==0.12.1
84
+ torch==2.2.1
85
+ tornado==6.4
86
+ tqdm==4.66.2
87
+ traitlets==5.14.3
88
+ typing_extensions==4.10.0
89
+ tzdata==2024.1
90
+ ujson==5.9.0
91
+ urllib3==2.2.1
92
+ uvicorn==0.29.0
93
+ watchdog==4.0.0
94
+ watchfiles==0.21.0
95
+ wcwidth==0.2.13
96
+ websockets==12.0