Update app.py
Browse files
app.py
CHANGED
@@ -422,8 +422,379 @@
|
|
422 |
|
423 |
|
424 |
|
425 |
-
##############NEW
|
426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
427 |
import streamlit as st
|
428 |
import matplotlib.pyplot as plt
|
429 |
import torch
|
@@ -441,8 +812,6 @@ import logging
|
|
441 |
import re
|
442 |
import plotly.graph_objects as go
|
443 |
|
444 |
-
# If you're curious of all the loggers
|
445 |
-
|
446 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
447 |
fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
|
448 |
|
@@ -568,7 +937,7 @@ def parse_log(log_lines):
|
|
568 |
clients = {}
|
569 |
memory_usage = []
|
570 |
|
571 |
-
round_pattern = re.compile(r'ROUND
|
572 |
client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
|
573 |
memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
|
574 |
|
@@ -792,7 +1161,3 @@ def main():
|
|
792 |
if __name__ == "__main__":
|
793 |
main()
|
794 |
|
795 |
-
|
796 |
-
|
797 |
-
#################
|
798 |
-
|
|
|
422 |
|
423 |
|
424 |
|
425 |
+
# ##############NEW
|
426 |
|
427 |
+
# import streamlit as st
|
428 |
+
# import matplotlib.pyplot as plt
|
429 |
+
# import torch
|
430 |
+
# from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
|
431 |
+
# from datasets import load_dataset, Dataset
|
432 |
+
# from evaluate import load as load_metric
|
433 |
+
# from torch.utils.data import DataLoader
|
434 |
+
# import pandas as pd
|
435 |
+
# import random
|
436 |
+
# from collections import OrderedDict
|
437 |
+
# import flwr as fl
|
438 |
+
# from logging import INFO, DEBUG
|
439 |
+
# from flwr.common.logger import log
|
440 |
+
# import logging
|
441 |
+
# import re
|
442 |
+
# import plotly.graph_objects as go
|
443 |
+
|
444 |
+
# # If you're curious of all the loggers
|
445 |
+
|
446 |
+
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
447 |
+
# fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
|
448 |
+
|
449 |
+
# def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
|
450 |
+
# raw_datasets = load_dataset(dataset_name)
|
451 |
+
# raw_datasets = raw_datasets.shuffle(seed=42)
|
452 |
+
# del raw_datasets["unsupervised"]
|
453 |
+
|
454 |
+
# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
455 |
+
|
456 |
+
# def tokenize_function(examples):
|
457 |
+
# return tokenizer(examples["text"], truncation=True)
|
458 |
+
|
459 |
+
# tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
|
460 |
+
# tokenized_datasets = tokenized_datasets.remove_columns("text")
|
461 |
+
# tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
462 |
+
|
463 |
+
# train_datasets = []
|
464 |
+
# test_datasets = []
|
465 |
+
|
466 |
+
# for _ in range(num_clients):
|
467 |
+
# train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
|
468 |
+
# test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
|
469 |
+
# train_datasets.append(train_dataset)
|
470 |
+
# test_datasets.append(test_dataset)
|
471 |
+
|
472 |
+
# data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
473 |
+
|
474 |
+
# return train_datasets, test_datasets, data_collator, raw_datasets
|
475 |
+
|
476 |
+
# def train(net, trainloader, epochs):
|
477 |
+
# optimizer = AdamW(net.parameters(), lr=5e-5)
|
478 |
+
# net.train()
|
479 |
+
# for _ in range(epochs):
|
480 |
+
# for batch in trainloader:
|
481 |
+
# batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
482 |
+
# outputs = net(**batch)
|
483 |
+
# loss = outputs.loss
|
484 |
+
# loss.backward()
|
485 |
+
# optimizer.step()
|
486 |
+
# optimizer.zero_grad()
|
487 |
+
|
488 |
+
# def test(net, testloader):
|
489 |
+
# metric = load_metric("accuracy")
|
490 |
+
# net.eval()
|
491 |
+
# loss = 0
|
492 |
+
# for batch in testloader:
|
493 |
+
# batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
494 |
+
# with torch.no_grad():
|
495 |
+
# outputs = net(**batch)
|
496 |
+
# logits = outputs.logits
|
497 |
+
# loss += outputs.loss.item()
|
498 |
+
# predictions = torch.argmax(logits, dim=-1)
|
499 |
+
# metric.add_batch(predictions=predictions, references=batch["labels"])
|
500 |
+
# loss /= len(testloader)
|
501 |
+
# accuracy = metric.compute()["accuracy"]
|
502 |
+
# return loss, accuracy
|
503 |
+
|
504 |
+
# class CustomClient(fl.client.NumPyClient):
|
505 |
+
# def __init__(self, net, trainloader, testloader, client_id):
|
506 |
+
# self.net = net
|
507 |
+
# self.trainloader = trainloader
|
508 |
+
# self.testloader = testloader
|
509 |
+
# self.client_id = client_id
|
510 |
+
# self.losses = []
|
511 |
+
# self.accuracies = []
|
512 |
+
|
513 |
+
# def get_parameters(self, config):
|
514 |
+
# return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
|
515 |
+
|
516 |
+
# def set_parameters(self, parameters):
|
517 |
+
# params_dict = zip(self.net.state_dict().keys(), parameters)
|
518 |
+
# state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
|
519 |
+
# self.net.load_state_dict(state_dict, strict=True)
|
520 |
+
|
521 |
+
# def fit(self, parameters, config):
|
522 |
+
# log(INFO, f"Client {self.client_id} is starting fit()")
|
523 |
+
# self.set_parameters(parameters)
|
524 |
+
# train(self.net, self.trainloader, epochs=1)
|
525 |
+
# loss, accuracy = test(self.net, self.testloader)
|
526 |
+
# self.losses.append(loss)
|
527 |
+
# self.accuracies.append(accuracy)
|
528 |
+
# log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
|
529 |
+
# return self.get_parameters(config={}), len(self.trainloader.dataset), {"loss": loss, "accuracy": accuracy}
|
530 |
+
|
531 |
+
# def evaluate(self, parameters, config):
|
532 |
+
# log(INFO, f"Client {self.client_id} is starting evaluate()")
|
533 |
+
# self.set_parameters(parameters)
|
534 |
+
# loss, accuracy = test(self.net, self.testloader)
|
535 |
+
# log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
|
536 |
+
# return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy), "loss": float(loss)}
|
537 |
+
|
538 |
+
# def plot_metrics(self, round_num, plot_placeholder):
|
539 |
+
# if self.losses and self.accuracies:
|
540 |
+
# plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
|
541 |
+
# plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}")
|
542 |
+
# plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}")
|
543 |
+
|
544 |
+
# fig, ax1 = plt.subplots()
|
545 |
+
|
546 |
+
# color = 'tab:red'
|
547 |
+
# ax1.set_xlabel('Round')
|
548 |
+
# ax1.set_ylabel('Loss', color=color)
|
549 |
+
# ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color)
|
550 |
+
# ax1.tick_params(axis='y', labelcolor=color)
|
551 |
+
|
552 |
+
# ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
|
553 |
+
# color = 'tab:blue'
|
554 |
+
# ax2.set_ylabel('Accuracy', color=color)
|
555 |
+
# ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color)
|
556 |
+
# ax2.tick_params(axis='y', labelcolor=color)
|
557 |
+
|
558 |
+
# fig.tight_layout()
|
559 |
+
# plot_placeholder.pyplot(fig)
|
560 |
+
|
561 |
+
# def read_log_file(log_path='./log.txt'):
|
562 |
+
# with open(log_path, 'r') as file:
|
563 |
+
# log_lines = file.readlines()
|
564 |
+
# return log_lines
|
565 |
+
|
566 |
+
# def parse_log(log_lines):
|
567 |
+
# rounds = []
|
568 |
+
# clients = {}
|
569 |
+
# memory_usage = []
|
570 |
+
|
571 |
+
# round_pattern = re.compile(r'ROUND(\d+)ROUND (\d+)')
|
572 |
+
# client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
|
573 |
+
# memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
|
574 |
+
|
575 |
+
# current_round = None
|
576 |
+
|
577 |
+
# for line in log_lines:
|
578 |
+
# round_match = round_pattern.search(line)
|
579 |
+
# client_match = client_pattern.search(line)
|
580 |
+
# memory_match = memory_pattern.search(line)
|
581 |
+
|
582 |
+
# if round_match:
|
583 |
+
# current_round = int(round_match.group(1))
|
584 |
+
# rounds.append(current_round)
|
585 |
+
# elif client_match:
|
586 |
+
# client_id = int(client_match.group(1))
|
587 |
+
# log_level = client_match.group(2)
|
588 |
+
# message = client_match.group(3)
|
589 |
+
|
590 |
+
# if client_id not in clients:
|
591 |
+
# clients[client_id] = {'rounds': [], 'messages': []}
|
592 |
+
|
593 |
+
# clients[client_id]['rounds'].append(current_round)
|
594 |
+
# clients[client_id]['messages'].append((log_level, message))
|
595 |
+
# elif memory_match:
|
596 |
+
# memory_usage.append(float(memory_match.group(1)))
|
597 |
+
|
598 |
+
# return rounds, clients, memory_usage
|
599 |
+
|
600 |
+
# def plot_metrics(rounds, clients, memory_usage):
|
601 |
+
# st.write("## Metrics Overview")
|
602 |
+
|
603 |
+
# st.write("### Memory Usage")
|
604 |
+
# plt.figure()
|
605 |
+
# plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)')
|
606 |
+
# plt.xlabel('Step')
|
607 |
+
# plt.ylabel('Memory Usage (GB)')
|
608 |
+
# plt.legend()
|
609 |
+
# st.pyplot(plt)
|
610 |
+
|
611 |
+
# for client_id, data in clients.items():
|
612 |
+
# st.write(f"### Client {client_id} Metrics")
|
613 |
+
|
614 |
+
# info_messages = [msg for level, msg in data['messages'] if level == 'INFO']
|
615 |
+
# debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG']
|
616 |
+
|
617 |
+
# st.write("#### INFO Messages")
|
618 |
+
# for msg in info_messages:
|
619 |
+
# st.write(msg)
|
620 |
+
|
621 |
+
# st.write("#### DEBUG Messages")
|
622 |
+
# for msg in debug_messages:
|
623 |
+
# st.write(msg)
|
624 |
+
|
625 |
+
# # Placeholder for actual loss and accuracy values, assuming they're included in the messages
|
626 |
+
# losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg]
|
627 |
+
# accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg]
|
628 |
+
|
629 |
+
# if losses:
|
630 |
+
# plt.figure()
|
631 |
+
# plt.plot(data['rounds'], losses, label='Loss')
|
632 |
+
# plt.xlabel('Round')
|
633 |
+
# plt.ylabel('Loss')
|
634 |
+
# plt.legend()
|
635 |
+
# st.pyplot(plt)
|
636 |
+
|
637 |
+
# if accuracies:
|
638 |
+
# plt.figure()
|
639 |
+
# plt.plot(data['rounds'], accuracies, label='Accuracy')
|
640 |
+
# plt.xlabel('Round')
|
641 |
+
# plt.ylabel('Accuracy')
|
642 |
+
# plt.legend()
|
643 |
+
# st.pyplot(plt)
|
644 |
+
|
645 |
+
# def read_log_file2():
|
646 |
+
# with open("./log.txt", "r") as file:
|
647 |
+
# return file.read()
|
648 |
+
|
649 |
+
# def main():
|
650 |
+
# st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
|
651 |
+
# dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
|
652 |
+
# model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
|
653 |
+
|
654 |
+
# NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
|
655 |
+
# NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
|
656 |
+
|
657 |
+
# train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS)
|
658 |
+
|
659 |
+
# trainloaders = []
|
660 |
+
# testloaders = []
|
661 |
+
# clients = []
|
662 |
+
|
663 |
+
# for i in range(NUM_CLIENTS):
|
664 |
+
# st.write(f"### Client {i+1} Datasets")
|
665 |
+
|
666 |
+
# train_df = pd.DataFrame(train_datasets[i])
|
667 |
+
# test_df = pd.DataFrame(test_datasets[i])
|
668 |
+
|
669 |
+
# st.write("#### Train Dataset (Words)")
|
670 |
+
# st.dataframe(raw_datasets["train"].select(random.sample(range(len(raw_datasets["train"])), 20)))
|
671 |
+
# st.write("#### Train Dataset (Tokens)")
|
672 |
+
# edited_train_df = st.data_editor(train_df, key=f"train_{i}")
|
673 |
+
|
674 |
+
# st.write("#### Test Dataset (Words)")
|
675 |
+
# st.dataframe(raw_datasets["test"].select(random.sample(range(len(raw_datasets["test"])), 20)))
|
676 |
+
# st.write("#### Test Dataset (Tokens)")
|
677 |
+
# edited_test_df = st.data_editor(test_df, key=f"test_{i}")
|
678 |
+
|
679 |
+
# edited_train_dataset = Dataset.from_pandas(edited_train_df)
|
680 |
+
# edited_test_dataset = Dataset.from_pandas(edited_test_df)
|
681 |
+
|
682 |
+
# trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
|
683 |
+
# testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
|
684 |
+
|
685 |
+
# trainloaders.append(trainloader)
|
686 |
+
# testloaders.append(testloader)
|
687 |
+
|
688 |
+
# net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
689 |
+
# client = CustomClient(net, trainloader, testloader, client_id=i+1)
|
690 |
+
# clients.append(client)
|
691 |
+
|
692 |
+
# if st.button("Start Training"):
|
693 |
+
# def client_fn(cid):
|
694 |
+
# return clients[int(cid)]
|
695 |
+
|
696 |
+
# def weighted_average(metrics):
|
697 |
+
# accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
|
698 |
+
# losses = [num_examples * m["loss"] for num_examples, m in metrics]
|
699 |
+
# examples = [num_examples for num_examples, _ in metrics]
|
700 |
+
# return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
|
701 |
+
|
702 |
+
# strategy = fl.server.strategy.FedAvg(
|
703 |
+
# fraction_fit=1.0,
|
704 |
+
# fraction_evaluate=1.0,
|
705 |
+
# evaluate_metrics_aggregation_fn=weighted_average,
|
706 |
+
# )
|
707 |
+
|
708 |
+
# for round_num in range(NUM_ROUNDS):
|
709 |
+
# st.write(f"### Round {round_num + 1} ✅")
|
710 |
+
|
711 |
+
# logs = read_log_file2()
|
712 |
+
# st.markdown(logs)
|
713 |
+
# # Extract relevant data
|
714 |
+
# accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}")
|
715 |
+
# loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}")
|
716 |
+
|
717 |
+
# accuracy_matches = accuracy_pattern.findall(logs)
|
718 |
+
# loss_matches = loss_pattern.findall(logs)
|
719 |
+
|
720 |
+
# rounds = [int(match[0]) for match in accuracy_matches]
|
721 |
+
# accuracies = [float(match[1]) for match in accuracy_matches]
|
722 |
+
# losses = [float(match[1]) for match in loss_matches]
|
723 |
+
|
724 |
+
# # Create accuracy plot
|
725 |
+
# accuracy_fig = go.Figure()
|
726 |
+
# accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy'))
|
727 |
+
# accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy')
|
728 |
+
|
729 |
+
# # Create loss plot
|
730 |
+
# loss_fig = go.Figure()
|
731 |
+
# loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss'))
|
732 |
+
# loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss')
|
733 |
+
|
734 |
+
# # Display plots in Streamlit
|
735 |
+
# st.plotly_chart(accuracy_fig)
|
736 |
+
# st.plotly_chart(loss_fig)
|
737 |
+
|
738 |
+
# # Display data table
|
739 |
+
# data = {
|
740 |
+
# 'Round': rounds,
|
741 |
+
# 'Accuracy': accuracies,
|
742 |
+
# 'Loss': losses
|
743 |
+
# }
|
744 |
+
|
745 |
+
# df = pd.DataFrame(data)
|
746 |
+
# st.write("## Training Metrics")
|
747 |
+
# st.table(df)
|
748 |
+
|
749 |
+
# plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)]
|
750 |
+
|
751 |
+
# fl.simulation.start_simulation(
|
752 |
+
# client_fn=client_fn,
|
753 |
+
# num_clients=NUM_CLIENTS,
|
754 |
+
# config=fl.server.ServerConfig(num_rounds=1),
|
755 |
+
# strategy=strategy,
|
756 |
+
# client_resources={"num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)},
|
757 |
+
# ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)}
|
758 |
+
# )
|
759 |
+
|
760 |
+
# for i, client in enumerate(clients):
|
761 |
+
# client.plot_metrics(round_num + 1, plot_placeholders[i])
|
762 |
+
# st.write(" ")
|
763 |
+
|
764 |
+
# st.success("Training completed successfully!")
|
765 |
+
|
766 |
+
# # Display final metrics
|
767 |
+
# st.write("## Final Client Metrics")
|
768 |
+
# for client in clients:
|
769 |
+
# st.write(f"### Client {client.client_id}")
|
770 |
+
# if client.losses and client.accuracies:
|
771 |
+
# st.write(f"Final Loss: {client.losses[-1]:.4f}")
|
772 |
+
# st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
|
773 |
+
# client.plot_metrics(NUM_ROUNDS, st.empty())
|
774 |
+
# else:
|
775 |
+
# st.write("No metrics available.")
|
776 |
+
|
777 |
+
# st.write(" ")
|
778 |
+
|
779 |
+
# # Display log.txt content
|
780 |
+
# st.write("## Training Log")
|
781 |
+
# st.write(read_log_file2())
|
782 |
+
|
783 |
+
# st.write("## Training Log Analysis")
|
784 |
+
# log_lines = read_log_file()
|
785 |
+
# rounds, clients, memory_usage = parse_log(log_lines)
|
786 |
+
|
787 |
+
# plot_metrics(rounds, clients, memory_usage)
|
788 |
+
|
789 |
+
# else:
|
790 |
+
# st.write("Click the 'Start Training' button to start the training process.")
|
791 |
+
|
792 |
+
# if __name__ == "__main__":
|
793 |
+
# main()
|
794 |
+
|
795 |
+
|
796 |
+
|
797 |
+
# #################
|
798 |
import streamlit as st
|
799 |
import matplotlib.pyplot as plt
|
800 |
import torch
|
|
|
812 |
import re
|
813 |
import plotly.graph_objects as go
|
814 |
|
|
|
|
|
815 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
816 |
fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
|
817 |
|
|
|
937 |
clients = {}
|
938 |
memory_usage = []
|
939 |
|
940 |
+
round_pattern = re.compile(r'ROUND (\d+)')
|
941 |
client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
|
942 |
memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
|
943 |
|
|
|
1161 |
if __name__ == "__main__":
|
1162 |
main()
|
1163 |
|
|
|
|
|
|
|
|