alisrbdni commited on
Commit
e44dff4
·
verified ·
1 Parent(s): 11edc3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +266 -104
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  # # %%writefile app.py
2
 
3
  # import streamlit as st
@@ -11,8 +12,15 @@
11
  # import random
12
  # from collections import OrderedDict
13
  # import flwr as fl
 
 
 
 
 
 
14
 
15
- # DEVICE = torch.device("cpu")
 
16
 
17
  # def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
18
  # raw_datasets = load_dataset(dataset_name)
@@ -39,10 +47,8 @@
39
 
40
  # data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
41
 
42
- # return train_datasets, test_datasets, data_collator
43
- # def read_log_file():
44
- # with open("./log.txt", "r") as file:
45
- # return file.read()
46
  # def train(net, trainloader, epochs):
47
  # optimizer = AdamW(net.parameters(), lr=5e-5)
48
  # net.train()
@@ -55,6 +61,38 @@
55
  # optimizer.step()
56
  # optimizer.zero_grad()
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # def test(net, testloader):
59
  # metric = load_metric("accuracy")
60
  # net.eval()
@@ -89,17 +127,21 @@
89
  # self.net.load_state_dict(state_dict, strict=True)
90
 
91
  # def fit(self, parameters, config):
 
92
  # self.set_parameters(parameters)
93
  # train(self.net, self.trainloader, epochs=1)
94
  # loss, accuracy = test(self.net, self.testloader)
95
  # self.losses.append(loss)
96
  # self.accuracies.append(accuracy)
97
- # return self.get_parameters(config={}), len(self.trainloader.dataset), {}
 
98
 
99
  # def evaluate(self, parameters, config):
 
100
  # self.set_parameters(parameters)
101
  # loss, accuracy = test(self.net, self.testloader)
102
- # return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
 
103
 
104
  # def plot_metrics(self, round_num, plot_placeholder):
105
  # if self.losses and self.accuracies:
@@ -123,16 +165,107 @@
123
 
124
  # fig.tight_layout()
125
  # plot_placeholder.pyplot(fig)
 
 
 
 
 
 
 
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  # def main():
 
128
  # st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
129
  # dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
130
- # model_name = st.selectbox("Model", ["bert-base-uncased","facebook/hubert-base-ls960", "distilbert-base-uncased"])
131
 
132
  # NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
133
  # NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
134
 
135
- # train_datasets, test_datasets, data_collator = load_data(dataset_name, num_clients=NUM_CLIENTS)
136
 
137
  # trainloaders = []
138
  # testloaders = []
@@ -144,9 +277,14 @@
144
  # train_df = pd.DataFrame(train_datasets[i])
145
  # test_df = pd.DataFrame(test_datasets[i])
146
 
147
- # st.write("#### Train Dataset")
 
 
148
  # edited_train_df = st.data_editor(train_df, key=f"train_{i}")
149
- # st.write("#### Test Dataset")
 
 
 
150
  # edited_test_df = st.data_editor(test_df, key=f"test_{i}")
151
 
152
  # edited_train_dataset = Dataset.from_pandas(edited_train_df)
@@ -179,21 +317,73 @@
179
  # )
180
 
181
  # for round_num in range(NUM_ROUNDS):
182
- # st.write(f"### Round {round_num + 1}")
183
- # plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)]
184
- # fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  # fl.simulation.start_simulation(
187
  # client_fn=client_fn,
188
  # num_clients=NUM_CLIENTS,
189
  # config=fl.server.ServerConfig(num_rounds=1),
190
  # strategy=strategy,
191
- # client_resources={"num_cpus": 1, "num_gpus": 0},
192
- # ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
193
  # )
194
 
195
  # for i, client in enumerate(clients):
196
- # st.markdown("LOGS : "+ read_log_file())
197
  # client.plot_metrics(round_num + 1, plot_placeholders[i])
198
  # st.write(" ")
199
 
@@ -203,18 +393,36 @@
203
  # st.write("## Final Client Metrics")
204
  # for client in clients:
205
  # st.write(f"### Client {client.client_id}")
206
- # st.write(f"Final Loss: {client.losses[-1]:.4f}")
207
- # st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
208
- # client.plot_metrics(NUM_ROUNDS, st.empty())
 
 
 
 
209
  # st.write(" ")
210
 
 
 
 
 
 
 
 
 
 
 
211
  # else:
212
  # st.write("Click the 'Start Training' button to start the training process.")
213
 
214
  # if __name__ == "__main__":
215
  # main()
216
 
217
- # %%writefile app.py
 
 
 
 
218
 
219
  import streamlit as st
220
  import matplotlib.pyplot as plt
@@ -230,7 +438,8 @@ import flwr as fl
230
  from logging import INFO, DEBUG
231
  from flwr.common.logger import log
232
  import logging
233
- import streamlit
 
234
 
235
  # If you're curious of all the loggers
236
 
@@ -276,38 +485,6 @@ def train(net, trainloader, epochs):
276
  optimizer.step()
277
  optimizer.zero_grad()
278
 
279
-
280
-
281
-
282
- # class SaveModelStrategy(fl.server.strategy.FedAvg):
283
- # def aggregate_fit(
284
- # self,
285
- # server_round: int,
286
- # results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
287
- # failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
288
- # ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
289
- # """Aggregate model weights using weighted average and store checkpoint"""
290
-
291
- # # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
292
- # aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)
293
-
294
- # if aggregated_parameters is not None:
295
- # print(f"Saving round {server_round} aggregated_parameters...")
296
-
297
- # # Convert `Parameters` to `List[np.ndarray]`
298
- # aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)
299
-
300
- # # Convert `List[np.ndarray]` to PyTorch`state_dict`
301
- # params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
302
- # state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
303
- # net.load_state_dict(state_dict, strict=True)
304
-
305
- # # Save the model
306
- # torch.save(net.state_dict(), f"model_round_{server_round}.pth")
307
-
308
- # return aggregated_parameters, aggregated_metrics
309
-
310
-
311
  def test(net, testloader):
312
  metric = load_metric("accuracy")
313
  net.eval()
@@ -380,8 +557,6 @@ class CustomClient(fl.client.NumPyClient):
380
 
381
  fig.tight_layout()
382
  plot_placeholder.pyplot(fig)
383
- import matplotlib.pyplot as plt
384
- import re
385
 
386
  def read_log_file(log_path='./log.txt'):
387
  with open(log_path, 'r') as file:
@@ -392,18 +567,18 @@ def parse_log(log_lines):
392
  rounds = []
393
  clients = {}
394
  memory_usage = []
395
-
396
- round_pattern = re.compile(r'ROUND(\d+)ROUND (\d+)')
397
  client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
398
  memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
399
-
400
  current_round = None
401
-
402
  for line in log_lines:
403
  round_match = round_pattern.search(line)
404
  client_match = client_pattern.search(line)
405
  memory_match = memory_pattern.search(line)
406
-
407
  if round_match:
408
  current_round = int(round_match.group(1))
409
  rounds.append(current_round)
@@ -411,20 +586,20 @@ def parse_log(log_lines):
411
  client_id = int(client_match.group(1))
412
  log_level = client_match.group(2)
413
  message = client_match.group(3)
414
-
415
  if client_id not in clients:
416
  clients[client_id] = {'rounds': [], 'messages': []}
417
-
418
  clients[client_id]['rounds'].append(current_round)
419
  clients[client_id]['messages'].append((log_level, message))
420
  elif memory_match:
421
  memory_usage.append(float(memory_match.group(1)))
422
-
423
  return rounds, clients, memory_usage
424
 
425
  def plot_metrics(rounds, clients, memory_usage):
426
  st.write("## Metrics Overview")
427
-
428
  st.write("### Memory Usage")
429
  plt.figure()
430
  plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)')
@@ -432,25 +607,25 @@ def plot_metrics(rounds, clients, memory_usage):
432
  plt.ylabel('Memory Usage (GB)')
433
  plt.legend()
434
  st.pyplot(plt)
435
-
436
  for client_id, data in clients.items():
437
  st.write(f"### Client {client_id} Metrics")
438
-
439
  info_messages = [msg for level, msg in data['messages'] if level == 'INFO']
440
  debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG']
441
-
442
  st.write("#### INFO Messages")
443
  for msg in info_messages:
444
  st.write(msg)
445
-
446
  st.write("#### DEBUG Messages")
447
  for msg in debug_messages:
448
  st.write(msg)
449
-
450
  # Placeholder for actual loss and accuracy values, assuming they're included in the messages
451
  losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg]
452
  accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg]
453
-
454
  if losses:
455
  plt.figure()
456
  plt.plot(data['rounds'], losses, label='Loss')
@@ -458,7 +633,7 @@ def plot_metrics(rounds, clients, memory_usage):
458
  plt.ylabel('Loss')
459
  plt.legend()
460
  st.pyplot(plt)
461
-
462
  if accuracies:
463
  plt.figure()
464
  plt.plot(data['rounds'], accuracies, label='Accuracy')
@@ -467,12 +642,11 @@ def plot_metrics(rounds, clients, memory_usage):
467
  plt.legend()
468
  st.pyplot(plt)
469
 
470
-
471
  def read_log_file2():
472
  with open("./log.txt", "r") as file:
473
  return file.read()
474
- def main():
475
 
 
476
  st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
477
  dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
478
  model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
@@ -534,59 +708,44 @@ def main():
534
  for round_num in range(NUM_ROUNDS):
535
  st.write(f"### Round {round_num + 1} ✅")
536
 
537
- st.markdown(print(st.logger._loggers))
538
- st.markdown(read_log_file2())
539
  logs = read_log_file2()
540
- import re
541
- import plotly.graph_objects as go
542
- import streamlit as st
543
- import pandas as pd
544
-
545
- # Log data
546
- log_data = logs
547
-
548
  # Extract relevant data
549
- accuracy_pattern = re.compile(r"'accuracy': \((\d+),([\d.]+)\)\((\d+), ([\d.]+)\)")
550
- loss_pattern = re.compile(r"'loss': \((\d+),([\d.]+)\)\((\d+), ([\d.]+)\)")
551
-
552
- accuracy_matches = accuracy_pattern.findall(log_data)
553
- loss_matches = loss_pattern.findall(log_data)
554
-
555
  rounds = [int(match[0]) for match in accuracy_matches]
556
  accuracies = [float(match[1]) for match in accuracy_matches]
557
  losses = [float(match[1]) for match in loss_matches]
558
-
559
  # Create accuracy plot
560
  accuracy_fig = go.Figure()
561
  accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy'))
562
  accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy')
563
-
564
  # Create loss plot
565
  loss_fig = go.Figure()
566
  loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss'))
567
  loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss')
568
-
569
  # Display plots in Streamlit
570
  st.plotly_chart(accuracy_fig)
571
  st.plotly_chart(loss_fig)
572
-
573
  # Display data table
574
  data = {
575
  'Round': rounds,
576
  'Accuracy': accuracies,
577
  'Loss': losses
578
  }
579
-
580
  df = pd.DataFrame(data)
581
  st.write("## Training Metrics")
582
  st.table(df)
583
 
584
-
585
-
586
-
587
-
588
-
589
-
590
  plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)]
591
 
592
  fl.simulation.start_simulation(
@@ -619,12 +778,12 @@ def main():
619
 
620
  # Display log.txt content
621
  st.write("## Training Log")
622
- # st.text(read_log_file())
 
623
  st.write("## Training Log Analysis")
624
-
625
  log_lines = read_log_file()
626
  rounds, clients, memory_usage = parse_log(log_lines)
627
-
628
  plot_metrics(rounds, clients, memory_usage)
629
 
630
  else:
@@ -634,3 +793,6 @@ if __name__ == "__main__":
634
  main()
635
 
636
 
 
 
 
 
1
+
2
  # # %%writefile app.py
3
 
4
  # import streamlit as st
 
12
  # import random
13
  # from collections import OrderedDict
14
  # import flwr as fl
15
+ # from logging import INFO, DEBUG
16
+ # from flwr.common.logger import log
17
+ # import logging
18
+ # import streamlit
19
+
20
+ # # If you're curious of all the loggers
21
 
22
+ # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ # fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
24
 
25
  # def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
26
  # raw_datasets = load_dataset(dataset_name)
 
47
 
48
  # data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
49
 
50
+ # return train_datasets, test_datasets, data_collator, raw_datasets
51
+
 
 
52
  # def train(net, trainloader, epochs):
53
  # optimizer = AdamW(net.parameters(), lr=5e-5)
54
  # net.train()
 
61
  # optimizer.step()
62
  # optimizer.zero_grad()
63
 
64
+
65
+
66
+
67
+ # # class SaveModelStrategy(fl.server.strategy.FedAvg):
68
+ # # def aggregate_fit(
69
+ # # self,
70
+ # # server_round: int,
71
+ # # results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
72
+ # # failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
73
+ # # ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
74
+ # # """Aggregate model weights using weighted average and store checkpoint"""
75
+
76
+ # # # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
77
+ # # aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)
78
+
79
+ # # if aggregated_parameters is not None:
80
+ # # print(f"Saving round {server_round} aggregated_parameters...")
81
+
82
+ # # # Convert `Parameters` to `List[np.ndarray]`
83
+ # # aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)
84
+
85
+ # # # Convert `List[np.ndarray]` to PyTorch`state_dict`
86
+ # # params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
87
+ # # state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
88
+ # # net.load_state_dict(state_dict, strict=True)
89
+
90
+ # # # Save the model
91
+ # # torch.save(net.state_dict(), f"model_round_{server_round}.pth")
92
+
93
+ # # return aggregated_parameters, aggregated_metrics
94
+
95
+
96
  # def test(net, testloader):
97
  # metric = load_metric("accuracy")
98
  # net.eval()
 
127
  # self.net.load_state_dict(state_dict, strict=True)
128
 
129
  # def fit(self, parameters, config):
130
+ # log(INFO, f"Client {self.client_id} is starting fit()")
131
  # self.set_parameters(parameters)
132
  # train(self.net, self.trainloader, epochs=1)
133
  # loss, accuracy = test(self.net, self.testloader)
134
  # self.losses.append(loss)
135
  # self.accuracies.append(accuracy)
136
+ # log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
137
+ # return self.get_parameters(config={}), len(self.trainloader.dataset), {"loss": loss, "accuracy": accuracy}
138
 
139
  # def evaluate(self, parameters, config):
140
+ # log(INFO, f"Client {self.client_id} is starting evaluate()")
141
  # self.set_parameters(parameters)
142
  # loss, accuracy = test(self.net, self.testloader)
143
+ # log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
144
+ # return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy), "loss": float(loss)}
145
 
146
  # def plot_metrics(self, round_num, plot_placeholder):
147
  # if self.losses and self.accuracies:
 
165
 
166
  # fig.tight_layout()
167
  # plot_placeholder.pyplot(fig)
168
+ # import matplotlib.pyplot as plt
169
+ # import re
170
+
171
+ # def read_log_file(log_path='./log.txt'):
172
+ # with open(log_path, 'r') as file:
173
+ # log_lines = file.readlines()
174
+ # return log_lines
175
 
176
+ # def parse_log(log_lines):
177
+ # rounds = []
178
+ # clients = {}
179
+ # memory_usage = []
180
+
181
+ # round_pattern = re.compile(r'ROUND(\d+)ROUND (\d+)')
182
+ # client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
183
+ # memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
184
+
185
+ # current_round = None
186
+
187
+ # for line in log_lines:
188
+ # round_match = round_pattern.search(line)
189
+ # client_match = client_pattern.search(line)
190
+ # memory_match = memory_pattern.search(line)
191
+
192
+ # if round_match:
193
+ # current_round = int(round_match.group(1))
194
+ # rounds.append(current_round)
195
+ # elif client_match:
196
+ # client_id = int(client_match.group(1))
197
+ # log_level = client_match.group(2)
198
+ # message = client_match.group(3)
199
+
200
+ # if client_id not in clients:
201
+ # clients[client_id] = {'rounds': [], 'messages': []}
202
+
203
+ # clients[client_id]['rounds'].append(current_round)
204
+ # clients[client_id]['messages'].append((log_level, message))
205
+ # elif memory_match:
206
+ # memory_usage.append(float(memory_match.group(1)))
207
+
208
+ # return rounds, clients, memory_usage
209
+
210
+ # def plot_metrics(rounds, clients, memory_usage):
211
+ # st.write("## Metrics Overview")
212
+
213
+ # st.write("### Memory Usage")
214
+ # plt.figure()
215
+ # plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)')
216
+ # plt.xlabel('Step')
217
+ # plt.ylabel('Memory Usage (GB)')
218
+ # plt.legend()
219
+ # st.pyplot(plt)
220
+
221
+ # for client_id, data in clients.items():
222
+ # st.write(f"### Client {client_id} Metrics")
223
+
224
+ # info_messages = [msg for level, msg in data['messages'] if level == 'INFO']
225
+ # debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG']
226
+
227
+ # st.write("#### INFO Messages")
228
+ # for msg in info_messages:
229
+ # st.write(msg)
230
+
231
+ # st.write("#### DEBUG Messages")
232
+ # for msg in debug_messages:
233
+ # st.write(msg)
234
+
235
+ # # Placeholder for actual loss and accuracy values, assuming they're included in the messages
236
+ # losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg]
237
+ # accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg]
238
+
239
+ # if losses:
240
+ # plt.figure()
241
+ # plt.plot(data['rounds'], losses, label='Loss')
242
+ # plt.xlabel('Round')
243
+ # plt.ylabel('Loss')
244
+ # plt.legend()
245
+ # st.pyplot(plt)
246
+
247
+ # if accuracies:
248
+ # plt.figure()
249
+ # plt.plot(data['rounds'], accuracies, label='Accuracy')
250
+ # plt.xlabel('Round')
251
+ # plt.ylabel('Accuracy')
252
+ # plt.legend()
253
+ # st.pyplot(plt)
254
+
255
+
256
+ # def read_log_file2():
257
+ # with open("./log.txt", "r") as file:
258
+ # return file.read()
259
  # def main():
260
+
261
  # st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
262
  # dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
263
+ # model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
264
 
265
  # NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
266
  # NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
267
 
268
+ # train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS)
269
 
270
  # trainloaders = []
271
  # testloaders = []
 
277
  # train_df = pd.DataFrame(train_datasets[i])
278
  # test_df = pd.DataFrame(test_datasets[i])
279
 
280
+ # st.write("#### Train Dataset (Words)")
281
+ # st.dataframe(raw_datasets["train"].select(random.sample(range(len(raw_datasets["train"])), 20)))
282
+ # st.write("#### Train Dataset (Tokens)")
283
  # edited_train_df = st.data_editor(train_df, key=f"train_{i}")
284
+
285
+ # st.write("#### Test Dataset (Words)")
286
+ # st.dataframe(raw_datasets["test"].select(random.sample(range(len(raw_datasets["test"])), 20)))
287
+ # st.write("#### Test Dataset (Tokens)")
288
  # edited_test_df = st.data_editor(test_df, key=f"test_{i}")
289
 
290
  # edited_train_dataset = Dataset.from_pandas(edited_train_df)
 
317
  # )
318
 
319
  # for round_num in range(NUM_ROUNDS):
320
+ # st.write(f"### Round {round_num + 1}")
321
+
322
+ # st.markdown(print(st.logger._loggers))
323
+ # st.markdown(read_log_file2())
324
+ # logs = read_log_file2()
325
+ # import re
326
+ # import plotly.graph_objects as go
327
+ # import streamlit as st
328
+ # import pandas as pd
329
+
330
+ # # Log data
331
+ # log_data = logs
332
+
333
+ # # Extract relevant data
334
+ # accuracy_pattern = re.compile(r"'accuracy': \((\d+),([\d.]+)\)\((\d+), ([\d.]+)\)")
335
+ # loss_pattern = re.compile(r"'loss': \((\d+),([\d.]+)\)\((\d+), ([\d.]+)\)")
336
+
337
+ # accuracy_matches = accuracy_pattern.findall(log_data)
338
+ # loss_matches = loss_pattern.findall(log_data)
339
+
340
+ # rounds = [int(match[0]) for match in accuracy_matches]
341
+ # accuracies = [float(match[1]) for match in accuracy_matches]
342
+ # losses = [float(match[1]) for match in loss_matches]
343
 
344
+ # # Create accuracy plot
345
+ # accuracy_fig = go.Figure()
346
+ # accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy'))
347
+ # accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy')
348
+
349
+ # # Create loss plot
350
+ # loss_fig = go.Figure()
351
+ # loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss'))
352
+ # loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss')
353
+
354
+ # # Display plots in Streamlit
355
+ # st.plotly_chart(accuracy_fig)
356
+ # st.plotly_chart(loss_fig)
357
+
358
+ # # Display data table
359
+ # data = {
360
+ # 'Round': rounds,
361
+ # 'Accuracy': accuracies,
362
+ # 'Loss': losses
363
+ # }
364
+
365
+ # df = pd.DataFrame(data)
366
+ # st.write("## Training Metrics")
367
+ # st.table(df)
368
+
369
+
370
+
371
+
372
+
373
+
374
+
375
+ # plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)]
376
+
377
  # fl.simulation.start_simulation(
378
  # client_fn=client_fn,
379
  # num_clients=NUM_CLIENTS,
380
  # config=fl.server.ServerConfig(num_rounds=1),
381
  # strategy=strategy,
382
+ # client_resources={"num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)},
383
+ # ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)}
384
  # )
385
 
386
  # for i, client in enumerate(clients):
 
387
  # client.plot_metrics(round_num + 1, plot_placeholders[i])
388
  # st.write(" ")
389
 
 
393
  # st.write("## Final Client Metrics")
394
  # for client in clients:
395
  # st.write(f"### Client {client.client_id}")
396
+ # if client.losses and client.accuracies:
397
+ # st.write(f"Final Loss: {client.losses[-1]:.4f}")
398
+ # st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
399
+ # client.plot_metrics(NUM_ROUNDS, st.empty())
400
+ # else:
401
+ # st.write("No metrics available.")
402
+
403
  # st.write(" ")
404
 
405
+ # # Display log.txt content
406
+ # st.write("## Training Log")
407
+ # # st.text(read_log_file())
408
+ # st.write("## Training Log Analysis")
409
+
410
+ # log_lines = read_log_file()
411
+ # rounds, clients, memory_usage = parse_log(log_lines)
412
+
413
+ # plot_metrics(rounds, clients, memory_usage)
414
+
415
  # else:
416
  # st.write("Click the 'Start Training' button to start the training process.")
417
 
418
  # if __name__ == "__main__":
419
  # main()
420
 
421
+
422
+
423
+
424
+
425
+ ##############NEW
426
 
427
  import streamlit as st
428
  import matplotlib.pyplot as plt
 
438
  from logging import INFO, DEBUG
439
  from flwr.common.logger import log
440
  import logging
441
+ import re
442
+ import plotly.graph_objects as go
443
 
444
  # If you're curious of all the loggers
445
 
 
485
  optimizer.step()
486
  optimizer.zero_grad()
487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  def test(net, testloader):
489
  metric = load_metric("accuracy")
490
  net.eval()
 
557
 
558
  fig.tight_layout()
559
  plot_placeholder.pyplot(fig)
 
 
560
 
561
  def read_log_file(log_path='./log.txt'):
562
  with open(log_path, 'r') as file:
 
567
  rounds = []
568
  clients = {}
569
  memory_usage = []
570
+
571
+ round_pattern = re.compile(r'\[ROUND (\d+)\]')
572
  client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
573
  memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
574
+
575
  current_round = None
576
+
577
  for line in log_lines:
578
  round_match = round_pattern.search(line)
579
  client_match = client_pattern.search(line)
580
  memory_match = memory_pattern.search(line)
581
+
582
  if round_match:
583
  current_round = int(round_match.group(1))
584
  rounds.append(current_round)
 
586
  client_id = int(client_match.group(1))
587
  log_level = client_match.group(2)
588
  message = client_match.group(3)
589
+
590
  if client_id not in clients:
591
  clients[client_id] = {'rounds': [], 'messages': []}
592
+
593
  clients[client_id]['rounds'].append(current_round)
594
  clients[client_id]['messages'].append((log_level, message))
595
  elif memory_match:
596
  memory_usage.append(float(memory_match.group(1)))
597
+
598
  return rounds, clients, memory_usage
599
 
600
  def plot_metrics(rounds, clients, memory_usage):
601
  st.write("## Metrics Overview")
602
+
603
  st.write("### Memory Usage")
604
  plt.figure()
605
  plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)')
 
607
  plt.ylabel('Memory Usage (GB)')
608
  plt.legend()
609
  st.pyplot(plt)
610
+
611
  for client_id, data in clients.items():
612
  st.write(f"### Client {client_id} Metrics")
613
+
614
  info_messages = [msg for level, msg in data['messages'] if level == 'INFO']
615
  debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG']
616
+
617
  st.write("#### INFO Messages")
618
  for msg in info_messages:
619
  st.write(msg)
620
+
621
  st.write("#### DEBUG Messages")
622
  for msg in debug_messages:
623
  st.write(msg)
624
+
625
  # Placeholder for actual loss and accuracy values, assuming they're included in the messages
626
  losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg]
627
  accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg]
628
+
629
  if losses:
630
  plt.figure()
631
  plt.plot(data['rounds'], losses, label='Loss')
 
633
  plt.ylabel('Loss')
634
  plt.legend()
635
  st.pyplot(plt)
636
+
637
  if accuracies:
638
  plt.figure()
639
  plt.plot(data['rounds'], accuracies, label='Accuracy')
 
642
  plt.legend()
643
  st.pyplot(plt)
644
 
 
645
  def read_log_file2():
646
  with open("./log.txt", "r") as file:
647
  return file.read()
 
648
 
649
+ def main():
650
  st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
651
  dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
652
  model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
 
708
  for round_num in range(NUM_ROUNDS):
709
  st.write(f"### Round {round_num + 1} ✅")
710
 
 
 
711
  logs = read_log_file2()
712
+
 
 
 
 
 
 
 
713
  # Extract relevant data
714
+ accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}")
715
+ loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}")
716
+
717
+ accuracy_matches = accuracy_pattern.findall(logs)
718
+ loss_matches = loss_pattern.findall(logs)
719
+
720
  rounds = [int(match[0]) for match in accuracy_matches]
721
  accuracies = [float(match[1]) for match in accuracy_matches]
722
  losses = [float(match[1]) for match in loss_matches]
723
+
724
  # Create accuracy plot
725
  accuracy_fig = go.Figure()
726
  accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy'))
727
  accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy')
728
+
729
  # Create loss plot
730
  loss_fig = go.Figure()
731
  loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss'))
732
  loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss')
733
+
734
  # Display plots in Streamlit
735
  st.plotly_chart(accuracy_fig)
736
  st.plotly_chart(loss_fig)
737
+
738
  # Display data table
739
  data = {
740
  'Round': rounds,
741
  'Accuracy': accuracies,
742
  'Loss': losses
743
  }
744
+
745
  df = pd.DataFrame(data)
746
  st.write("## Training Metrics")
747
  st.table(df)
748
 
 
 
 
 
 
 
749
  plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)]
750
 
751
  fl.simulation.start_simulation(
 
778
 
779
  # Display log.txt content
780
  st.write("## Training Log")
781
+ st.write(read_log_file2())
782
+
783
  st.write("## Training Log Analysis")
 
784
  log_lines = read_log_file()
785
  rounds, clients, memory_usage = parse_log(log_lines)
786
+
787
  plot_metrics(rounds, clients, memory_usage)
788
 
789
  else:
 
793
  main()
794
 
795
 
796
+
797
+ #################
798
+