Update app.py
Browse files
app.py
CHANGED
@@ -392,11 +392,11 @@
|
|
392 |
# if __name__ == "__main__":
|
393 |
# main()
|
394 |
|
395 |
-
|
396 |
import streamlit as st
|
397 |
import matplotlib.pyplot as plt
|
398 |
import torch
|
399 |
-
from transformers import AutoTokenizer,
|
|
|
400 |
from datasets import load_dataset, Dataset
|
401 |
from evaluate import load as load_metric
|
402 |
from torch.utils.data import DataLoader
|
@@ -413,35 +413,39 @@ import plotly.graph_objects as go
|
|
413 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
414 |
fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
|
415 |
|
416 |
-
class CustomDataCollator
|
|
|
|
|
|
|
417 |
def __call__(self, features):
|
418 |
-
|
419 |
-
|
420 |
-
max_length
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8=False):
|
426 |
raw_datasets = load_dataset(dataset_name)
|
427 |
raw_datasets = raw_datasets.shuffle(seed=42)
|
428 |
del raw_datasets["unsupervised"]
|
429 |
|
430 |
-
if
|
431 |
-
tokenizer =
|
432 |
|
433 |
-
def
|
434 |
-
|
|
|
435 |
|
436 |
-
tokenized_datasets = raw_datasets.map(
|
437 |
tokenized_datasets = tokenized_datasets.remove_columns("text")
|
438 |
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
439 |
else:
|
440 |
-
|
441 |
-
examples["input_ids"] = [list(text.encode('utf-8')) for text in examples["text"]]
|
442 |
-
return examples
|
443 |
|
444 |
-
|
|
|
|
|
|
|
445 |
tokenized_datasets = tokenized_datasets.remove_columns("text")
|
446 |
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
447 |
|
@@ -454,7 +458,7 @@ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8
|
|
454 |
train_datasets.append(train_dataset)
|
455 |
test_datasets.append(test_dataset)
|
456 |
|
457 |
-
data_collator = CustomDataCollator(tokenizer
|
458 |
|
459 |
return train_datasets, test_datasets, data_collator, raw_datasets
|
460 |
|
@@ -634,15 +638,11 @@ def read_log_file2():
|
|
634 |
def main():
|
635 |
st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
|
636 |
logs = read_log_file2()
|
637 |
-
# cleanLogs = # Define a pattern to match relevant log entries
|
638 |
pattern = re.compile(r"memory|loss|accuracy|round|client", re.IGNORECASE)
|
639 |
|
640 |
-
|
641 |
-
# Filter the log data
|
642 |
filtered_logs = [line for line in logs.splitlines() if pattern.search(line)]
|
643 |
st.markdown(filtered_logs)
|
644 |
|
645 |
-
# Provide a download button for the logs
|
646 |
st.download_button(
|
647 |
label="Download Logs",
|
648 |
data="\n".join(filtered_logs),
|
@@ -650,13 +650,13 @@ def main():
|
|
650 |
mime="text/plain"
|
651 |
)
|
652 |
dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
|
653 |
-
model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
|
654 |
|
655 |
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
|
656 |
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
|
657 |
use_utf8 = st.checkbox("Train on Byte UTF-8 Dataset", value=False)
|
658 |
|
659 |
-
train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS, use_utf8=use_utf8)
|
660 |
|
661 |
trainloaders = []
|
662 |
testloaders = []
|
@@ -684,9 +684,6 @@ def main():
|
|
684 |
trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
|
685 |
testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
|
686 |
|
687 |
-
trainloaders.append(trainloader)
|
688 |
-
testloaders.append(testloader)
|
689 |
-
|
690 |
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
691 |
client = CustomClient(net, trainloader, testloader, client_id=i+1)
|
692 |
clients.append(client)
|
@@ -711,18 +708,10 @@ def main():
|
|
711 |
st.write(f"### Round {round_num + 1} ✅")
|
712 |
|
713 |
logs = read_log_file2()
|
714 |
-
filtered_log_list = [line for line in logs.splitlines if pattern.search(line)]
|
715 |
filtered_logs = "\n".join(filtered_log_list)
|
716 |
|
717 |
st.markdown(filtered_logs)
|
718 |
-
# Provide a download button for the logs
|
719 |
-
# st.download_button(
|
720 |
-
# label="Download Logs",
|
721 |
-
# data=logs,
|
722 |
-
# file_name="./log.txt",
|
723 |
-
# mime="text/plain"
|
724 |
-
# )
|
725 |
-
# # Extract relevant data
|
726 |
accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}")
|
727 |
loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}")
|
728 |
|
@@ -733,21 +722,17 @@ def main():
|
|
733 |
accuracies = [float(match[1]) for match in accuracy_matches]
|
734 |
losses = [float(match[1]) for match in loss_matches]
|
735 |
|
736 |
-
# Create accuracy plot
|
737 |
accuracy_fig = go.Figure()
|
738 |
accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy'))
|
739 |
accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy')
|
740 |
|
741 |
-
# Create loss plot
|
742 |
loss_fig = go.Figure()
|
743 |
loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss'))
|
744 |
loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss')
|
745 |
|
746 |
-
# Display plots in Streamlit
|
747 |
st.plotly_chart(accuracy_fig)
|
748 |
st.plotly_chart(loss_fig)
|
749 |
|
750 |
-
# Display data table
|
751 |
data = {
|
752 |
'Round': rounds,
|
753 |
'Accuracy': accuracies,
|
@@ -775,7 +760,6 @@ def main():
|
|
775 |
|
776 |
st.success("Training completed successfully!")
|
777 |
|
778 |
-
# Display final metrics
|
779 |
st.write("## Final Client Metrics")
|
780 |
for client in clients:
|
781 |
st.write(f"### Client {client.client_id}")
|
@@ -788,7 +772,6 @@ def main():
|
|
788 |
|
789 |
st.write(" ")
|
790 |
|
791 |
-
# Display log.txt content
|
792 |
st.write("## Training Log")
|
793 |
st.write(read_log_file2())
|
794 |
|
|
|
392 |
# if __name__ == "__main__":
|
393 |
# main()
|
394 |
|
|
|
395 |
import streamlit as st
|
396 |
import matplotlib.pyplot as plt
|
397 |
import torch
|
398 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
|
399 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
400 |
from datasets import load_dataset, Dataset
|
401 |
from evaluate import load as load_metric
|
402 |
from torch.utils.data import DataLoader
|
|
|
413 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
414 |
fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
|
415 |
|
416 |
+
class CustomDataCollator:
|
417 |
+
def __init__(self, pad_token_id=0):
|
418 |
+
self.pad_token_id = pad_token_id
|
419 |
+
|
420 |
def __call__(self, features):
|
421 |
+
max_length = max(len(f["input_ids"]) for f in features)
|
422 |
+
for f in features:
|
423 |
+
f['input_ids'] += [self.pad_token_id] * (max_length - len(f['input_ids']))
|
424 |
+
batch = {k: torch.tensor([f[k] for f in features]) for k in features[0].keys()}
|
425 |
+
return batch
|
426 |
+
|
427 |
+
def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8=False, model_name="bert-base-uncased"):
|
|
|
428 |
raw_datasets = load_dataset(dataset_name)
|
429 |
raw_datasets = raw_datasets.shuffle(seed=42)
|
430 |
del raw_datasets["unsupervised"]
|
431 |
|
432 |
+
if model_name == "google/byt5-small":
|
433 |
+
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
434 |
|
435 |
+
def utf8_encode_function(examples):
|
436 |
+
examples["input_ids"] = [tokenizer(text.encode('utf-8'), return_tensors="pt")["input_ids"].squeeze().tolist() for text in examples["text"]]
|
437 |
+
return examples
|
438 |
|
439 |
+
tokenized_datasets = raw_datasets.map(utf8_encode_function, batched=True)
|
440 |
tokenized_datasets = tokenized_datasets.remove_columns("text")
|
441 |
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
442 |
else:
|
443 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
444 |
|
445 |
+
def tokenize_function(examples):
|
446 |
+
return tokenizer(examples["text"], truncation=True)
|
447 |
+
|
448 |
+
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
|
449 |
tokenized_datasets = tokenized_datasets.remove_columns("text")
|
450 |
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
451 |
|
|
|
458 |
train_datasets.append(train_dataset)
|
459 |
test_datasets.append(test_dataset)
|
460 |
|
461 |
+
data_collator = CustomDataCollator(pad_token_id=tokenizer.pad_token_id)
|
462 |
|
463 |
return train_datasets, test_datasets, data_collator, raw_datasets
|
464 |
|
|
|
638 |
def main():
|
639 |
st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
|
640 |
logs = read_log_file2()
|
|
|
641 |
pattern = re.compile(r"memory|loss|accuracy|round|client", re.IGNORECASE)
|
642 |
|
|
|
|
|
643 |
filtered_logs = [line for line in logs.splitlines() if pattern.search(line)]
|
644 |
st.markdown(filtered_logs)
|
645 |
|
|
|
646 |
st.download_button(
|
647 |
label="Download Logs",
|
648 |
data="\n".join(filtered_logs),
|
|
|
650 |
mime="text/plain"
|
651 |
)
|
652 |
dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
|
653 |
+
model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased", "google/byt5-small"])
|
654 |
|
655 |
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
|
656 |
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
|
657 |
use_utf8 = st.checkbox("Train on Byte UTF-8 Dataset", value=False)
|
658 |
|
659 |
+
train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS, use_utf8=use_utf8, model_name=model_name)
|
660 |
|
661 |
trainloaders = []
|
662 |
testloaders = []
|
|
|
684 |
trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
|
685 |
testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
|
686 |
|
|
|
|
|
|
|
687 |
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
688 |
client = CustomClient(net, trainloader, testloader, client_id=i+1)
|
689 |
clients.append(client)
|
|
|
708 |
st.write(f"### Round {round_num + 1} ✅")
|
709 |
|
710 |
logs = read_log_file2()
|
711 |
+
filtered_log_list = [line for line in logs.splitlines() if pattern.search(line)]
|
712 |
filtered_logs = "\n".join(filtered_log_list)
|
713 |
|
714 |
st.markdown(filtered_logs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
715 |
accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}")
|
716 |
loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}")
|
717 |
|
|
|
722 |
accuracies = [float(match[1]) for match in accuracy_matches]
|
723 |
losses = [float(match[1]) for match in loss_matches]
|
724 |
|
|
|
725 |
accuracy_fig = go.Figure()
|
726 |
accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy'))
|
727 |
accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy')
|
728 |
|
|
|
729 |
loss_fig = go.Figure()
|
730 |
loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss'))
|
731 |
loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss')
|
732 |
|
|
|
733 |
st.plotly_chart(accuracy_fig)
|
734 |
st.plotly_chart(loss_fig)
|
735 |
|
|
|
736 |
data = {
|
737 |
'Round': rounds,
|
738 |
'Accuracy': accuracies,
|
|
|
760 |
|
761 |
st.success("Training completed successfully!")
|
762 |
|
|
|
763 |
st.write("## Final Client Metrics")
|
764 |
for client in clients:
|
765 |
st.write(f"### Client {client.client_id}")
|
|
|
772 |
|
773 |
st.write(" ")
|
774 |
|
|
|
775 |
st.write("## Training Log")
|
776 |
st.write(read_log_file2())
|
777 |
|