Update app.py
Browse files
app.py
CHANGED
@@ -276,6 +276,38 @@ def train(net, trainloader, epochs):
|
|
276 |
optimizer.step()
|
277 |
optimizer.zero_grad()
|
278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
def test(net, testloader):
|
280 |
metric = load_metric("accuracy")
|
281 |
net.eval()
|
|
|
276 |
optimizer.step()
|
277 |
optimizer.zero_grad()
|
278 |
|
279 |
+
|
280 |
+
|
281 |
+
|
282 |
+
# class SaveModelStrategy(fl.server.strategy.FedAvg):
|
283 |
+
# def aggregate_fit(
|
284 |
+
# self,
|
285 |
+
# server_round: int,
|
286 |
+
# results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
|
287 |
+
# failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
|
288 |
+
# ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
|
289 |
+
# """Aggregate model weights using weighted average and store checkpoint"""
|
290 |
+
|
291 |
+
# # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
|
292 |
+
# aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)
|
293 |
+
|
294 |
+
# if aggregated_parameters is not None:
|
295 |
+
# print(f"Saving round {server_round} aggregated_parameters...")
|
296 |
+
|
297 |
+
# # Convert `Parameters` to `List[np.ndarray]`
|
298 |
+
# aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)
|
299 |
+
|
300 |
+
# # Convert `List[np.ndarray]` to PyTorch`state_dict`
|
301 |
+
# params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
|
302 |
+
# state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
|
303 |
+
# net.load_state_dict(state_dict, strict=True)
|
304 |
+
|
305 |
+
# # Save the model
|
306 |
+
# torch.save(net.state_dict(), f"model_round_{server_round}.pth")
|
307 |
+
|
308 |
+
# return aggregated_parameters, aggregated_metrics
|
309 |
+
|
310 |
+
|
311 |
def test(net, testloader):
|
312 |
metric = load_metric("accuracy")
|
313 |
net.eval()
|