Reaper200 commited on
Commit
9147ab9
1 Parent(s): 4097fd9

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +17 -0
model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class RNN_model(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+
8
+ self.rnn= nn.RNN(input_size=1080, hidden_size=240,num_layers=1, nonlinearity= 'relu', bias= True)
9
+ self.output= nn.Linear(in_features=240, out_features=24)
10
+
11
+ def forward(self, x):
12
+ y, hidden= self.rnn(x)
13
+ #print(y.shape)
14
+ #print(hidden.shape)
15
+ x= self.output(y)
16
+
17
+ return(x)