Jensen-holm commited on
Commit
84bbd7d
1 Parent(s): 38e3b7b

init weights and biases, and getting through epochs

Browse files
.gitignore CHANGED
@@ -24,4 +24,165 @@ go.work
24
  .vscode
25
  .idea
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  *.swp
 
24
  .vscode
25
  .idea
26
 
27
+ # Byte-compiled / optimized / DLL files
28
+ __pycache__/
29
+ *.py[cod]
30
+ *$py.class
31
+
32
+ # C extensions
33
+ *.so
34
+
35
+ # Distribution / packaging
36
+ .Python
37
+ build/
38
+ develop-eggs/
39
+ dist/
40
+ downloads/
41
+ eggs/
42
+ .eggs/
43
+ lib/
44
+ lib64/
45
+ parts/
46
+ sdist/
47
+ var/
48
+ wheels/
49
+ share/python-wheels/
50
+ *.egg-info/
51
+ .installed.cfg
52
+ *.egg
53
+ MANIFEST
54
+
55
+ # PyInstaller
56
+ # Usually these files are written by a python script from a template
57
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
58
+ *.manifest
59
+ *.spec
60
+
61
+ # Installer logs
62
+ pip-log.txt
63
+ pip-delete-this-directory.txt
64
+
65
+ # Unit test / coverage reports
66
+ htmlcov/
67
+ .tox/
68
+ .nox/
69
+ .coverage
70
+ .coverage.*
71
+ .cache
72
+ nosetests.xml
73
+ coverage.xml
74
+ *.cover
75
+ *.py,cover
76
+ .hypothesis/
77
+ .pytest_cache/
78
+ cover/
79
+
80
+ # Translations
81
+ *.mo
82
+ *.pot
83
+
84
+ # Django stuff:
85
+ *.log
86
+ local_settings.py
87
+ db.sqlite3
88
+ db.sqlite3-journal
89
+
90
+ # Flask stuff:
91
+ instance/
92
+ .webassets-cache
93
+
94
+ # Scrapy stuff:
95
+ .scrapy
96
+
97
+ # Sphinx documentation
98
+ docs/_build/
99
+
100
+ # PyBuilder
101
+ .pybuilder/
102
+ target/
103
+
104
+ # Jupyter Notebook
105
+ .ipynb_checkpoints
106
+
107
+ # IPython
108
+ profile_default/
109
+ ipython_config.py
110
+
111
+ # pyenv
112
+ # For a library or package, you might want to ignore these files since the code is
113
+ # intended to run in multiple environments; otherwise, check them in:
114
+ # .python-version
115
+
116
+ # pipenv
117
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
118
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
119
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
120
+ # install all needed dependencies.
121
+ #Pipfile.lock
122
+
123
+ # poetry
124
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
125
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
126
+ # commonly ignored for libraries.
127
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
128
+ #poetry.lock
129
+
130
+ # pdm
131
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
132
+ #pdm.lock
133
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
134
+ # in version control.
135
+ # https://pdm.fming.dev/#use-with-ide
136
+ .pdm.toml
137
+
138
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
139
+ __pypackages__/
140
+
141
+ # Celery stuff
142
+ celerybeat-schedule
143
+ celerybeat.pid
144
+
145
+ # SageMath parsed files
146
+ *.sage.py
147
+
148
+ # Environments
149
+ .env
150
+ .venv
151
+ env/
152
+ venv/
153
+ ENV/
154
+ env.bak/
155
+ venv.bak/
156
+
157
+ # Spyder project settings
158
+ .spyderproject
159
+ .spyproject
160
+
161
+ # Rope project settings
162
+ .ropeproject
163
+
164
+ # mkdocs documentation
165
+ /site
166
+
167
+ # mypy
168
+ .mypy_cache/
169
+ .dmypy.json
170
+ dmypy.json
171
+
172
+ # Pyre type checker
173
+ .pyre/
174
+
175
+ # pytype static type analyzer
176
+ .pytype/
177
+
178
+ # Cython debug symbols
179
+ cython_debug/
180
+
181
+ # PyCharm
182
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
183
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
184
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
185
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
186
+ #.idea/
187
+
188
  *.swp
nn/__pycache__/activation.cpython-310.pyc DELETED
Binary file (1.3 kB)
 
nn/__pycache__/nn.cpython-310.pyc DELETED
Binary file (2.34 kB)
 
nn/__pycache__/train.cpython-310.pyc DELETED
Binary file (1.01 kB)
 
nn/nn.py CHANGED
@@ -24,8 +24,6 @@ class NN:
24
  self.target = target
25
  self.data = data
26
 
27
- self.input_size = len(features)
28
-
29
  self.wh: np.array = None
30
  self.wo: np.array = None
31
  self.bh: np.array = None
@@ -39,8 +37,12 @@ class NN:
39
  def set_df(self, df: pd.DataFrame) -> None:
40
  assert isinstance(df, pd.DataFrame)
41
  self.df = df
42
- self.X = df[self.features]
43
- self.y = df[self.target]
 
 
 
 
44
 
45
  def set_func(self, f: Callable) -> None:
46
  assert isinstance(f, Callable)
 
24
  self.target = target
25
  self.data = data
26
 
 
 
27
  self.wh: np.array = None
28
  self.wo: np.array = None
29
  self.bh: np.array = None
 
37
  def set_df(self, df: pd.DataFrame) -> None:
38
  assert isinstance(df, pd.DataFrame)
39
  self.df = df
40
+ # we can only deal with numbers from here on out
41
+ y = df[self.target]
42
+ x = df[self.features]
43
+ self.y = pd.get_dummies(y, columns=self.target)
44
+ self.X = pd.get_dummies(x, columns=self.features)
45
+ self.input_size = len(self.X.columns)
46
 
47
  def set_func(self, f: Callable) -> None:
48
  assert isinstance(f, Callable)
nn/train.py CHANGED
@@ -1,15 +1,16 @@
1
  from sklearn.model_selection import train_test_split
 
2
  from nn.nn import NN
3
  import pandas as pd
4
  import numpy as np
5
 
6
 
7
  def init_weights_biases(nn: NN) -> None:
8
- np.random.seed(88)
9
- bh = np.zeros((1, 1))
10
  bo = np.zeros((1, 1))
11
- wh = np.random.randn(1, nn.input_size) * np.sqrt(2 / nn.input_size)
12
- wo = np.random.randn(1, nn.hidden_size) * np.sqrt(2 / nn.hidden_size)
 
13
  nn.set_bh(bh)
14
  nn.set_bo(bo)
15
  nn.set_wh(wh)
@@ -22,7 +23,33 @@ def train(nn: NN) -> dict:
22
  nn.X,
23
  nn.y,
24
  test_size=nn.test_size,
25
- random_state=88,
26
  )
27
 
28
- return {"status": "you made it!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from sklearn.model_selection import train_test_split
2
+ from typing import Callable
3
  from nn.nn import NN
4
  import pandas as pd
5
  import numpy as np
6
 
7
 
8
  def init_weights_biases(nn: NN) -> None:
9
+ bh = np.zeros((1, nn.hidden_size))
 
10
  bo = np.zeros((1, 1))
11
+ wh = np.random.randn(nn.input_size, nn.hidden_size) * \
12
+ np.sqrt(2 / nn.input_size)
13
+ wo = np.random.randn(nn.hidden_size, 1) * np.sqrt(2 / nn.hidden_size)
14
  nn.set_bh(bh)
15
  nn.set_bo(bo)
16
  nn.set_wh(wh)
 
23
  nn.X,
24
  nn.y,
25
  test_size=nn.test_size,
 
26
  )
27
 
28
+ for _ in range(nn.epochs):
29
+ # compute hidden output
30
+ hidden_output = compute_node(
31
+ data=X_train.to_numpy(),
32
+ weights=nn.wh,
33
+ biases=nn.bh,
34
+ func=nn.func,
35
+ )
36
+
37
+ # compute output layer
38
+ y_hat = compute_node(
39
+ data=hidden_output,
40
+ weights=nn.wo,
41
+ biases=nn.bo,
42
+ func=nn.func,
43
+ )
44
+
45
+ mse = mean_squared_error(y_train, y_hat)
46
+
47
+ return {"mse": mse}
48
+
49
+
50
+ def compute_node(data: np.array, weights: np.array, biases: np.array, func: Callable) -> np.array:
51
+ return func(np.dot(data, weights) + biases)
52
+
53
+
54
+ def mean_squared_error(y: np.array, y_hat: np.array) -> np.array:
55
+ return np.mean((y - y_hat) ** 2)