ArxAlfa
commited on
Commit
•
66f8548
1
Parent(s):
d077299
Refactor DNN model to support variable number of
Browse files- app.py +22 -24
- model_weights.pth +0 -0
app.py
CHANGED
@@ -9,30 +9,30 @@ from sklearn.model_selection import train_test_split
|
|
9 |
import csv
|
10 |
import io
|
11 |
|
12 |
-
# from joblib import load, dump
|
13 |
-
|
14 |
|
15 |
# Define the DNN model
|
16 |
class DNN(nn.Module):
|
17 |
-
def __init__(self, input_size, hidden_size, output_size):
|
18 |
super(DNN, self).__init__()
|
19 |
self.fc1 = nn.Linear(input_size, hidden_size)
|
20 |
self.relu1 = nn.ReLU()
|
21 |
-
self.
|
22 |
-
|
|
|
|
|
23 |
self.fc3 = nn.Linear(hidden_size, output_size)
|
24 |
|
25 |
def forward(self, x):
|
26 |
x = self.fc1(x)
|
27 |
x = self.relu1(x)
|
28 |
-
|
29 |
-
|
30 |
x = self.fc3(x)
|
31 |
return x
|
32 |
|
33 |
|
34 |
# Load the model
|
35 |
-
model = DNN(input_size=6, hidden_size=
|
36 |
|
37 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
38 |
model = model.to(device)
|
@@ -71,23 +71,19 @@ def generate(
|
|
71 |
prediction = model(input_data)
|
72 |
return {"prediction": prediction.item()}
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
trainDatafile: UploadFile = File(...),
|
78 |
-
testDatafile: UploadFile = File(...),
|
79 |
-
epochs: int = 100,
|
80 |
-
):
|
81 |
-
global model
|
82 |
-
|
83 |
-
contents1 = await trainDatafile.read()
|
84 |
-
train_data = pd.read_csv(io.StringIO(contents1.decode("utf-8")))
|
85 |
-
|
86 |
-
contents2 = await testDatafile.read()
|
87 |
-
test_data = pd.read_csv(io.StringIO(contents2.decode("utf-8")))
|
88 |
|
89 |
-
|
90 |
-
|
91 |
|
92 |
# Convert data to numpy arrays
|
93 |
X_train = train_data.drop("Yield_kg_per_hectare", axis=1).values
|
@@ -128,7 +124,9 @@ async def train(
|
|
128 |
y_test.cpu().detach().numpy(), predictions.cpu().detach().numpy()
|
129 |
)
|
130 |
)
|
131 |
-
print(
|
|
|
|
|
132 |
rmseList.append(float(rmse))
|
133 |
|
134 |
torch.save(model.state_dict(), "model_weights.pth")
|
|
|
9 |
import csv
|
10 |
import io
|
11 |
|
|
|
|
|
12 |
|
13 |
# Define the DNN model
|
14 |
class DNN(nn.Module):
|
15 |
+
def __init__(self, input_size, hidden_size, output_size, num_hidden_layers):
|
16 |
super(DNN, self).__init__()
|
17 |
self.fc1 = nn.Linear(input_size, hidden_size)
|
18 |
self.relu1 = nn.ReLU()
|
19 |
+
self.hidden_layers = nn.ModuleList()
|
20 |
+
for _ in range(num_hidden_layers):
|
21 |
+
self.hidden_layers.append(nn.Linear(hidden_size, hidden_size))
|
22 |
+
self.hidden_layers.append(nn.ReLU())
|
23 |
self.fc3 = nn.Linear(hidden_size, output_size)
|
24 |
|
25 |
def forward(self, x):
|
26 |
x = self.fc1(x)
|
27 |
x = self.relu1(x)
|
28 |
+
for layer in self.hidden_layers:
|
29 |
+
x = layer(x)
|
30 |
x = self.fc3(x)
|
31 |
return x
|
32 |
|
33 |
|
34 |
# Load the model
|
35 |
+
model = DNN(input_size=6, hidden_size=64, output_size=1, num_hidden_layers=32)
|
36 |
|
37 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
38 |
model = model.to(device)
|
|
|
71 |
prediction = model(input_data)
|
72 |
return {"prediction": prediction.item()}
|
73 |
|
74 |
+
@app.post("/train")
|
75 |
+
async def train(
|
76 |
+
trainDatafile: UploadFile = File(...),
|
77 |
+
testDatafile: UploadFile = File(...),
|
78 |
+
epochs: int = 100,
|
79 |
+
):
|
80 |
+
global model
|
81 |
|
82 |
+
contents1 = await trainDatafile.read()
|
83 |
+
train_data = pd.read_csv(io.StringIO(contents1.decode("utf-8")))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
contents2 = await testDatafile.read()
|
86 |
+
test_data = pd.read_csv(io.StringIO(contents2.decode("utf-8")))
|
87 |
|
88 |
# Convert data to numpy arrays
|
89 |
X_train = train_data.drop("Yield_kg_per_hectare", axis=1).values
|
|
|
124 |
y_test.cpu().detach().numpy(), predictions.cpu().detach().numpy()
|
125 |
)
|
126 |
)
|
127 |
+
print(
|
128 |
+
f"Epoch: {epoch+1}, RMSE: {float(rmse)}, Loss: {float(np.sqrt(loss.cpu().detach().numpy()))}"
|
129 |
+
)
|
130 |
rmseList.append(float(rmse))
|
131 |
|
132 |
torch.save(model.state_dict(), "model_weights.pth")
|
model_weights.pth
CHANGED
Binary files a/model_weights.pth and b/model_weights.pth differ
|
|