ArxAlfa commited on
Commit
66f8548
1 Parent(s): d077299

Refactor DNN model to support variable number of

Browse files
Files changed (2) hide show
  1. app.py +22 -24
  2. model_weights.pth +0 -0
app.py CHANGED
@@ -9,30 +9,30 @@ from sklearn.model_selection import train_test_split
9
  import csv
10
  import io
11
 
12
- # from joblib import load, dump
13
-
14
 
15
  # Define the DNN model
16
  class DNN(nn.Module):
17
- def __init__(self, input_size, hidden_size, output_size):
18
  super(DNN, self).__init__()
19
  self.fc1 = nn.Linear(input_size, hidden_size)
20
  self.relu1 = nn.ReLU()
21
- self.fc2 = nn.Linear(hidden_size, hidden_size)
22
- self.relu2 = nn.ReLU()
 
 
23
  self.fc3 = nn.Linear(hidden_size, output_size)
24
 
25
  def forward(self, x):
26
  x = self.fc1(x)
27
  x = self.relu1(x)
28
- x = self.fc2(x)
29
- x = self.relu2(x)
30
  x = self.fc3(x)
31
  return x
32
 
33
 
34
  # Load the model
35
- model = DNN(input_size=6, hidden_size=256, output_size=1)
36
 
37
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
38
  model = model.to(device)
@@ -71,23 +71,19 @@ def generate(
71
  prediction = model(input_data)
72
  return {"prediction": prediction.item()}
73
 
 
 
 
 
 
 
 
74
 
75
- @app.post("/train")
76
- async def train(
77
- trainDatafile: UploadFile = File(...),
78
- testDatafile: UploadFile = File(...),
79
- epochs: int = 100,
80
- ):
81
- global model
82
-
83
- contents1 = await trainDatafile.read()
84
- train_data = pd.read_csv(io.StringIO(contents1.decode("utf-8")))
85
-
86
- contents2 = await testDatafile.read()
87
- test_data = pd.read_csv(io.StringIO(contents2.decode("utf-8")))
88
 
89
- # Load the training and testing data
90
- # test_data = pd.read_csv("dataset/agricultural_yield_test.csv")
91
 
92
  # Convert data to numpy arrays
93
  X_train = train_data.drop("Yield_kg_per_hectare", axis=1).values
@@ -128,7 +124,9 @@ async def train(
128
  y_test.cpu().detach().numpy(), predictions.cpu().detach().numpy()
129
  )
130
  )
131
- print(f"Epoch: {epoch+1}, RMSE: {float(rmse)}")
 
 
132
  rmseList.append(float(rmse))
133
 
134
  torch.save(model.state_dict(), "model_weights.pth")
 
9
  import csv
10
  import io
11
 
 
 
12
 
13
  # Define the DNN model
14
  class DNN(nn.Module):
15
+ def __init__(self, input_size, hidden_size, output_size, num_hidden_layers):
16
  super(DNN, self).__init__()
17
  self.fc1 = nn.Linear(input_size, hidden_size)
18
  self.relu1 = nn.ReLU()
19
+ self.hidden_layers = nn.ModuleList()
20
+ for _ in range(num_hidden_layers):
21
+ self.hidden_layers.append(nn.Linear(hidden_size, hidden_size))
22
+ self.hidden_layers.append(nn.ReLU())
23
  self.fc3 = nn.Linear(hidden_size, output_size)
24
 
25
  def forward(self, x):
26
  x = self.fc1(x)
27
  x = self.relu1(x)
28
+ for layer in self.hidden_layers:
29
+ x = layer(x)
30
  x = self.fc3(x)
31
  return x
32
 
33
 
34
  # Load the model
35
+ model = DNN(input_size=6, hidden_size=64, output_size=1, num_hidden_layers=32)
36
 
37
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
38
  model = model.to(device)
 
71
  prediction = model(input_data)
72
  return {"prediction": prediction.item()}
73
 
74
+ @app.post("/train")
75
+ async def train(
76
+ trainDatafile: UploadFile = File(...),
77
+ testDatafile: UploadFile = File(...),
78
+ epochs: int = 100,
79
+ ):
80
+ global model
81
 
82
+ contents1 = await trainDatafile.read()
83
+ train_data = pd.read_csv(io.StringIO(contents1.decode("utf-8")))
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ contents2 = await testDatafile.read()
86
+ test_data = pd.read_csv(io.StringIO(contents2.decode("utf-8")))
87
 
88
  # Convert data to numpy arrays
89
  X_train = train_data.drop("Yield_kg_per_hectare", axis=1).values
 
124
  y_test.cpu().detach().numpy(), predictions.cpu().detach().numpy()
125
  )
126
  )
127
+ print(
128
+ f"Epoch: {epoch+1}, RMSE: {float(rmse)}, Loss: {float(np.sqrt(loss.cpu().detach().numpy()))}"
129
+ )
130
  rmseList.append(float(rmse))
131
 
132
  torch.save(model.state_dict(), "model_weights.pth")
model_weights.pth CHANGED
Binary files a/model_weights.pth and b/model_weights.pth differ