Jensen-holm commited on
Commit
29cce3f
1 Parent(s): 2c781f8

working on new python implementation with cleaner code

Browse files
Files changed (16) hide show
  1. app.py +27 -0
  2. example/main.py +3 -2
  3. go.mod +0 -24
  4. go.sum +0 -109
  5. nn/activation.go +0 -41
  6. nn/args.go +0 -17
  7. nn/backprop.go +0 -99
  8. nn/backprop.py +0 -0
  9. nn/main.go +0 -84
  10. nn/nn.py +32 -0
  11. nn/split.go +0 -60
  12. nn/subset.go +0 -6
  13. nn/train.go +0 -7
  14. nn/train.py +15 -0
  15. requirements.txt +3 -0
  16. server.go +0 -26
app.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+
3
+ import pandas as pd
4
+ from nn.nn import NN
5
+ from nn import train as train_nn
6
+
7
+ app = Flask(__name__)
8
+
9
+
10
+ @app.route("/neural-network", methods=["POST"])
11
+ def neural_net():
12
+ args = request.json
13
+
14
+ try:
15
+ net = NN.from_dict(args)
16
+ df = pd.read_csv(args.pop("data"))
17
+ except Exception as e:
18
+ return jsonify({
19
+ "bad request": f"could not read csv data: {e}",
20
+ })
21
+
22
+ result = train_nn(nn=net)
23
+ return jsonify(result)
24
+
25
+
26
+ if __name__ == "__main__":
27
+ app.run(debug=True)
example/main.py CHANGED
@@ -7,10 +7,11 @@ ARGS = {
7
  "epochs": 100,
8
  "hidden_size": 12,
9
  "learning_rate": 0.01,
 
10
  "activation": "tanh",
11
  "features": ["sepal width", "sepal length", "petal width", "petal length"],
12
  "target": "species",
13
- "data": iris_data.decode('utf-8'),
14
  }
15
 
16
  r = requests.post(
@@ -19,4 +20,4 @@ r = requests.post(
19
  )
20
 
21
  if __name__ == "__main__":
22
- print(r.json())
 
7
  "epochs": 100,
8
  "hidden_size": 12,
9
  "learning_rate": 0.01,
10
+ "test_size": 0.3,
11
  "activation": "tanh",
12
  "features": ["sepal width", "sepal length", "petal width", "petal length"],
13
  "target": "species",
14
+ "data": iris_data.decode("utf-8"),
15
  }
16
 
17
  r = requests.post(
 
20
  )
21
 
22
  if __name__ == "__main__":
23
+ print(r.json())
go.mod DELETED
@@ -1,24 +0,0 @@
1
- module github.com/Jensen-holm/ml-from-scratch
2
-
3
- go 1.19
4
-
5
- require (
6
- github.com/go-gota/gota v0.12.0
7
- github.com/gofiber/fiber/v2 v2.49.2
8
- )
9
-
10
- require (
11
- github.com/andybalholm/brotli v1.0.5 // indirect
12
- github.com/google/uuid v1.3.1 // indirect
13
- github.com/klauspost/compress v1.16.7 // indirect
14
- github.com/mattn/go-colorable v0.1.13 // indirect
15
- github.com/mattn/go-isatty v0.0.19 // indirect
16
- github.com/mattn/go-runewidth v0.0.15 // indirect
17
- github.com/rivo/uniseg v0.2.0 // indirect
18
- github.com/valyala/bytebufferpool v1.0.0 // indirect
19
- github.com/valyala/fasthttp v1.49.0 // indirect
20
- github.com/valyala/tcplisten v1.0.0 // indirect
21
- golang.org/x/net v0.17.0 // indirect
22
- golang.org/x/sys v0.13.0 // indirect
23
- gonum.org/v1/gonum v0.14.0 // indirect
24
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
go.sum DELETED
@@ -1,109 +0,0 @@
1
- dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
2
- gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8=
3
- github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
4
- github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
5
- github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs=
6
- github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
7
- github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
8
- github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
9
- github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
10
- github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
11
- github.com/go-fonts/dejavu v0.1.0/go.mod h1:4Wt4I4OU2Nq9asgDCteaAaWZOV24E+0/Pwo0gppep4g=
12
- github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3T0ecnM9pNujks=
13
- github.com/go-fonts/liberation v0.1.1/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY=
14
- github.com/go-fonts/stix v0.1.0/go.mod h1:w/c1f0ldAUlJmLBvlbkvVXLAD+tAMqobIIQpmnUIzUY=
15
- github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
16
- github.com/go-gota/gota v0.12.0 h1:T5BDg1hTf5fZ/CO+T/N0E+DDqUhvoKBl+UVckgcAAQg=
17
- github.com/go-gota/gota v0.12.0/go.mod h1:UT+NsWpZC/FhaOyWb9Hui0jXg0Iq8e/YugZHTbyW/34=
18
- github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U=
19
- github.com/gofiber/fiber/v2 v2.49.2 h1:ONEN3/Vc+dUCxxDgZZwpqvhISgHqb+bu+isBiEyKEQs=
20
- github.com/gofiber/fiber/v2 v2.49.2/go.mod h1:gNsKnyrmfEWFpJxQAV0qvW6l70K1dZGno12oLtukcts=
21
- github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
22
- github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
23
- github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
24
- github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
25
- github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
26
- github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I=
27
- github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
28
- github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
29
- github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
30
- github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
31
- github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
32
- github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
33
- github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
34
- github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
35
- github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2dXMnm1mY=
36
- github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
37
- github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
38
- github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
39
- github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
40
- github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
41
- github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
42
- github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w=
43
- github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
44
- github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
45
- github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
46
- github.com/valyala/fasthttp v1.49.0 h1:9FdvCpmxB74LH4dPb7IJ1cOSsluR07XG3I1txXWwJpE=
47
- github.com/valyala/fasthttp v1.49.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA=
48
- github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8=
49
- github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
50
- golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
51
- golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
52
- golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
53
- golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
54
- golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
55
- golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
56
- golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3 h1:n9HxLrNxWWtEb1cA950nuEEj3QnKbtsCJ6KjcgisNUs=
57
- golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
58
- golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
59
- golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
60
- golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
61
- golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
62
- golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
63
- golang.org/x/image v0.0.0-20200119044424-58c23975cae1/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
64
- golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
65
- golang.org/x/image v0.0.0-20200618115811-c13761719519/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
66
- golang.org/x/image v0.0.0-20201208152932-35266b937fa6/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
67
- golang.org/x/image v0.0.0-20210216034530-4410531fe030/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
68
- golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
69
- golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
70
- golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
71
- golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
72
- golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
73
- golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
74
- golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
75
- golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
76
- golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
77
- golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
78
- golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
79
- golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
80
- golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
81
- golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
82
- golang.org/x/sys v0.0.0-20210304124612-50617c2ba197/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
83
- golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
84
- golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
85
- golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
86
- golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
87
- golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
88
- golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
89
- golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
90
- golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
91
- golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
92
- golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
93
- golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
94
- golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
95
- golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
96
- golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
97
- golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
98
- golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
99
- gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
100
- gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
101
- gonum.org/v1/gonum v0.9.1 h1:HCWmqqNoELL0RAQeKBXWtkp04mGk8koafcB4He6+uhc=
102
- gonum.org/v1/gonum v0.9.1/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0=
103
- gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
104
- gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
105
- gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc=
106
- gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
107
- gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
108
- gonum.org/v1/plot v0.9.0/go.mod h1:3Pcqqmp6RHvJI72kgb8fThyUnav364FOsdDo2aGW5lY=
109
- rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nn/activation.go DELETED
@@ -1,41 +0,0 @@
1
- package nn
2
-
3
- import "math"
4
-
5
- var ActivationMap = map[string]func(float64) float64{
6
- "sigmoid": Sigmoid,
7
- "tanh": Tanh,
8
- "relu": Relu,
9
- }
10
-
11
- func Sigmoid(x float64) float64 {
12
- return 1.0 / (1.0 + math.Exp(-x))
13
- }
14
-
15
- func SigmoidPrime(x float64) float64 {
16
- s := Sigmoid(x)
17
- return s / (1.0 - s)
18
- }
19
-
20
- func Tanh(x float64) float64 {
21
- return math.Tanh(x)
22
- }
23
-
24
- func TanhPrime(x float64) float64 {
25
- return math.Pow((1.0 / math.Cosh(x)), 2)
26
- }
27
-
28
- func Relu(x float64) float64 {
29
- if x > 0 {
30
- return x
31
- }
32
- return 0
33
- }
34
-
35
- func ReluPrime(x float64) float64 {
36
- // maybe want to look into edge case if x == 0
37
- if x > 0 {
38
- return 1
39
- }
40
- return 0
41
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nn/args.go DELETED
@@ -1,17 +0,0 @@
1
- package nn
2
-
3
- type NNArgs struct {
4
- epochs int
5
- hiddenSize int
6
- learningRate float64
7
- activationFunc func()
8
- }
9
-
10
- func NewArgs(argsMap map[string]interface{}) *NNArgs {
11
- return &NNArgs{
12
- epochs: argsMap["epochs"].(int),
13
- hiddenSize: argsMap["hidden_size"].(int),
14
- learningRate: argsMap["learning_rate"].(float64),
15
- activationFunc: argsMap["activation"].(func()),
16
- }
17
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nn/backprop.go DELETED
@@ -1,99 +0,0 @@
1
- package nn
2
-
3
- import (
4
- "fmt"
5
-
6
- "gonum.org/v1/gonum/mat"
7
- )
8
-
9
- func (nn *NN) Backprop() {
10
- var (
11
- activation = *nn.ActivationFunc
12
- // lossHist []float64
13
- )
14
-
15
- for i := 0; i < nn.Epochs; i++ {
16
- // compute output with current w + b
17
- // then compute loss & backprop
18
- hiddenOutput, err := computeOutput(
19
- nn.XTrain,
20
- nn.Wh,
21
- nn.Bh,
22
- activation,
23
- )
24
- if err != nil {
25
- fmt.Printf("error computing hidden output: %v", err)
26
- }
27
-
28
- yHat, err := computeOutput(
29
- hiddenOutput,
30
- nn.Wo,
31
- nn.Bo,
32
- activation,
33
- )
34
- if err != nil {
35
- fmt.Printf("error computing yHat: %v", err)
36
- }
37
-
38
- mse := meanSquaredError(nn.YTrain, yHat)
39
- fmt.Println(mse)
40
-
41
- }
42
-
43
- }
44
-
45
- func computeOutput(arr, w, b *mat.Dense, activationFunc func(float64) float64) (*mat.Dense, error) {
46
- // Check if any of the input matrices is nil
47
- if arr == nil || w == nil || b == nil {
48
- return nil, fmt.Errorf("Input matrices cannot be nil")
49
- }
50
-
51
- // Check input dimensions
52
- arrRows, arrCols := arr.Dims()
53
- wRows, wCols := w.Dims()
54
- bRows, bCols := b.Dims()
55
-
56
- if arrCols != wRows || bCols != wCols {
57
- return nil, fmt.Errorf("Matrix dimension mismatch: arr[%d, %d], w[%d, %d], b[%d, %d]", arrRows, arrCols, wRows, wCols, bRows, bCols)
58
- }
59
-
60
- // Compute the dot product between the input matrix 'arr' and the weight matrix 'w'
61
- var product mat.Dense
62
- product.Mul(arr, w)
63
-
64
- // Check dimensions of product and bias
65
- productRows, productCols := product.Dims()
66
- if productCols != bCols {
67
- return nil, fmt.Errorf("Matrix dimension mismatch: product[%d, %d], b[%d, %d]", productRows, productCols, bRows, bCols)
68
- }
69
-
70
- // Add the bias matrix 'b' to the product
71
- var result mat.Dense
72
- result.Add(&product, b)
73
-
74
- // Apply the activation function to the result
75
- applyActivation(&result, activationFunc)
76
-
77
- return &result, nil
78
- }
79
-
80
- func applyActivation(m *mat.Dense, f func(float64) float64) {
81
- r, c := m.Dims()
82
- data := m.RawMatrix().Data
83
- for i := 0; i < r*c; i++ {
84
- data[i] = f(data[i])
85
- }
86
- }
87
-
88
- func meanSquaredError(y, yHat *mat.Dense) float64 {
89
- var sum float64
90
- r, c := y.Dims()
91
-
92
- for row := 0; row < r; row++ {
93
- for col := 0; col < c; col++ {
94
- diff := y.At(row, col) - yHat.At(row, col)
95
- sum += (diff * diff)
96
- }
97
- }
98
- return sum / float64((r * c))
99
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nn/backprop.py ADDED
File without changes
nn/main.go DELETED
@@ -1,84 +0,0 @@
1
- package nn
2
-
3
- import (
4
- "fmt"
5
- "math/rand"
6
- "strings"
7
-
8
- "github.com/go-gota/gota/dataframe"
9
- "github.com/gofiber/fiber/v2"
10
- "gonum.org/v1/gonum/mat"
11
- )
12
-
13
- type NN struct {
14
- // attributes set by request
15
- CSVData string `json:"csv_data"`
16
- Features []string `json:"features"`
17
- Target string `json:"target"`
18
- Epochs int `json:"epochs"`
19
- HiddenSize int `json:"hidden_size"`
20
- LearningRate float64 `json:"learning_rate"`
21
- Activation string `json:"activation"`
22
- TestSize float64 `json:"test_size"`
23
-
24
- // attributes set after args above are parsed
25
- ActivationFunc *func(float64) float64
26
- Df *dataframe.DataFrame
27
- XTrain *mat.Dense
28
- YTrain *mat.Dense
29
- XTest *mat.Dense
30
- YTest *mat.Dense
31
- Wh *mat.Dense
32
- Bh *mat.Dense
33
- Wo *mat.Dense
34
- Bo *mat.Dense
35
- }
36
-
37
- func NewNN(c *fiber.Ctx) (*NN, error) {
38
- newNN := new(NN)
39
- err := c.BodyParser(newNN)
40
- if err != nil {
41
- return nil, fmt.Errorf("invalid JSON data: %v", err)
42
- }
43
- df := dataframe.ReadCSV(strings.NewReader(newNN.CSVData))
44
- activation := ActivationMap[newNN.Activation]
45
-
46
- newNN.Df = &df
47
- newNN.ActivationFunc = &activation
48
- return newNN, nil
49
- }
50
-
51
- func (nn *NN) InitWnB() {
52
- // randomly initialize weights and biases to start
53
- inputSize := len(nn.Features)
54
- hiddenSize := nn.HiddenSize
55
- outputSize := 1 // only predicting one thing
56
-
57
- // Initialize input hidden layer weights as a Gonum matrix
58
- wh := mat.NewDense(inputSize, hiddenSize, nil)
59
- wh.Apply(func(i, j int, v float64) float64 {
60
- return rand.Float64() - 0.5
61
- }, wh)
62
-
63
- // Initialize hidden layer bias as a Gonum matrix
64
- bh := mat.NewDense(1, hiddenSize, nil)
65
- bh.Apply(func(i, j int, v float64) float64 {
66
- return rand.Float64() - 0.5
67
- }, bh)
68
-
69
- // Initialize weights and biases for hidden -> output layer as Gonum matrices
70
- wo := mat.NewDense(hiddenSize, outputSize, nil)
71
- wo.Apply(func(i, j int, v float64) float64 {
72
- return rand.Float64() - 0.5
73
- }, wo)
74
-
75
- bo := mat.NewDense(1, outputSize, nil)
76
- bo.Apply(func(i, j int, v float64) float64 {
77
- return rand.Float64() - 0.5
78
- }, bo)
79
-
80
- nn.Wh = wh
81
- nn.Bh = bh
82
- nn.Wo = wo
83
- nn.Bo = bo
84
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nn/nn.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+
4
+ class NN:
5
+ def __init__(
6
+ self,
7
+ epochs: int,
8
+ hidden_size: int,
9
+ learning_rate: float,
10
+ test_size: float,
11
+ activation: str,
12
+ features: list[str],
13
+ target: str,
14
+ data: str,
15
+ ):
16
+ self.epochs = epochs
17
+ self.hidden_size = hidden_size
18
+ self.learning_rate = learning_rate
19
+ self.test_size = test_size
20
+ self.activation = activation
21
+ self.features = features
22
+ self.target = target
23
+ self.data = data
24
+
25
+ self.df: pd.DataFrame = None
26
+
27
+ @classmethod
28
+ def from_dict(cls, dct):
29
+ """ Creates an instance of NN given a dictionary
30
+ we can use this to make sure that the arguments are right
31
+ """
32
+ return cls(**dct)
nn/split.go DELETED
@@ -1,60 +0,0 @@
1
- package nn
2
-
3
- import (
4
- "math"
5
- "math/rand"
6
-
7
- "github.com/go-gota/gota/dataframe"
8
- "gonum.org/v1/gonum/mat"
9
- )
10
-
11
- func (nn *NN) TrainTestSplit() {
12
- // now we split the data into training
13
- // and testing based on user specified
14
- // nn.TestSize.
15
- nRows := nn.Df.Nrow()
16
- testRows := int(math.Floor(float64(nRows) * nn.TestSize))
17
-
18
- // subset the testing data
19
- // randomly select trainRows number of rows
20
- randStrt := rand.Intn(int(math.Floor(float64(nRows) * nn.TestSize)))
21
- test := nn.Df.Subset([]int{randStrt, randStrt + testRows})
22
-
23
- // use what is left for training
24
- allIndices := make([]int, nRows)
25
- for i := range allIndices {
26
- allIndices[i] = i
27
- }
28
-
29
- // Remove the test indices using slice append and variadic parameter
30
- trainIndices := append(allIndices[:randStrt], allIndices[randStrt+testRows:]...)
31
-
32
- // Create the train DataFrame using the trainIndices
33
- train := nn.Df.Subset(trainIndices)
34
-
35
- XTrain := train.Select(nn.Features)
36
- YTrain := train.Select(nn.Target)
37
- XTest := test.Select(nn.Features)
38
- YTest := test.Select(nn.Target)
39
-
40
- // to make linear algebra easier & faster,
41
- // we convert these dataframes that we are
42
- // performing potentially expensive computations
43
- // on into gonum matrices since we no longer need the
44
- // column names.
45
- nn.XTrain = df2mat(&XTrain)
46
- nn.YTrain = df2mat(&YTrain)
47
- nn.XTest = df2mat(&XTest)
48
- nn.YTest = df2mat(&YTest)
49
- }
50
-
51
- // df2mat -> converts gota dataframe into gonum matrix
52
- func df2mat(df *dataframe.DataFrame) *mat.Dense {
53
- m := mat.NewDense(df.Nrow(), df.Ncol(), nil)
54
- for i := 0; i < df.Nrow(); i++ {
55
- for j := 0; j < df.Ncol(); j++ {
56
- m.Set(i, j, df.Elem(i, j).Float())
57
- }
58
- }
59
- return m
60
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nn/subset.go DELETED
@@ -1,6 +0,0 @@
1
- package nn
2
-
3
- // subset the data frame into just the
4
- // features and target that the user specify
5
-
6
- func Subset() {}
 
 
 
 
 
 
 
nn/train.go DELETED
@@ -1,7 +0,0 @@
1
- package nn
2
-
3
- func (nn *NN) Train() {
4
- nn.InitWnB()
5
- nn.TrainTestSplit()
6
- nn.Backprop()
7
- }
 
 
 
 
 
 
 
 
nn/train.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.model_selection import train_test_split
2
+ from nn.nn import NN
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+
7
+ def train(nn: NN):
8
+ X_train, X_test, y_train, y_test = train_test_split(
9
+ nn.X,
10
+ nn.y,
11
+ test_size=nn.test_size,
12
+ random_state=88,
13
+ )
14
+
15
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Flask==3.0.0
2
+ pandas==2.1.1
3
+ Requests==2.31.0
server.go DELETED
@@ -1,26 +0,0 @@
1
- package main
2
-
3
- import (
4
- "github.com/Jensen-holm/ml-from-scratch/nn"
5
- "github.com/gofiber/fiber/v2"
6
- )
7
-
8
- func main() {
9
- app := fiber.New()
10
-
11
- // eventually we might want to add a key to this endpoint
12
- // that we will be able to validate.
13
- app.Post("/neural-network", func(c *fiber.Ctx) error {
14
- nn, err := nn.NewNN(c)
15
- if err != nil {
16
- return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
17
- "error": err,
18
- })
19
- }
20
-
21
- nn.Train()
22
- return c.SendString("No error")
23
- })
24
-
25
- app.Listen(":3000")
26
- }