Tjdharani commited on
Commit
14f4c62
·
1 Parent(s): c4bceef

Upload minGPT.ipynb

Browse files
Files changed (1) hide show
  1. minGPT.ipynb +1505 -0
minGPT.ipynb ADDED
@@ -0,0 +1,1505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "markdown",
21
+ "source": [
22
+ "#Building GPT"
23
+ ],
24
+ "metadata": {
25
+ "id": "8FHnXpkTv_5f"
26
+ }
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "source": [
31
+ "# We always start with a dataset to train on. Let's download the tiny shakespeare dataset\n",
32
+ "!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
33
+ ],
34
+ "metadata": {
35
+ "colab": {
36
+ "base_uri": "https://localhost:8080/"
37
+ },
38
+ "id": "YTPlvPQn-Zef",
39
+ "outputId": "45f9c50f-d2c6-4629-cabe-d1378e2882a7"
40
+ },
41
+ "execution_count": 1,
42
+ "outputs": [
43
+ {
44
+ "output_type": "stream",
45
+ "name": "stdout",
46
+ "text": [
47
+ "--2023-06-13 07:55:40-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n",
48
+ "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
49
+ "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
50
+ "HTTP request sent, awaiting response... 200 OK\n",
51
+ "Length: 1115394 (1.1M) [text/plain]\n",
52
+ "Saving to: ‘input.txt’\n",
53
+ "\n",
54
+ "input.txt 100%[===================>] 1.06M --.-KB/s in 0.005s \n",
55
+ "\n",
56
+ "2023-06-13 07:55:40 (199 MB/s) - ‘input.txt’ saved [1115394/1115394]\n",
57
+ "\n"
58
+ ]
59
+ }
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "source": [
65
+ "with open('input.txt', 'r', encoding='utf-8') as f:\n",
66
+ " text = f.read()"
67
+ ],
68
+ "metadata": {
69
+ "id": "mfIiqOSm-euI"
70
+ },
71
+ "execution_count": 2,
72
+ "outputs": []
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "source": [
77
+ "print(\"length of dataset in characters:\", len(text))\n"
78
+ ],
79
+ "metadata": {
80
+ "colab": {
81
+ "base_uri": "https://localhost:8080/"
82
+ },
83
+ "id": "4Qgkvnr0_N66",
84
+ "outputId": "6063f096-78b7-40c1-c830-531594a0bb1a"
85
+ },
86
+ "execution_count": 3,
87
+ "outputs": [
88
+ {
89
+ "output_type": "stream",
90
+ "name": "stdout",
91
+ "text": [
92
+ "length of dataset in characters: 1115394\n"
93
+ ]
94
+ }
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "source": [
100
+ "# let's look at the first 1000 characters\n",
101
+ "print(text[:1000])"
102
+ ],
103
+ "metadata": {
104
+ "colab": {
105
+ "base_uri": "https://localhost:8080/"
106
+ },
107
+ "id": "Qn9QIHwf_c-_",
108
+ "outputId": "4f4f837a-7b53-43fd-807e-42d16b0519c6"
109
+ },
110
+ "execution_count": 4,
111
+ "outputs": [
112
+ {
113
+ "output_type": "stream",
114
+ "name": "stdout",
115
+ "text": [
116
+ "First Citizen:\n",
117
+ "Before we proceed any further, hear me speak.\n",
118
+ "\n",
119
+ "All:\n",
120
+ "Speak, speak.\n",
121
+ "\n",
122
+ "First Citizen:\n",
123
+ "You are all resolved rather to die than to famish?\n",
124
+ "\n",
125
+ "All:\n",
126
+ "Resolved. resolved.\n",
127
+ "\n",
128
+ "First Citizen:\n",
129
+ "First, you know Caius Marcius is chief enemy to the people.\n",
130
+ "\n",
131
+ "All:\n",
132
+ "We know't, we know't.\n",
133
+ "\n",
134
+ "First Citizen:\n",
135
+ "Let us kill him, and we'll have corn at our own price.\n",
136
+ "Is't a verdict?\n",
137
+ "\n",
138
+ "All:\n",
139
+ "No more talking on't; let it be done: away, away!\n",
140
+ "\n",
141
+ "Second Citizen:\n",
142
+ "One word, good citizens.\n",
143
+ "\n",
144
+ "First Citizen:\n",
145
+ "We are accounted poor citizens, the patricians good.\n",
146
+ "What authority surfeits on would relieve us: if they\n",
147
+ "would yield us but the superfluity, while it were\n",
148
+ "wholesome, we might guess they relieved us humanely;\n",
149
+ "but they think we are too dear: the leanness that\n",
150
+ "afflicts us, the object of our misery, is as an\n",
151
+ "inventory to particularise their abundance; our\n",
152
+ "sufferance is a gain to them Let us revenge this with\n",
153
+ "our pikes, ere we become rakes: for the gods know I\n",
154
+ "speak this in hunger for bread, not in thirst for revenge.\n",
155
+ "\n",
156
+ "\n"
157
+ ]
158
+ }
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "source": [
164
+ "# here are all the unique characters that occur in this text\n",
165
+ "chars = sorted(list(set(text)))\n",
166
+ "vocab_size = len(chars)\n",
167
+ "print(''.join(chars))\n",
168
+ "print(vocab_size)"
169
+ ],
170
+ "metadata": {
171
+ "colab": {
172
+ "base_uri": "https://localhost:8080/"
173
+ },
174
+ "id": "JN8_xJFY_zvq",
175
+ "outputId": "d0ab20bb-c366-41af-9378-15ced2913126"
176
+ },
177
+ "execution_count": 5,
178
+ "outputs": [
179
+ {
180
+ "output_type": "stream",
181
+ "name": "stdout",
182
+ "text": [
183
+ "\n",
184
+ " !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n",
185
+ "65\n"
186
+ ]
187
+ }
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "source": [
193
+ "# create a mapping from characters to integers \n",
194
+ "stoi = { ch:i for i, ch in enumerate(chars)}\n",
195
+ "itos = { i:ch for i, ch in enumerate(chars)}\n",
196
+ "encode = lambda s: [stoi[c] for c in s] # sting to integer\n",
197
+ "decode = lambda l: ''.join([itos[i] for i in l]) # integer to string\n",
198
+ "\n",
199
+ "print(encode(\"hii there\"))\n",
200
+ "print(decode(encode(\"hii there\")))"
201
+ ],
202
+ "metadata": {
203
+ "colab": {
204
+ "base_uri": "https://localhost:8080/"
205
+ },
206
+ "id": "X1lJF7-IAjz_",
207
+ "outputId": "18702fc0-b1c0-4675-b78a-e047a06f4887"
208
+ },
209
+ "execution_count": 6,
210
+ "outputs": [
211
+ {
212
+ "output_type": "stream",
213
+ "name": "stdout",
214
+ "text": [
215
+ "[46, 47, 47, 1, 58, 46, 43, 56, 43]\n",
216
+ "hii there\n"
217
+ ]
218
+ }
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "source": [
224
+ "# let's now encode the entire text dataset and store it into torch.Tensor\n",
225
+ "import torch # PyTorch\n",
226
+ "data = torch.tensor(encode(text), dtype=torch.long)\n",
227
+ "print(data.shape, data.dtype)\n",
228
+ "print(data[:1000])"
229
+ ],
230
+ "metadata": {
231
+ "colab": {
232
+ "base_uri": "https://localhost:8080/"
233
+ },
234
+ "id": "ML1pjHfLCJ_M",
235
+ "outputId": "3f21fc94-ed1f-4bb5-b9db-0a1ad2e5b227"
236
+ },
237
+ "execution_count": 7,
238
+ "outputs": [
239
+ {
240
+ "output_type": "stream",
241
+ "name": "stdout",
242
+ "text": [
243
+ "torch.Size([1115394]) torch.int64\n",
244
+ "tensor([18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44,\n",
245
+ " 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63,\n",
246
+ " 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, 1, 51, 43, 1,\n",
247
+ " 57, 54, 43, 39, 49, 8, 0, 0, 13, 50, 50, 10, 0, 31, 54, 43, 39, 49,\n",
248
+ " 6, 1, 57, 54, 43, 39, 49, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47,\n",
249
+ " 58, 47, 64, 43, 52, 10, 0, 37, 53, 59, 1, 39, 56, 43, 1, 39, 50, 50,\n",
250
+ " 1, 56, 43, 57, 53, 50, 60, 43, 42, 1, 56, 39, 58, 46, 43, 56, 1, 58,\n",
251
+ " 53, 1, 42, 47, 43, 1, 58, 46, 39, 52, 1, 58, 53, 1, 44, 39, 51, 47,\n",
252
+ " 57, 46, 12, 0, 0, 13, 50, 50, 10, 0, 30, 43, 57, 53, 50, 60, 43, 42,\n",
253
+ " 8, 1, 56, 43, 57, 53, 50, 60, 43, 42, 8, 0, 0, 18, 47, 56, 57, 58,\n",
254
+ " 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 18, 47, 56, 57, 58, 6, 1, 63,\n",
255
+ " 53, 59, 1, 49, 52, 53, 61, 1, 15, 39, 47, 59, 57, 1, 25, 39, 56, 41,\n",
256
+ " 47, 59, 57, 1, 47, 57, 1, 41, 46, 47, 43, 44, 1, 43, 52, 43, 51, 63,\n",
257
+ " 1, 58, 53, 1, 58, 46, 43, 1, 54, 43, 53, 54, 50, 43, 8, 0, 0, 13,\n",
258
+ " 50, 50, 10, 0, 35, 43, 1, 49, 52, 53, 61, 5, 58, 6, 1, 61, 43, 1,\n",
259
+ " 49, 52, 53, 61, 5, 58, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47, 58,\n",
260
+ " 47, 64, 43, 52, 10, 0, 24, 43, 58, 1, 59, 57, 1, 49, 47, 50, 50, 1,\n",
261
+ " 46, 47, 51, 6, 1, 39, 52, 42, 1, 61, 43, 5, 50, 50, 1, 46, 39, 60,\n",
262
+ " 43, 1, 41, 53, 56, 52, 1, 39, 58, 1, 53, 59, 56, 1, 53, 61, 52, 1,\n",
263
+ " 54, 56, 47, 41, 43, 8, 0, 21, 57, 5, 58, 1, 39, 1, 60, 43, 56, 42,\n",
264
+ " 47, 41, 58, 12, 0, 0, 13, 50, 50, 10, 0, 26, 53, 1, 51, 53, 56, 43,\n",
265
+ " 1, 58, 39, 50, 49, 47, 52, 45, 1, 53, 52, 5, 58, 11, 1, 50, 43, 58,\n",
266
+ " 1, 47, 58, 1, 40, 43, 1, 42, 53, 52, 43, 10, 1, 39, 61, 39, 63, 6,\n",
267
+ " 1, 39, 61, 39, 63, 2, 0, 0, 31, 43, 41, 53, 52, 42, 1, 15, 47, 58,\n",
268
+ " 47, 64, 43, 52, 10, 0, 27, 52, 43, 1, 61, 53, 56, 42, 6, 1, 45, 53,\n",
269
+ " 53, 42, 1, 41, 47, 58, 47, 64, 43, 52, 57, 8, 0, 0, 18, 47, 56, 57,\n",
270
+ " 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 35, 43, 1, 39, 56, 43, 1,\n",
271
+ " 39, 41, 41, 53, 59, 52, 58, 43, 42, 1, 54, 53, 53, 56, 1, 41, 47, 58,\n",
272
+ " 47, 64, 43, 52, 57, 6, 1, 58, 46, 43, 1, 54, 39, 58, 56, 47, 41, 47,\n",
273
+ " 39, 52, 57, 1, 45, 53, 53, 42, 8, 0, 35, 46, 39, 58, 1, 39, 59, 58,\n",
274
+ " 46, 53, 56, 47, 58, 63, 1, 57, 59, 56, 44, 43, 47, 58, 57, 1, 53, 52,\n",
275
+ " 1, 61, 53, 59, 50, 42, 1, 56, 43, 50, 47, 43, 60, 43, 1, 59, 57, 10,\n",
276
+ " 1, 47, 44, 1, 58, 46, 43, 63, 0, 61, 53, 59, 50, 42, 1, 63, 47, 43,\n",
277
+ " 50, 42, 1, 59, 57, 1, 40, 59, 58, 1, 58, 46, 43, 1, 57, 59, 54, 43,\n",
278
+ " 56, 44, 50, 59, 47, 58, 63, 6, 1, 61, 46, 47, 50, 43, 1, 47, 58, 1,\n",
279
+ " 61, 43, 56, 43, 0, 61, 46, 53, 50, 43, 57, 53, 51, 43, 6, 1, 61, 43,\n",
280
+ " 1, 51, 47, 45, 46, 58, 1, 45, 59, 43, 57, 57, 1, 58, 46, 43, 63, 1,\n",
281
+ " 56, 43, 50, 47, 43, 60, 43, 42, 1, 59, 57, 1, 46, 59, 51, 39, 52, 43,\n",
282
+ " 50, 63, 11, 0, 40, 59, 58, 1, 58, 46, 43, 63, 1, 58, 46, 47, 52, 49,\n",
283
+ " 1, 61, 43, 1, 39, 56, 43, 1, 58, 53, 53, 1, 42, 43, 39, 56, 10, 1,\n",
284
+ " 58, 46, 43, 1, 50, 43, 39, 52, 52, 43, 57, 57, 1, 58, 46, 39, 58, 0,\n",
285
+ " 39, 44, 44, 50, 47, 41, 58, 57, 1, 59, 57, 6, 1, 58, 46, 43, 1, 53,\n",
286
+ " 40, 48, 43, 41, 58, 1, 53, 44, 1, 53, 59, 56, 1, 51, 47, 57, 43, 56,\n",
287
+ " 63, 6, 1, 47, 57, 1, 39, 57, 1, 39, 52, 0, 47, 52, 60, 43, 52, 58,\n",
288
+ " 53, 56, 63, 1, 58, 53, 1, 54, 39, 56, 58, 47, 41, 59, 50, 39, 56, 47,\n",
289
+ " 57, 43, 1, 58, 46, 43, 47, 56, 1, 39, 40, 59, 52, 42, 39, 52, 41, 43,\n",
290
+ " 11, 1, 53, 59, 56, 0, 57, 59, 44, 44, 43, 56, 39, 52, 41, 43, 1, 47,\n",
291
+ " 57, 1, 39, 1, 45, 39, 47, 52, 1, 58, 53, 1, 58, 46, 43, 51, 1, 24,\n",
292
+ " 43, 58, 1, 59, 57, 1, 56, 43, 60, 43, 52, 45, 43, 1, 58, 46, 47, 57,\n",
293
+ " 1, 61, 47, 58, 46, 0, 53, 59, 56, 1, 54, 47, 49, 43, 57, 6, 1, 43,\n",
294
+ " 56, 43, 1, 61, 43, 1, 40, 43, 41, 53, 51, 43, 1, 56, 39, 49, 43, 57,\n",
295
+ " 10, 1, 44, 53, 56, 1, 58, 46, 43, 1, 45, 53, 42, 57, 1, 49, 52, 53,\n",
296
+ " 61, 1, 21, 0, 57, 54, 43, 39, 49, 1, 58, 46, 47, 57, 1, 47, 52, 1,\n",
297
+ " 46, 59, 52, 45, 43, 56, 1, 44, 53, 56, 1, 40, 56, 43, 39, 42, 6, 1,\n",
298
+ " 52, 53, 58, 1, 47, 52, 1, 58, 46, 47, 56, 57, 58, 1, 44, 53, 56, 1,\n",
299
+ " 56, 43, 60, 43, 52, 45, 43, 8, 0, 0])\n"
300
+ ]
301
+ }
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "code",
306
+ "source": [
307
+ "# split the data into train and validation set\n",
308
+ "n = int(0.9*len(data)) #train 90% data\n",
309
+ "train_data = data[:n]\n",
310
+ "val_data = data[n:]"
311
+ ],
312
+ "metadata": {
313
+ "id": "F-6DyilNE7KM"
314
+ },
315
+ "execution_count": 8,
316
+ "outputs": []
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "source": [
321
+ "block_size = 8\n",
322
+ "train_data[:block_size+1]"
323
+ ],
324
+ "metadata": {
325
+ "colab": {
326
+ "base_uri": "https://localhost:8080/"
327
+ },
328
+ "id": "z79mbyx-GJC-",
329
+ "outputId": "b4b90aae-90f9-4f07-bbc0-2f726b0ff4d3"
330
+ },
331
+ "execution_count": 9,
332
+ "outputs": [
333
+ {
334
+ "output_type": "execute_result",
335
+ "data": {
336
+ "text/plain": [
337
+ "tensor([18, 47, 56, 57, 58, 1, 15, 47, 58])"
338
+ ]
339
+ },
340
+ "metadata": {},
341
+ "execution_count": 9
342
+ }
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "code",
347
+ "source": [
348
+ "x = train_data[:block_size]\n",
349
+ "y = train_data[1:block_size+1]\n",
350
+ "for t in range(block_size):\n",
351
+ " context = x[:t+1]\n",
352
+ " target = y[t]\n",
353
+ " print(f\"when input is {context} the target: {target}\")"
354
+ ],
355
+ "metadata": {
356
+ "colab": {
357
+ "base_uri": "https://localhost:8080/"
358
+ },
359
+ "id": "5SQI_jZXGb7_",
360
+ "outputId": "52404a4a-91dd-4757-9c7e-c30a8a2eb2a3"
361
+ },
362
+ "execution_count": 10,
363
+ "outputs": [
364
+ {
365
+ "output_type": "stream",
366
+ "name": "stdout",
367
+ "text": [
368
+ "when input is tensor([18]) the target: 47\n",
369
+ "when input is tensor([18, 47]) the target: 56\n",
370
+ "when input is tensor([18, 47, 56]) the target: 57\n",
371
+ "when input is tensor([18, 47, 56, 57]) the target: 58\n",
372
+ "when input is tensor([18, 47, 56, 57, 58]) the target: 1\n",
373
+ "when input is tensor([18, 47, 56, 57, 58, 1]) the target: 15\n",
374
+ "when input is tensor([18, 47, 56, 57, 58, 1, 15]) the target: 47\n",
375
+ "when input is tensor([18, 47, 56, 57, 58, 1, 15, 47]) the target: 58\n"
376
+ ]
377
+ }
378
+ ]
379
+ },
380
+ {
381
+ "cell_type": "code",
382
+ "source": [
383
+ "torch.manual_seed(1337)\n",
384
+ "batch_size = 4\n",
385
+ "block_size = 8\n",
386
+ "\n",
387
+ "def get_batch(split):\n",
388
+ " # generate a small batch of data of inputs x and targets y\n",
389
+ " data = train_data if split == 'train' else val_data\n",
390
+ " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
391
+ " x = torch.stack([data[i:i+block_size] for i in ix])\n",
392
+ " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
393
+ " return x, y\n",
394
+ "\n",
395
+ "xb, yb = get_batch('train')\n",
396
+ "print('inputs:')\n",
397
+ "print(xb.shape)\n",
398
+ "print(xb)\n",
399
+ "print('targets:')\n",
400
+ "print(yb.shape)\n",
401
+ "print(yb)\n",
402
+ "\n",
403
+ "print('----')\n",
404
+ "\n",
405
+ "for b in range(batch_size): # batch dimension\n",
406
+ " for t in range(block_size): # time dimension\n",
407
+ " context = xb[b, :t+1]\n",
408
+ " target = yb[b,t]\n",
409
+ " print(f\"when input is {context.tolist()} the target: {target}\")"
410
+ ],
411
+ "metadata": {
412
+ "colab": {
413
+ "base_uri": "https://localhost:8080/"
414
+ },
415
+ "id": "IAjhF0PTI1HF",
416
+ "outputId": "245c0f68-9502-4633-d365-e411176a5a14"
417
+ },
418
+ "execution_count": 11,
419
+ "outputs": [
420
+ {
421
+ "output_type": "stream",
422
+ "name": "stdout",
423
+ "text": [
424
+ "inputs:\n",
425
+ "torch.Size([4, 8])\n",
426
+ "tensor([[24, 43, 58, 5, 57, 1, 46, 43],\n",
427
+ " [44, 53, 56, 1, 58, 46, 39, 58],\n",
428
+ " [52, 58, 1, 58, 46, 39, 58, 1],\n",
429
+ " [25, 17, 27, 10, 0, 21, 1, 54]])\n",
430
+ "targets:\n",
431
+ "torch.Size([4, 8])\n",
432
+ "tensor([[43, 58, 5, 57, 1, 46, 43, 39],\n",
433
+ " [53, 56, 1, 58, 46, 39, 58, 1],\n",
434
+ " [58, 1, 58, 46, 39, 58, 1, 46],\n",
435
+ " [17, 27, 10, 0, 21, 1, 54, 39]])\n",
436
+ "----\n",
437
+ "when input is [24] the target: 43\n",
438
+ "when input is [24, 43] the target: 58\n",
439
+ "when input is [24, 43, 58] the target: 5\n",
440
+ "when input is [24, 43, 58, 5] the target: 57\n",
441
+ "when input is [24, 43, 58, 5, 57] the target: 1\n",
442
+ "when input is [24, 43, 58, 5, 57, 1] the target: 46\n",
443
+ "when input is [24, 43, 58, 5, 57, 1, 46] the target: 43\n",
444
+ "when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39\n",
445
+ "when input is [44] the target: 53\n",
446
+ "when input is [44, 53] the target: 56\n",
447
+ "when input is [44, 53, 56] the target: 1\n",
448
+ "when input is [44, 53, 56, 1] the target: 58\n",
449
+ "when input is [44, 53, 56, 1, 58] the target: 46\n",
450
+ "when input is [44, 53, 56, 1, 58, 46] the target: 39\n",
451
+ "when input is [44, 53, 56, 1, 58, 46, 39] the target: 58\n",
452
+ "when input is [44, 53, 56, 1, 58, 46, 39, 58] the target: 1\n",
453
+ "when input is [52] the target: 58\n",
454
+ "when input is [52, 58] the target: 1\n",
455
+ "when input is [52, 58, 1] the target: 58\n",
456
+ "when input is [52, 58, 1, 58] the target: 46\n",
457
+ "when input is [52, 58, 1, 58, 46] the target: 39\n",
458
+ "when input is [52, 58, 1, 58, 46, 39] the target: 58\n",
459
+ "when input is [52, 58, 1, 58, 46, 39, 58] the target: 1\n",
460
+ "when input is [52, 58, 1, 58, 46, 39, 58, 1] the target: 46\n",
461
+ "when input is [25] the target: 17\n",
462
+ "when input is [25, 17] the target: 27\n",
463
+ "when input is [25, 17, 27] the target: 10\n",
464
+ "when input is [25, 17, 27, 10] the target: 0\n",
465
+ "when input is [25, 17, 27, 10, 0] the target: 21\n",
466
+ "when input is [25, 17, 27, 10, 0, 21] the target: 1\n",
467
+ "when input is [25, 17, 27, 10, 0, 21, 1] the target: 54\n",
468
+ "when input is [25, 17, 27, 10, 0, 21, 1, 54] the target: 39\n"
469
+ ]
470
+ }
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "source": [
476
+ "print (xb) # our input to the transformer"
477
+ ],
478
+ "metadata": {
479
+ "colab": {
480
+ "base_uri": "https://localhost:8080/"
481
+ },
482
+ "id": "Sy2A0cbXM1Bd",
483
+ "outputId": "ba015f11-ee15-435e-b88a-2ad4164d7abe"
484
+ },
485
+ "execution_count": 12,
486
+ "outputs": [
487
+ {
488
+ "output_type": "stream",
489
+ "name": "stdout",
490
+ "text": [
491
+ "tensor([[24, 43, 58, 5, 57, 1, 46, 43],\n",
492
+ " [44, 53, 56, 1, 58, 46, 39, 58],\n",
493
+ " [52, 58, 1, 58, 46, 39, 58, 1],\n",
494
+ " [25, 17, 27, 10, 0, 21, 1, 54]])\n"
495
+ ]
496
+ }
497
+ ]
498
+ },
499
+ {
500
+ "cell_type": "code",
501
+ "source": [
502
+ "import torch\n",
503
+ "import torch.nn as nn\n",
504
+ "from torch.nn import functional as F\n",
505
+ "torch.manual_seed(1337)\n",
506
+ "\n",
507
+ "class BigramLanguageModel(nn.Module):\n",
508
+ "\n",
509
+ " def __init__(self, vocab_size):\n",
510
+ " super().__init__()\n",
511
+ " # each token directly reads off the logits for the next token from a lookup table\n",
512
+ " self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n",
513
+ "\n",
514
+ " def forward(self, idx, targets=None):\n",
515
+ "\n",
516
+ " # idx and targets are both (B,T) tensor of integers\n",
517
+ " logits = self.token_embedding_table(idx) # (B,T,C)\n",
518
+ " \n",
519
+ " if targets is None:\n",
520
+ " loss = None\n",
521
+ " else:\n",
522
+ " B, T, C = logits.shape\n",
523
+ " logits = logits.view(B*T, C)\n",
524
+ " targets = targets.view(B*T)\n",
525
+ " loss = F.cross_entropy(logits, targets)\n",
526
+ "\n",
527
+ " return logits, loss\n",
528
+ " \n",
529
+ " def generate(self, idx, max_new_tokens):\n",
530
+ " # idx is (B, T) array of indices in the current context\n",
531
+ " for _ in range(max_new_tokens):\n",
532
+ " # get the predictions\n",
533
+ " logits, loss = self(idx)\n",
534
+ " # focus only on the last time step\n",
535
+ " logits = logits[:, -1, :] # becomes (B, C)\n",
536
+ " # apply softmax to get probabilities\n",
537
+ " probs = F.softmax(logits, dim=-1) # (B, C)\n",
538
+ " # sample from the distribution\n",
539
+ " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
540
+ " # append sampled index to the running sequence\n",
541
+ " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
542
+ " return idx\n",
543
+ "\n",
544
+ "m = BigramLanguageModel(vocab_size)\n",
545
+ "logits, loss = m(xb, yb)\n",
546
+ "print(logits.shape)\n",
547
+ "print(loss)\n",
548
+ "\n",
549
+ "print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))\n"
550
+ ],
551
+ "metadata": {
552
+ "colab": {
553
+ "base_uri": "https://localhost:8080/"
554
+ },
555
+ "id": "JadlSYPFfV5i",
556
+ "outputId": "48885ec2-7337-4d9b-8931-9db5b06ff04a"
557
+ },
558
+ "execution_count": 13,
559
+ "outputs": [
560
+ {
561
+ "output_type": "stream",
562
+ "name": "stdout",
563
+ "text": [
564
+ "torch.Size([32, 65])\n",
565
+ "tensor(4.8786, grad_fn=<NllLossBackward0>)\n",
566
+ "\n",
567
+ "Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3\n"
568
+ ]
569
+ }
570
+ ]
571
+ },
572
+ {
573
+ "cell_type": "code",
574
+ "source": [
575
+ "# create a PyTorch optimizer\n",
576
+ "optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)"
577
+ ],
578
+ "metadata": {
579
+ "id": "kC6Sf0DkfZEs"
580
+ },
581
+ "execution_count": 14,
582
+ "outputs": []
583
+ },
584
+ {
585
+ "cell_type": "code",
586
+ "source": [
587
+ "batch_size = 32\n",
588
+ "for steps in range(100):\n",
589
+ "\n",
590
+ " xb, yb = get_batch('train')\n",
591
+ "\n",
592
+ " logits, loss = m(xb, yb)\n",
593
+ " optimizer.zero_grad(set_to_none=True)\n",
594
+ " loss.backward()\n",
595
+ " optimizer.step()\n",
596
+ "\n",
597
+ "print(loss.item())"
598
+ ],
599
+ "metadata": {
600
+ "colab": {
601
+ "base_uri": "https://localhost:8080/"
602
+ },
603
+ "id": "eAdiWhq8mq0v",
604
+ "outputId": "2210d81b-5438-4e35-9336-5f30567de53d"
605
+ },
606
+ "execution_count": 15,
607
+ "outputs": [
608
+ {
609
+ "output_type": "stream",
610
+ "name": "stdout",
611
+ "text": [
612
+ "4.587916374206543\n"
613
+ ]
614
+ }
615
+ ]
616
+ },
617
+ {
618
+ "cell_type": "code",
619
+ "source": [
620
+ "print(decode(m.generate(idx = torch.zeros((1, 1), dtype = torch.long), max_new_tokens=500)[0].tolist()))"
621
+ ],
622
+ "metadata": {
623
+ "colab": {
624
+ "base_uri": "https://localhost:8080/"
625
+ },
626
+ "id": "9I0z9v9NnVcW",
627
+ "outputId": "07133374-3061-41e3-9e0e-77ba644c3c94"
628
+ },
629
+ "execution_count": 16,
630
+ "outputs": [
631
+ {
632
+ "output_type": "stream",
633
+ "name": "stdout",
634
+ "text": [
635
+ "\n",
636
+ "xiKi-RJ:CgqVuUa!U?qMH.uk!sCuMXvv!CJFfx;LgRyJknOEti.?I&-gPlLyulId?XlaInQ'q,lT$\n",
637
+ "3Q&sGlvHQ?mqSq-eON\n",
638
+ "x?SP fUAfCAuCX:bOlgiRQWN:Mphaw\n",
639
+ "tRLKuYXEaAXxrcq-gCUzeh3w!AcyaylgYWjmJM?Uzw:inaY,:C&OECW:vmGGJAn3onAuMgia!ms$Vb q-gCOcPcUhOnxJGUGSPJWT:.?ujmJFoiNL&A'DxY,prZ?qdT;hoo'dHooXXlxf'WkHK&u3Q?rqUi.kz;?Yx?C&u3Qbfzxlyh'Vl:zyxjKXgC?\n",
640
+ "lv'QKFiBeviNxO'm!Upm$srm&TqViqiBD3HBP!juEOpmZJyF$Fwfy!PlvWPFC\n",
641
+ "&WDdP!Ko,px\n",
642
+ "x\n",
643
+ "tREOE;AJ.BeXkylOVD3KHp$e?nD,.SFbWWI'ubcL!q-tU;aXmJ&uGXHxJXI&Z!gHRpajj;l.\n",
644
+ "pTErIBjx;JKIgoCnLGXrJSP!AU-AcbczR?\n"
645
+ ]
646
+ }
647
+ ]
648
+ },
649
+ {
650
+ "cell_type": "markdown",
651
+ "source": [
652
+ "#Mathematical Trick in self-attention"
653
+ ],
654
+ "metadata": {
655
+ "id": "JPRFdk7pn7Xz"
656
+ }
657
+ },
658
+ {
659
+ "cell_type": "code",
660
+ "source": [
661
+ "# toy example for M Mul for weighted Aggregation\n",
662
+ "torch.manual_seed(42)\n",
663
+ "a = torch.tril(torch.ones(3, 3))\n",
664
+ "a = a / torch.sum(a, 1, keepdim=True)\n",
665
+ "b = torch.randint(0,10,(3,2)).float()\n",
666
+ "c = a @ b\n",
667
+ "print('a=')\n",
668
+ "print(a)\n",
669
+ "print('--')\n",
670
+ "print('b=')\n",
671
+ "print(b)\n",
672
+ "print('--')\n",
673
+ "print('c=')\n",
674
+ "print(c)\n"
675
+ ],
676
+ "metadata": {
677
+ "colab": {
678
+ "base_uri": "https://localhost:8080/"
679
+ },
680
+ "id": "z-XvQJi_u0HL",
681
+ "outputId": "486bcbac-c42e-494c-e9a0-341779370076"
682
+ },
683
+ "execution_count": 17,
684
+ "outputs": [
685
+ {
686
+ "output_type": "stream",
687
+ "name": "stdout",
688
+ "text": [
689
+ "a=\n",
690
+ "tensor([[1.0000, 0.0000, 0.0000],\n",
691
+ " [0.5000, 0.5000, 0.0000],\n",
692
+ " [0.3333, 0.3333, 0.3333]])\n",
693
+ "--\n",
694
+ "b=\n",
695
+ "tensor([[2., 7.],\n",
696
+ " [6., 4.],\n",
697
+ " [6., 5.]])\n",
698
+ "--\n",
699
+ "c=\n",
700
+ "tensor([[2.0000, 7.0000],\n",
701
+ " [4.0000, 5.5000],\n",
702
+ " [4.6667, 5.3333]])\n"
703
+ ]
704
+ }
705
+ ]
706
+ },
707
+ {
708
+ "cell_type": "code",
709
+ "source": [
710
+ "torch.manual_seed(1337)\n",
711
+ "B,T,C = 4,8,2 # BATCH, TIME, CHANNELS\n",
712
+ "x = torch.randn(B,T,C)\n",
713
+ "x.shape"
714
+ ],
715
+ "metadata": {
716
+ "colab": {
717
+ "base_uri": "https://localhost:8080/"
718
+ },
719
+ "id": "8zInghO3v5yg",
720
+ "outputId": "4f7a38e9-05a2-494b-eda1-2d8ca136fe03"
721
+ },
722
+ "execution_count": 18,
723
+ "outputs": [
724
+ {
725
+ "output_type": "execute_result",
726
+ "data": {
727
+ "text/plain": [
728
+ "torch.Size([4, 8, 2])"
729
+ ]
730
+ },
731
+ "metadata": {},
732
+ "execution_count": 18
733
+ }
734
+ ]
735
+ },
736
+ {
737
+ "cell_type": "code",
738
+ "source": [
739
+ "xbow = torch.zeros((B,T,C))\n",
740
+ "for b in range(B):\n",
741
+ " for t in range(T):\n",
742
+ " xprev = x[b, :t+1]\n",
743
+ " xbow[b,t] = torch.mean(xprev, 0)"
744
+ ],
745
+ "metadata": {
746
+ "id": "kM4Az6f3xXwz"
747
+ },
748
+ "execution_count": 19,
749
+ "outputs": []
750
+ },
751
+ {
752
+ "cell_type": "code",
753
+ "source": [
754
+ "wei = torch.tril(torch.ones(T, T))\n",
755
+ "wei = wei / wei.sum(1, keepdim=True)\n",
756
+ "xbow2 = wei @ x\n",
757
+ "torch.allclose(xbow, xbow2)"
758
+ ],
759
+ "metadata": {
760
+ "colab": {
761
+ "base_uri": "https://localhost:8080/"
762
+ },
763
+ "id": "j6mzu409x9qt",
764
+ "outputId": "16c8abd7-5e22-4c7e-b2e4-fc53041411d2"
765
+ },
766
+ "execution_count": 20,
767
+ "outputs": [
768
+ {
769
+ "output_type": "execute_result",
770
+ "data": {
771
+ "text/plain": [
772
+ "True"
773
+ ]
774
+ },
775
+ "metadata": {},
776
+ "execution_count": 20
777
+ }
778
+ ]
779
+ },
780
+ {
781
+ "cell_type": "code",
782
+ "source": [
783
+ "tril = torch.tril(torch.ones(T, T))\n",
784
+ "wei = torch.zeros((T, T))\n",
785
+ "wei = wei.masked_fill(tril == 0, float('-inf'))\n",
786
+ "wei = F.softmax(wei, dim=-1)\n",
787
+ "xbow3 = wei @ x\n",
788
+ "torch.allclose(xbow, xbow3)"
789
+ ],
790
+ "metadata": {
791
+ "colab": {
792
+ "base_uri": "https://localhost:8080/"
793
+ },
794
+ "id": "Ez5cxjXjyeyA",
795
+ "outputId": "8cf70b82-93bb-4b9a-c29c-50342c99ca0b"
796
+ },
797
+ "execution_count": 22,
798
+ "outputs": [
799
+ {
800
+ "output_type": "execute_result",
801
+ "data": {
802
+ "text/plain": [
803
+ "True"
804
+ ]
805
+ },
806
+ "metadata": {},
807
+ "execution_count": 22
808
+ }
809
+ ]
810
+ },
811
+ {
812
+ "cell_type": "code",
813
+ "source": [
814
+ "# Self-attention !\n",
815
+ "torch.manual_seed(1337)\n",
816
+ "B,T,C = 4,8,32\n",
817
+ "x = torch.randn(B,T,C)\n",
818
+ "\n",
819
+ "# Single head perform self-attention\n",
820
+ "head_size = 16\n",
821
+ "key = nn.Linear(C, head_size, bias=False)\n",
822
+ "query = nn.Linear(C, head_size, bias=False)\n",
823
+ "value = nn.Linear(C, head_size, bias=False)\n",
824
+ "k = key(x)\n",
825
+ "q = query(x)\n",
826
+ "wei = q @ k.transpose(-2, -1)\n",
827
+ "\n",
828
+ "tril = torch.tril(torch.ones(T, T))\n",
829
+ "wei = wei.masked_fill(tril == 0, float('-inf'))\n",
830
+ "wei = F.softmax(wei, dim=-1)\n",
831
+ "\n",
832
+ "v = value(x)\n",
833
+ "out = wei @ v\n",
834
+ "\n",
835
+ "out.shape"
836
+ ],
837
+ "metadata": {
838
+ "colab": {
839
+ "base_uri": "https://localhost:8080/"
840
+ },
841
+ "id": "d4fbZKO_zJlE",
842
+ "outputId": "61bfb573-3b08-4e83-aed1-cdb4be76ead8"
843
+ },
844
+ "execution_count": 23,
845
+ "outputs": [
846
+ {
847
+ "output_type": "execute_result",
848
+ "data": {
849
+ "text/plain": [
850
+ "torch.Size([4, 8, 16])"
851
+ ]
852
+ },
853
+ "metadata": {},
854
+ "execution_count": 23
855
+ }
856
+ ]
857
+ },
858
+ {
859
+ "cell_type": "code",
860
+ "source": [
861
+ "wei[0]"
862
+ ],
863
+ "metadata": {
864
+ "colab": {
865
+ "base_uri": "https://localhost:8080/"
866
+ },
867
+ "id": "5mUg8q-D1xJ3",
868
+ "outputId": "24f9aa45-1d20-4bc6-8efb-af5f5fb9899c"
869
+ },
870
+ "execution_count": 24,
871
+ "outputs": [
872
+ {
873
+ "output_type": "execute_result",
874
+ "data": {
875
+ "text/plain": [
876
+ "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
877
+ " [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
878
+ " [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
879
+ " [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],\n",
880
+ " [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],\n",
881
+ " [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],\n",
882
+ " [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],\n",
883
+ " [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],\n",
884
+ " grad_fn=<SelectBackward0>)"
885
+ ]
886
+ },
887
+ "metadata": {},
888
+ "execution_count": 24
889
+ }
890
+ ]
891
+ },
892
+ {
893
+ "cell_type": "code",
894
+ "source": [
895
+ "k = torch.randn(B,T,head_size)\n",
896
+ "q = torch.randn(B,T,head_size)\n",
897
+ "wei = q @ k.transpose(-2, -1) * head_size**-0.5"
898
+ ],
899
+ "metadata": {
900
+ "id": "L6Hz65jN11C5"
901
+ },
902
+ "execution_count": 25,
903
+ "outputs": []
904
+ },
905
+ {
906
+ "cell_type": "code",
907
+ "source": [
908
+ "k.var()"
909
+ ],
910
+ "metadata": {
911
+ "colab": {
912
+ "base_uri": "https://localhost:8080/"
913
+ },
914
+ "id": "opow74Yg82UN",
915
+ "outputId": "7937ca44-b52d-4373-ae58-d0c1ed450fa7"
916
+ },
917
+ "execution_count": 26,
918
+ "outputs": [
919
+ {
920
+ "output_type": "execute_result",
921
+ "data": {
922
+ "text/plain": [
923
+ "tensor(1.0449)"
924
+ ]
925
+ },
926
+ "metadata": {},
927
+ "execution_count": 26
928
+ }
929
+ ]
930
+ },
931
+ {
932
+ "cell_type": "code",
933
+ "source": [
934
+ "q.var()"
935
+ ],
936
+ "metadata": {
937
+ "colab": {
938
+ "base_uri": "https://localhost:8080/"
939
+ },
940
+ "id": "jEGJMlZh86lD",
941
+ "outputId": "c093ea15-9db4-408b-8898-0192748f8ab2"
942
+ },
943
+ "execution_count": 27,
944
+ "outputs": [
945
+ {
946
+ "output_type": "execute_result",
947
+ "data": {
948
+ "text/plain": [
949
+ "tensor(1.0700)"
950
+ ]
951
+ },
952
+ "metadata": {},
953
+ "execution_count": 27
954
+ }
955
+ ]
956
+ },
957
+ {
958
+ "cell_type": "code",
959
+ "source": [
960
+ "wei.var()"
961
+ ],
962
+ "metadata": {
963
+ "colab": {
964
+ "base_uri": "https://localhost:8080/"
965
+ },
966
+ "id": "37djNLHJ88Gh",
967
+ "outputId": "a3ba1d4b-bca5-41a2-afa5-f135056b80ba"
968
+ },
969
+ "execution_count": 28,
970
+ "outputs": [
971
+ {
972
+ "output_type": "execute_result",
973
+ "data": {
974
+ "text/plain": [
975
+ "tensor(1.0918)"
976
+ ]
977
+ },
978
+ "metadata": {},
979
+ "execution_count": 28
980
+ }
981
+ ]
982
+ },
983
+ {
984
+ "cell_type": "code",
985
+ "source": [
986
+ "torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)"
987
+ ],
988
+ "metadata": {
989
+ "colab": {
990
+ "base_uri": "https://localhost:8080/"
991
+ },
992
+ "id": "3NK1li0w89wx",
993
+ "outputId": "4205b108-d666-4add-dd3e-48da20a6e351"
994
+ },
995
+ "execution_count": 29,
996
+ "outputs": [
997
+ {
998
+ "output_type": "execute_result",
999
+ "data": {
1000
+ "text/plain": [
1001
+ "tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])"
1002
+ ]
1003
+ },
1004
+ "metadata": {},
1005
+ "execution_count": 29
1006
+ }
1007
+ ]
1008
+ },
1009
+ {
1010
+ "cell_type": "code",
1011
+ "source": [
1012
+ "torch.softmax(torch.tensor([0.1, -0.2, 0.3,-0.2,0.5])*8, dim=-1)"
1013
+ ],
1014
+ "metadata": {
1015
+ "colab": {
1016
+ "base_uri": "https://localhost:8080/"
1017
+ },
1018
+ "id": "-3UqDMG79QLI",
1019
+ "outputId": "61674514-3887-43a4-93aa-055dfcd61b76"
1020
+ },
1021
+ "execution_count": 30,
1022
+ "outputs": [
1023
+ {
1024
+ "output_type": "execute_result",
1025
+ "data": {
1026
+ "text/plain": [
1027
+ "tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])"
1028
+ ]
1029
+ },
1030
+ "metadata": {},
1031
+ "execution_count": 30
1032
+ }
1033
+ ]
1034
+ },
1035
+ {
1036
+ "cell_type": "code",
1037
+ "source": [
1038
+ "class LayerNorm1d: # (used to be BatchNorm1d)\n",
1039
+ " \n",
1040
+ " def __init__(self, dim, eps=1e-5, momentum=0.1):\n",
1041
+ " self.eps = eps\n",
1042
+ " self.gamma = torch.ones(dim)\n",
1043
+ " self.beta = torch.zeros(dim)\n",
1044
+ " \n",
1045
+ " def __call__(self, x):\n",
1046
+ " # calculate the forward pass\n",
1047
+ " xmean = x.mean(1, keepdim=True) # batch mean\n",
1048
+ " xvar = x.var(1, keepdim=True) # batch variance\n",
1049
+ " xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance\n",
1050
+ " self.out = self.gamma * xhat + self.beta\n",
1051
+ " return self.out\n",
1052
+ " \n",
1053
+ " def parameters(self):\n",
1054
+ " return [self.gamma, self.beta]\n",
1055
+ "\n",
1056
+ "torch.manual_seed(1337)\n",
1057
+ "module = LayerNorm1d(100)\n",
1058
+ "x = torch.randn(32, 100) # batch size 32 of 100-dimensional vectors\n",
1059
+ "x = module(x)\n",
1060
+ "x.shape"
1061
+ ],
1062
+ "metadata": {
1063
+ "colab": {
1064
+ "base_uri": "https://localhost:8080/"
1065
+ },
1066
+ "id": "a_572UNcChia",
1067
+ "outputId": "87012d0d-81cd-4841-a4e8-48bf9c0e2e61"
1068
+ },
1069
+ "execution_count": 32,
1070
+ "outputs": [
1071
+ {
1072
+ "output_type": "execute_result",
1073
+ "data": {
1074
+ "text/plain": [
1075
+ "torch.Size([32, 100])"
1076
+ ]
1077
+ },
1078
+ "metadata": {},
1079
+ "execution_count": 32
1080
+ }
1081
+ ]
1082
+ },
1083
+ {
1084
+ "cell_type": "code",
1085
+ "source": [
1086
+ "x[:, 0].mean(), x[:,0].std()"
1087
+ ],
1088
+ "metadata": {
1089
+ "colab": {
1090
+ "base_uri": "https://localhost:8080/"
1091
+ },
1092
+ "id": "LHfhDFW1Coel",
1093
+ "outputId": "7eff9314-f287-4566-aa4d-7d9082bff11b"
1094
+ },
1095
+ "execution_count": 33,
1096
+ "outputs": [
1097
+ {
1098
+ "output_type": "execute_result",
1099
+ "data": {
1100
+ "text/plain": [
1101
+ "(tensor(0.1469), tensor(0.8803))"
1102
+ ]
1103
+ },
1104
+ "metadata": {},
1105
+ "execution_count": 33
1106
+ }
1107
+ ]
1108
+ },
1109
+ {
1110
+ "cell_type": "code",
1111
+ "source": [
1112
+ "x[0,:].mean(), x[0,:].std()"
1113
+ ],
1114
+ "metadata": {
1115
+ "colab": {
1116
+ "base_uri": "https://localhost:8080/"
1117
+ },
1118
+ "id": "bt7xbja2FOu-",
1119
+ "outputId": "8f1cbfe0-7862-4ba0-bd54-7149a78b7153"
1120
+ },
1121
+ "execution_count": 34,
1122
+ "outputs": [
1123
+ {
1124
+ "output_type": "execute_result",
1125
+ "data": {
1126
+ "text/plain": [
1127
+ "(tensor(-9.5367e-09), tensor(1.0000))"
1128
+ ]
1129
+ },
1130
+ "metadata": {},
1131
+ "execution_count": 34
1132
+ }
1133
+ ]
1134
+ },
1135
+ {
1136
+ "cell_type": "code",
1137
+ "source": [
1138
+ "import torch\n",
1139
+ "import torch.nn as nn\n",
1140
+ "from torch.nn import functional as F\n",
1141
+ "\n",
1142
+ "# hyperparameters\n",
1143
+ "batch_size = 16 # how many independent sequences will we process in parallel?\n",
1144
+ "block_size = 32 # what is the maximum context length for predictions?\n",
1145
+ "max_iters = 5000\n",
1146
+ "eval_interval = 100\n",
1147
+ "learning_rate = 1e-3\n",
1148
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
1149
+ "eval_iters = 200\n",
1150
+ "n_embd = 64\n",
1151
+ "n_head = 4\n",
1152
+ "n_layer = 4\n",
1153
+ "dropout = 0.0\n",
1154
+ "\n",
1155
+ "torch.manual_seed(1337)\n",
1156
+ "\n",
1157
+ "# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n",
1158
+ "with open('input.txt', 'r', encoding='utf-8') as f:\n",
1159
+ " text = f.read()\n",
1160
+ "\n",
1161
+ "# here are all the unique characters that occur in this text\n",
1162
+ "chars = sorted(list(set(text)))\n",
1163
+ "vocab_size = len(chars)\n",
1164
+ "# create a mapping from characters to integers\n",
1165
+ "stoi = { ch:i for i,ch in enumerate(chars) }\n",
1166
+ "itos = { i:ch for i,ch in enumerate(chars) }\n",
1167
+ "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
1168
+ "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n",
1169
+ "\n",
1170
+ "# Train and test splits\n",
1171
+ "data = torch.tensor(encode(text), dtype=torch.long)\n",
1172
+ "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
1173
+ "train_data = data[:n]\n",
1174
+ "val_data = data[n:]\n",
1175
+ "\n",
1176
+ "# data loading\n",
1177
+ "def get_batch(split):\n",
1178
+ " # generate a small batch of data of inputs x and targets y\n",
1179
+ " data = train_data if split == 'train' else val_data\n",
1180
+ " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
1181
+ " x = torch.stack([data[i:i+block_size] for i in ix])\n",
1182
+ " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
1183
+ " x, y = x.to(device), y.to(device)\n",
1184
+ " return x, y\n",
1185
+ "\n",
1186
+ "@torch.no_grad()\n",
1187
+ "def estimate_loss():\n",
1188
+ " out = {}\n",
1189
+ " model.eval()\n",
1190
+ " for split in ['train', 'val']:\n",
1191
+ " losses = torch.zeros(eval_iters)\n",
1192
+ " for k in range(eval_iters):\n",
1193
+ " X, Y = get_batch(split)\n",
1194
+ " logits, loss = model(X, Y)\n",
1195
+ " losses[k] = loss.item()\n",
1196
+ " out[split] = losses.mean()\n",
1197
+ " model.train()\n",
1198
+ " return out\n",
1199
+ "\n",
1200
+ "class Head(nn.Module):\n",
1201
+ " \"\"\" one head of self-attention \"\"\"\n",
1202
+ "\n",
1203
+ " def __init__(self, head_size):\n",
1204
+ " super().__init__()\n",
1205
+ " self.key = nn.Linear(n_embd, head_size, bias=False)\n",
1206
+ " self.query = nn.Linear(n_embd, head_size, bias=False)\n",
1207
+ " self.value = nn.Linear(n_embd, head_size, bias=False)\n",
1208
+ " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
1209
+ "\n",
1210
+ " self.dropout = nn.Dropout(dropout)\n",
1211
+ "\n",
1212
+ " def forward(self, x):\n",
1213
+ " B,T,C = x.shape\n",
1214
+ " k = self.key(x) # (B,T,C)\n",
1215
+ " q = self.query(x) # (B,T,C)\n",
1216
+ " # compute attention scores (\"affinities\")\n",
1217
+ " wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n",
1218
+ " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n",
1219
+ " wei = F.softmax(wei, dim=-1) # (B, T, T)\n",
1220
+ " wei = self.dropout(wei)\n",
1221
+ " # perform the weighted aggregation of the values\n",
1222
+ " v = self.value(x) # (B,T,C)\n",
1223
+ " out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n",
1224
+ " return out\n",
1225
+ "\n",
1226
+ "class MultiHeadAttention(nn.Module):\n",
1227
+ " \"\"\" multiple heads of self-attention in parallel \"\"\"\n",
1228
+ "\n",
1229
+ " def __init__(self, num_heads, head_size):\n",
1230
+ " super().__init__()\n",
1231
+ " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n",
1232
+ " self.proj = nn.Linear(n_embd, n_embd)\n",
1233
+ " self.dropout = nn.Dropout(dropout)\n",
1234
+ "\n",
1235
+ " def forward(self, x):\n",
1236
+ " out = torch.cat([h(x) for h in self.heads], dim=-1)\n",
1237
+ " out = self.dropout(self.proj(out))\n",
1238
+ " return out\n",
1239
+ "\n",
1240
+ "class FeedFoward(nn.Module):\n",
1241
+ " \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n",
1242
+ "\n",
1243
+ " def __init__(self, n_embd):\n",
1244
+ " super().__init__()\n",
1245
+ " self.net = nn.Sequential(\n",
1246
+ " nn.Linear(n_embd, 4 * n_embd),\n",
1247
+ " nn.ReLU(),\n",
1248
+ " nn.Linear(4 * n_embd, n_embd),\n",
1249
+ " nn.Dropout(dropout),\n",
1250
+ " )\n",
1251
+ "\n",
1252
+ " def forward(self, x):\n",
1253
+ " return self.net(x)\n",
1254
+ "\n",
1255
+ "class Block(nn.Module):\n",
1256
+ " \"\"\" Transformer block: communication followed by computation \"\"\"\n",
1257
+ "\n",
1258
+ " def __init__(self, n_embd, n_head):\n",
1259
+ " # n_embd: embedding dimension, n_head: the number of heads we'd like\n",
1260
+ " super().__init__()\n",
1261
+ " head_size = n_embd // n_head\n",
1262
+ " self.sa = MultiHeadAttention(n_head, head_size)\n",
1263
+ " self.ffwd = FeedFoward(n_embd)\n",
1264
+ " self.ln1 = nn.LayerNorm(n_embd)\n",
1265
+ " self.ln2 = nn.LayerNorm(n_embd)\n",
1266
+ "\n",
1267
+ " def forward(self, x):\n",
1268
+ " x = x + self.sa(self.ln1(x))\n",
1269
+ " x = x + self.ffwd(self.ln2(x))\n",
1270
+ " return x\n",
1271
+ "\n",
1272
+ "# super simple bigram model\n",
1273
+ "class BigramLanguageModel(nn.Module):\n",
1274
+ "\n",
1275
+ " def __init__(self):\n",
1276
+ " super().__init__()\n",
1277
+ " # each token directly reads off the logits for the next token from a lookup table\n",
1278
+ " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n",
1279
+ " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n",
1280
+ " self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n",
1281
+ " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n",
1282
+ " self.lm_head = nn.Linear(n_embd, vocab_size)\n",
1283
+ "\n",
1284
+ " def forward(self, idx, targets=None):\n",
1285
+ " B, T = idx.shape\n",
1286
+ "\n",
1287
+ " # idx and targets are both (B,T) tensor of integers\n",
1288
+ " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n",
1289
+ " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n",
1290
+ " x = tok_emb + pos_emb # (B,T,C)\n",
1291
+ " x = self.blocks(x) # (B,T,C)\n",
1292
+ " x = self.ln_f(x) # (B,T,C)\n",
1293
+ " logits = self.lm_head(x) # (B,T,vocab_size)\n",
1294
+ "\n",
1295
+ " if targets is None:\n",
1296
+ " loss = None\n",
1297
+ " else:\n",
1298
+ " B, T, C = logits.shape\n",
1299
+ " logits = logits.view(B*T, C)\n",
1300
+ " targets = targets.view(B*T)\n",
1301
+ " loss = F.cross_entropy(logits, targets)\n",
1302
+ "\n",
1303
+ " return logits, loss\n",
1304
+ "\n",
1305
+ " def generate(self, idx, max_new_tokens):\n",
1306
+ " # idx is (B, T) array of indices in the current context\n",
1307
+ " for _ in range(max_new_tokens):\n",
1308
+ " # crop idx to the last block_size tokens\n",
1309
+ " idx_cond = idx[:, -block_size:]\n",
1310
+ " # get the predictions\n",
1311
+ " logits, loss = self(idx_cond)\n",
1312
+ " # focus only on the last time step\n",
1313
+ " logits = logits[:, -1, :] # becomes (B, C)\n",
1314
+ " # apply softmax to get probabilities\n",
1315
+ " probs = F.softmax(logits, dim=-1) # (B, C)\n",
1316
+ " # sample from the distribution\n",
1317
+ " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
1318
+ " # append sampled index to the running sequence\n",
1319
+ " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
1320
+ " return idx\n",
1321
+ "\n",
1322
+ "model = BigramLanguageModel()\n",
1323
+ "m = model.to(device)\n",
1324
+ "# print the number of parameters in the model\n",
1325
+ "print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')\n",
1326
+ "\n",
1327
+ "# create a PyTorch optimizer\n",
1328
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
1329
+ "\n",
1330
+ "for iter in range(max_iters):\n",
1331
+ "\n",
1332
+ " # every once in a while evaluate the loss on train and val sets\n",
1333
+ " if iter % eval_interval == 0 or iter == max_iters - 1:\n",
1334
+ " losses = estimate_loss()\n",
1335
+ " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
1336
+ "\n",
1337
+ " # sample a batch of data\n",
1338
+ " xb, yb = get_batch('train')\n",
1339
+ "\n",
1340
+ " # evaluate the loss\n",
1341
+ " logits, loss = model(xb, yb)\n",
1342
+ " optimizer.zero_grad(set_to_none=True)\n",
1343
+ " loss.backward()\n",
1344
+ " optimizer.step()\n",
1345
+ "\n",
1346
+ "# generate from the model\n",
1347
+ "context = torch.zeros((1, 1), dtype=torch.long, device=device )\n",
1348
+ "print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))"
1349
+ ],
1350
+ "metadata": {
1351
+ "colab": {
1352
+ "base_uri": "https://localhost:8080/"
1353
+ },
1354
+ "id": "WYnRTqPbFXHy",
1355
+ "outputId": "d625a959-7490-4a84-e692-600da91e0ef9"
1356
+ },
1357
+ "execution_count": 35,
1358
+ "outputs": [
1359
+ {
1360
+ "output_type": "stream",
1361
+ "name": "stdout",
1362
+ "text": [
1363
+ "0.209729 M parameters\n",
1364
+ "step 0: train loss 4.4116, val loss 4.4022\n",
1365
+ "step 100: train loss 2.6568, val loss 2.6670\n",
1366
+ "step 200: train loss 2.5090, val loss 2.5059\n",
1367
+ "step 300: train loss 2.4196, val loss 2.4338\n",
1368
+ "step 400: train loss 2.3503, val loss 2.3565\n",
1369
+ "step 500: train loss 2.2966, val loss 2.3129\n",
1370
+ "step 600: train loss 2.2410, val loss 2.2500\n",
1371
+ "step 700: train loss 2.2051, val loss 2.2191\n",
1372
+ "step 800: train loss 2.1640, val loss 2.1874\n",
1373
+ "step 900: train loss 2.1251, val loss 2.1515\n",
1374
+ "step 1000: train loss 2.1023, val loss 2.1291\n",
1375
+ "step 1100: train loss 2.0699, val loss 2.1192\n",
1376
+ "step 1200: train loss 2.0375, val loss 2.0797\n",
1377
+ "step 1300: train loss 2.0259, val loss 2.0647\n",
1378
+ "step 1400: train loss 1.9924, val loss 2.0362\n",
1379
+ "step 1500: train loss 1.9700, val loss 2.0304\n",
1380
+ "step 1600: train loss 1.9631, val loss 2.0476\n",
1381
+ "step 1700: train loss 1.9412, val loss 2.0131\n",
1382
+ "step 1800: train loss 1.9097, val loss 1.9960\n",
1383
+ "step 1900: train loss 1.9101, val loss 1.9882\n",
1384
+ "step 2000: train loss 1.8867, val loss 1.9976\n",
1385
+ "step 2100: train loss 1.8720, val loss 1.9754\n",
1386
+ "step 2200: train loss 1.8588, val loss 1.9606\n",
1387
+ "step 2300: train loss 1.8542, val loss 1.9525\n",
1388
+ "step 2400: train loss 1.8424, val loss 1.9464\n",
1389
+ "step 2500: train loss 1.8173, val loss 1.9455\n",
1390
+ "step 2600: train loss 1.8256, val loss 1.9388\n",
1391
+ "step 2700: train loss 1.8116, val loss 1.9350\n",
1392
+ "step 2800: train loss 1.8056, val loss 1.9214\n",
1393
+ "step 2900: train loss 1.8040, val loss 1.9300\n",
1394
+ "step 3000: train loss 1.7974, val loss 1.9205\n",
1395
+ "step 3100: train loss 1.7694, val loss 1.9157\n",
1396
+ "step 3200: train loss 1.7539, val loss 1.9115\n",
1397
+ "step 3300: train loss 1.7571, val loss 1.9071\n",
1398
+ "step 3400: train loss 1.7531, val loss 1.8954\n",
1399
+ "step 3500: train loss 1.7368, val loss 1.8918\n",
1400
+ "step 3600: train loss 1.7274, val loss 1.8884\n",
1401
+ "step 3700: train loss 1.7301, val loss 1.8819\n",
1402
+ "step 3800: train loss 1.7210, val loss 1.8938\n",
1403
+ "step 3900: train loss 1.7260, val loss 1.8750\n",
1404
+ "step 4000: train loss 1.7122, val loss 1.8554\n",
1405
+ "step 4100: train loss 1.7129, val loss 1.8717\n",
1406
+ "step 4200: train loss 1.7041, val loss 1.8634\n",
1407
+ "step 4300: train loss 1.6986, val loss 1.8434\n",
1408
+ "step 4400: train loss 1.7052, val loss 1.8605\n",
1409
+ "step 4500: train loss 1.6881, val loss 1.8467\n",
1410
+ "step 4600: train loss 1.6849, val loss 1.8318\n",
1411
+ "step 4700: train loss 1.6833, val loss 1.8449\n",
1412
+ "step 4800: train loss 1.6686, val loss 1.8472\n",
1413
+ "step 4900: train loss 1.6719, val loss 1.8425\n",
1414
+ "step 4999: train loss 1.6619, val loss 1.8215\n",
1415
+ "\n",
1416
+ "And they bride will to lay be madie;\n",
1417
+ "Thou but take O-dam the change:\n",
1418
+ "Warth full him tother dilth ane away, my fears,\n",
1419
+ "You have was them of is heart mile,\n",
1420
+ "You, and if ensmy contlatist, drov the does me now that\n",
1421
+ "just, lesing that.\n",
1422
+ "His my now, you up; and the tyby love.\n",
1423
+ "In Bodiet, and whom\n",
1424
+ "that demperakenous, so what evily well my\n",
1425
+ "Murtus censurence of him the reshep and thrust for to imper my monte in Mont,\n",
1426
+ "To fight? gry of thy hourb! stiddy as\n",
1427
+ "ards bearing her broint must are no Runnts\n",
1428
+ "Infortuce will me not be arm.\n",
1429
+ "You contrantymes have myse.-\n",
1430
+ "And fortwerle madam them may in son, live body.\n",
1431
+ "\n",
1432
+ "Think you:\n",
1433
+ "It stay might. \n",
1434
+ "CLAMENCE:\n",
1435
+ "My whilesse everew in movet, if Cassce of's counted;\n",
1436
+ "How what make you fear tals: the gold my sun?\n",
1437
+ "What, loudy forgor man our him.\n",
1438
+ "I will were but with some. Povinly Ford the welcont.\n",
1439
+ "\n",
1440
+ "QUEEN FIDILIZ:\n",
1441
+ "No?\n",
1442
+ "Their him the not.\n",
1443
+ "\n",
1444
+ "POLIXENENE:\n",
1445
+ "But to me, God no now the summe wip.\n",
1446
+ "\n",
1447
+ "GROMPEO:\n",
1448
+ "Conguit, bruke this belike, on so han the bodiet.\n",
1449
+ "\n",
1450
+ "CORIOLANUS:\n",
1451
+ "Till the;\n",
1452
+ "you wellseers I am with you,\n",
1453
+ "For I hust no where Mustconce, do wind that I am nobly.\n",
1454
+ "\n",
1455
+ "BRUSTHORD:\n",
1456
+ "O, wenterings so me worting.\n",
1457
+ "\n",
1458
+ "GRUMIO:\n",
1459
+ "O thus favour now,\n",
1460
+ "An bear was all beenIn\n",
1461
+ "Before and to the sever--and.\n",
1462
+ "In to dot me, to liberfeleing breamn'd my have\n",
1463
+ "epince, if that jutcey's leve,\n",
1464
+ "That Tumselfly there's little ofjess the vown;\n",
1465
+ "Maughter armied maste love in stide belothy dong'd the not.\n",
1466
+ "\n",
1467
+ "BENVOLIO:\n",
1468
+ "Well cavonzy to I have must aboe;\n",
1469
+ "I now, I thinke numt om Three teny, delelige,\n",
1470
+ "And yet our son one old, we\n",
1471
+ "ell sment on you; and plock, say, as If have to kavidess corby?\n",
1472
+ "Then eteep; upose worth\n",
1473
+ "But arm one wall preven him there.\n",
1474
+ "\n",
1475
+ "BUCKINGHARD\n",
1476
+ "\n",
1477
+ "IVIRHAMIUS:\n",
1478
+ "Why, unere to-marrow thy sathe court his in on\n",
1479
+ "some no, God the have blay not, these wife it:\n",
1480
+ "The that hear I, thou with art, lives?\n",
1481
+ "\n",
1482
+ "LARY:\n",
1483
+ "Our while with you\n",
1484
+ "That I horrtw'd will theirs is.\n",
1485
+ "Why, I would I drue, and was father,--\n",
1486
+ "'Tensis, thy promb, many and sentry talbatt.\n",
1487
+ "\n",
1488
+ "PORDINCE:\n",
1489
+ "Why Riparding:\n",
1490
+ "In is shown's fortunds, but whom the brike our all\n"
1491
+ ]
1492
+ }
1493
+ ]
1494
+ },
1495
+ {
1496
+ "cell_type": "code",
1497
+ "source": [],
1498
+ "metadata": {
1499
+ "id": "i8lCFzYGMkBk"
1500
+ },
1501
+ "execution_count": null,
1502
+ "outputs": []
1503
+ }
1504
+ ]
1505
+ }