alisrbdni commited on
Commit
f001673
·
verified ·
1 Parent(s): 4a865a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -234,7 +234,7 @@ import streamlit
234
 
235
  # If you're curious of all the loggers
236
 
237
- DEVICE = torch.device("cpu")
238
  fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
239
 
240
  def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
@@ -510,8 +510,8 @@ def main():
510
  num_clients=NUM_CLIENTS,
511
  config=fl.server.ServerConfig(num_rounds=1),
512
  strategy=strategy,
513
- client_resources={"num_cpus": 1, "num_gpus": 0},
514
- ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": 0}
515
  )
516
 
517
  for i, client in enumerate(clients):
 
234
 
235
  # If you're curious of all the loggers
236
 
237
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
238
  fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
239
 
240
  def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
 
510
  num_clients=NUM_CLIENTS,
511
  config=fl.server.ServerConfig(num_rounds=1),
512
  strategy=strategy,
513
+ client_resources={"num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)},
514
+ ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)}
515
  )
516
 
517
  for i, client in enumerate(clients):