alisrbdni commited on
Commit
e3bf7fc
·
verified ·
1 Parent(s): bde5538

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -0
app.py CHANGED
@@ -276,6 +276,38 @@ def train(net, trainloader, epochs):
276
  optimizer.step()
277
  optimizer.zero_grad()
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  def test(net, testloader):
280
  metric = load_metric("accuracy")
281
  net.eval()
 
276
  optimizer.step()
277
  optimizer.zero_grad()
278
 
279
+
280
+
281
+
282
+ # class SaveModelStrategy(fl.server.strategy.FedAvg):
283
+ # def aggregate_fit(
284
+ # self,
285
+ # server_round: int,
286
+ # results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
287
+ # failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
288
+ # ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
289
+ # """Aggregate model weights using weighted average and store checkpoint"""
290
+
291
+ # # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
292
+ # aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)
293
+
294
+ # if aggregated_parameters is not None:
295
+ # print(f"Saving round {server_round} aggregated_parameters...")
296
+
297
+ # # Convert `Parameters` to `List[np.ndarray]`
298
+ # aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)
299
+
300
+ # # Convert `List[np.ndarray]` to PyTorch`state_dict`
301
+ # params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
302
+ # state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
303
+ # net.load_state_dict(state_dict, strict=True)
304
+
305
+ # # Save the model
306
+ # torch.save(net.state_dict(), f"model_round_{server_round}.pth")
307
+
308
+ # return aggregated_parameters, aggregated_metrics
309
+
310
+
311
  def test(net, testloader):
312
  metric = load_metric("accuracy")
313
  net.eval()