alisrbdni commited on
Commit
13907ea
·
verified ·
1 Parent(s): d97267e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +373 -8
app.py CHANGED
@@ -422,8 +422,379 @@
422
 
423
 
424
 
425
- ##############NEW
426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  import streamlit as st
428
  import matplotlib.pyplot as plt
429
  import torch
@@ -441,8 +812,6 @@ import logging
441
  import re
442
  import plotly.graph_objects as go
443
 
444
- # If you're curious of all the loggers
445
-
446
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
447
  fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
448
 
@@ -568,7 +937,7 @@ def parse_log(log_lines):
568
  clients = {}
569
  memory_usage = []
570
 
571
- round_pattern = re.compile(r'ROUND(\d+)ROUND (\d+)')
572
  client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
573
  memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
574
 
@@ -792,7 +1161,3 @@ def main():
792
  if __name__ == "__main__":
793
  main()
794
 
795
-
796
-
797
- #################
798
-
 
422
 
423
 
424
 
425
+ # ##############NEW
426
 
427
+ # import streamlit as st
428
+ # import matplotlib.pyplot as plt
429
+ # import torch
430
+ # from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
431
+ # from datasets import load_dataset, Dataset
432
+ # from evaluate import load as load_metric
433
+ # from torch.utils.data import DataLoader
434
+ # import pandas as pd
435
+ # import random
436
+ # from collections import OrderedDict
437
+ # import flwr as fl
438
+ # from logging import INFO, DEBUG
439
+ # from flwr.common.logger import log
440
+ # import logging
441
+ # import re
442
+ # import plotly.graph_objects as go
443
+
444
+ # # If you're curious of all the loggers
445
+
446
+ # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
447
+ # fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
448
+
449
+ # def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
450
+ # raw_datasets = load_dataset(dataset_name)
451
+ # raw_datasets = raw_datasets.shuffle(seed=42)
452
+ # del raw_datasets["unsupervised"]
453
+
454
+ # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
455
+
456
+ # def tokenize_function(examples):
457
+ # return tokenizer(examples["text"], truncation=True)
458
+
459
+ # tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
460
+ # tokenized_datasets = tokenized_datasets.remove_columns("text")
461
+ # tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
462
+
463
+ # train_datasets = []
464
+ # test_datasets = []
465
+
466
+ # for _ in range(num_clients):
467
+ # train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
468
+ # test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
469
+ # train_datasets.append(train_dataset)
470
+ # test_datasets.append(test_dataset)
471
+
472
+ # data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
473
+
474
+ # return train_datasets, test_datasets, data_collator, raw_datasets
475
+
476
+ # def train(net, trainloader, epochs):
477
+ # optimizer = AdamW(net.parameters(), lr=5e-5)
478
+ # net.train()
479
+ # for _ in range(epochs):
480
+ # for batch in trainloader:
481
+ # batch = {k: v.to(DEVICE) for k, v in batch.items()}
482
+ # outputs = net(**batch)
483
+ # loss = outputs.loss
484
+ # loss.backward()
485
+ # optimizer.step()
486
+ # optimizer.zero_grad()
487
+
488
+ # def test(net, testloader):
489
+ # metric = load_metric("accuracy")
490
+ # net.eval()
491
+ # loss = 0
492
+ # for batch in testloader:
493
+ # batch = {k: v.to(DEVICE) for k, v in batch.items()}
494
+ # with torch.no_grad():
495
+ # outputs = net(**batch)
496
+ # logits = outputs.logits
497
+ # loss += outputs.loss.item()
498
+ # predictions = torch.argmax(logits, dim=-1)
499
+ # metric.add_batch(predictions=predictions, references=batch["labels"])
500
+ # loss /= len(testloader)
501
+ # accuracy = metric.compute()["accuracy"]
502
+ # return loss, accuracy
503
+
504
+ # class CustomClient(fl.client.NumPyClient):
505
+ # def __init__(self, net, trainloader, testloader, client_id):
506
+ # self.net = net
507
+ # self.trainloader = trainloader
508
+ # self.testloader = testloader
509
+ # self.client_id = client_id
510
+ # self.losses = []
511
+ # self.accuracies = []
512
+
513
+ # def get_parameters(self, config):
514
+ # return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
515
+
516
+ # def set_parameters(self, parameters):
517
+ # params_dict = zip(self.net.state_dict().keys(), parameters)
518
+ # state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
519
+ # self.net.load_state_dict(state_dict, strict=True)
520
+
521
+ # def fit(self, parameters, config):
522
+ # log(INFO, f"Client {self.client_id} is starting fit()")
523
+ # self.set_parameters(parameters)
524
+ # train(self.net, self.trainloader, epochs=1)
525
+ # loss, accuracy = test(self.net, self.testloader)
526
+ # self.losses.append(loss)
527
+ # self.accuracies.append(accuracy)
528
+ # log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
529
+ # return self.get_parameters(config={}), len(self.trainloader.dataset), {"loss": loss, "accuracy": accuracy}
530
+
531
+ # def evaluate(self, parameters, config):
532
+ # log(INFO, f"Client {self.client_id} is starting evaluate()")
533
+ # self.set_parameters(parameters)
534
+ # loss, accuracy = test(self.net, self.testloader)
535
+ # log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
536
+ # return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy), "loss": float(loss)}
537
+
538
+ # def plot_metrics(self, round_num, plot_placeholder):
539
+ # if self.losses and self.accuracies:
540
+ # plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
541
+ # plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}")
542
+ # plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}")
543
+
544
+ # fig, ax1 = plt.subplots()
545
+
546
+ # color = 'tab:red'
547
+ # ax1.set_xlabel('Round')
548
+ # ax1.set_ylabel('Loss', color=color)
549
+ # ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color)
550
+ # ax1.tick_params(axis='y', labelcolor=color)
551
+
552
+ # ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
553
+ # color = 'tab:blue'
554
+ # ax2.set_ylabel('Accuracy', color=color)
555
+ # ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color)
556
+ # ax2.tick_params(axis='y', labelcolor=color)
557
+
558
+ # fig.tight_layout()
559
+ # plot_placeholder.pyplot(fig)
560
+
561
+ # def read_log_file(log_path='./log.txt'):
562
+ # with open(log_path, 'r') as file:
563
+ # log_lines = file.readlines()
564
+ # return log_lines
565
+
566
+ # def parse_log(log_lines):
567
+ # rounds = []
568
+ # clients = {}
569
+ # memory_usage = []
570
+
571
+ # round_pattern = re.compile(r'ROUND(\d+)ROUND (\d+)')
572
+ # client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
573
+ # memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
574
+
575
+ # current_round = None
576
+
577
+ # for line in log_lines:
578
+ # round_match = round_pattern.search(line)
579
+ # client_match = client_pattern.search(line)
580
+ # memory_match = memory_pattern.search(line)
581
+
582
+ # if round_match:
583
+ # current_round = int(round_match.group(1))
584
+ # rounds.append(current_round)
585
+ # elif client_match:
586
+ # client_id = int(client_match.group(1))
587
+ # log_level = client_match.group(2)
588
+ # message = client_match.group(3)
589
+
590
+ # if client_id not in clients:
591
+ # clients[client_id] = {'rounds': [], 'messages': []}
592
+
593
+ # clients[client_id]['rounds'].append(current_round)
594
+ # clients[client_id]['messages'].append((log_level, message))
595
+ # elif memory_match:
596
+ # memory_usage.append(float(memory_match.group(1)))
597
+
598
+ # return rounds, clients, memory_usage
599
+
600
+ # def plot_metrics(rounds, clients, memory_usage):
601
+ # st.write("## Metrics Overview")
602
+
603
+ # st.write("### Memory Usage")
604
+ # plt.figure()
605
+ # plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)')
606
+ # plt.xlabel('Step')
607
+ # plt.ylabel('Memory Usage (GB)')
608
+ # plt.legend()
609
+ # st.pyplot(plt)
610
+
611
+ # for client_id, data in clients.items():
612
+ # st.write(f"### Client {client_id} Metrics")
613
+
614
+ # info_messages = [msg for level, msg in data['messages'] if level == 'INFO']
615
+ # debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG']
616
+
617
+ # st.write("#### INFO Messages")
618
+ # for msg in info_messages:
619
+ # st.write(msg)
620
+
621
+ # st.write("#### DEBUG Messages")
622
+ # for msg in debug_messages:
623
+ # st.write(msg)
624
+
625
+ # # Placeholder for actual loss and accuracy values, assuming they're included in the messages
626
+ # losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg]
627
+ # accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg]
628
+
629
+ # if losses:
630
+ # plt.figure()
631
+ # plt.plot(data['rounds'], losses, label='Loss')
632
+ # plt.xlabel('Round')
633
+ # plt.ylabel('Loss')
634
+ # plt.legend()
635
+ # st.pyplot(plt)
636
+
637
+ # if accuracies:
638
+ # plt.figure()
639
+ # plt.plot(data['rounds'], accuracies, label='Accuracy')
640
+ # plt.xlabel('Round')
641
+ # plt.ylabel('Accuracy')
642
+ # plt.legend()
643
+ # st.pyplot(plt)
644
+
645
+ # def read_log_file2():
646
+ # with open("./log.txt", "r") as file:
647
+ # return file.read()
648
+
649
+ # def main():
650
+ # st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
651
+ # dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
652
+ # model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
653
+
654
+ # NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
655
+ # NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
656
+
657
+ # train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS)
658
+
659
+ # trainloaders = []
660
+ # testloaders = []
661
+ # clients = []
662
+
663
+ # for i in range(NUM_CLIENTS):
664
+ # st.write(f"### Client {i+1} Datasets")
665
+
666
+ # train_df = pd.DataFrame(train_datasets[i])
667
+ # test_df = pd.DataFrame(test_datasets[i])
668
+
669
+ # st.write("#### Train Dataset (Words)")
670
+ # st.dataframe(raw_datasets["train"].select(random.sample(range(len(raw_datasets["train"])), 20)))
671
+ # st.write("#### Train Dataset (Tokens)")
672
+ # edited_train_df = st.data_editor(train_df, key=f"train_{i}")
673
+
674
+ # st.write("#### Test Dataset (Words)")
675
+ # st.dataframe(raw_datasets["test"].select(random.sample(range(len(raw_datasets["test"])), 20)))
676
+ # st.write("#### Test Dataset (Tokens)")
677
+ # edited_test_df = st.data_editor(test_df, key=f"test_{i}")
678
+
679
+ # edited_train_dataset = Dataset.from_pandas(edited_train_df)
680
+ # edited_test_dataset = Dataset.from_pandas(edited_test_df)
681
+
682
+ # trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
683
+ # testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
684
+
685
+ # trainloaders.append(trainloader)
686
+ # testloaders.append(testloader)
687
+
688
+ # net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
689
+ # client = CustomClient(net, trainloader, testloader, client_id=i+1)
690
+ # clients.append(client)
691
+
692
+ # if st.button("Start Training"):
693
+ # def client_fn(cid):
694
+ # return clients[int(cid)]
695
+
696
+ # def weighted_average(metrics):
697
+ # accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
698
+ # losses = [num_examples * m["loss"] for num_examples, m in metrics]
699
+ # examples = [num_examples for num_examples, _ in metrics]
700
+ # return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
701
+
702
+ # strategy = fl.server.strategy.FedAvg(
703
+ # fraction_fit=1.0,
704
+ # fraction_evaluate=1.0,
705
+ # evaluate_metrics_aggregation_fn=weighted_average,
706
+ # )
707
+
708
+ # for round_num in range(NUM_ROUNDS):
709
+ # st.write(f"### Round {round_num + 1} ✅")
710
+
711
+ # logs = read_log_file2()
712
+ # st.markdown(logs)
713
+ # # Extract relevant data
714
+ # accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}")
715
+ # loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}")
716
+
717
+ # accuracy_matches = accuracy_pattern.findall(logs)
718
+ # loss_matches = loss_pattern.findall(logs)
719
+
720
+ # rounds = [int(match[0]) for match in accuracy_matches]
721
+ # accuracies = [float(match[1]) for match in accuracy_matches]
722
+ # losses = [float(match[1]) for match in loss_matches]
723
+
724
+ # # Create accuracy plot
725
+ # accuracy_fig = go.Figure()
726
+ # accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy'))
727
+ # accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy')
728
+
729
+ # # Create loss plot
730
+ # loss_fig = go.Figure()
731
+ # loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss'))
732
+ # loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss')
733
+
734
+ # # Display plots in Streamlit
735
+ # st.plotly_chart(accuracy_fig)
736
+ # st.plotly_chart(loss_fig)
737
+
738
+ # # Display data table
739
+ # data = {
740
+ # 'Round': rounds,
741
+ # 'Accuracy': accuracies,
742
+ # 'Loss': losses
743
+ # }
744
+
745
+ # df = pd.DataFrame(data)
746
+ # st.write("## Training Metrics")
747
+ # st.table(df)
748
+
749
+ # plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)]
750
+
751
+ # fl.simulation.start_simulation(
752
+ # client_fn=client_fn,
753
+ # num_clients=NUM_CLIENTS,
754
+ # config=fl.server.ServerConfig(num_rounds=1),
755
+ # strategy=strategy,
756
+ # client_resources={"num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)},
757
+ # ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)}
758
+ # )
759
+
760
+ # for i, client in enumerate(clients):
761
+ # client.plot_metrics(round_num + 1, plot_placeholders[i])
762
+ # st.write(" ")
763
+
764
+ # st.success("Training completed successfully!")
765
+
766
+ # # Display final metrics
767
+ # st.write("## Final Client Metrics")
768
+ # for client in clients:
769
+ # st.write(f"### Client {client.client_id}")
770
+ # if client.losses and client.accuracies:
771
+ # st.write(f"Final Loss: {client.losses[-1]:.4f}")
772
+ # st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
773
+ # client.plot_metrics(NUM_ROUNDS, st.empty())
774
+ # else:
775
+ # st.write("No metrics available.")
776
+
777
+ # st.write(" ")
778
+
779
+ # # Display log.txt content
780
+ # st.write("## Training Log")
781
+ # st.write(read_log_file2())
782
+
783
+ # st.write("## Training Log Analysis")
784
+ # log_lines = read_log_file()
785
+ # rounds, clients, memory_usage = parse_log(log_lines)
786
+
787
+ # plot_metrics(rounds, clients, memory_usage)
788
+
789
+ # else:
790
+ # st.write("Click the 'Start Training' button to start the training process.")
791
+
792
+ # if __name__ == "__main__":
793
+ # main()
794
+
795
+
796
+
797
+ # #################
798
  import streamlit as st
799
  import matplotlib.pyplot as plt
800
  import torch
 
812
  import re
813
  import plotly.graph_objects as go
814
 
 
 
815
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
816
  fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
817
 
 
937
  clients = {}
938
  memory_usage = []
939
 
940
+ round_pattern = re.compile(r'ROUND (\d+)')
941
  client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
942
  memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
943
 
 
1161
  if __name__ == "__main__":
1162
  main()
1163