alisrbdni commited on
Commit
ea1705a
·
verified ·
1 Parent(s): 9dc118b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -45
app.py CHANGED
@@ -392,11 +392,11 @@
392
  # if __name__ == "__main__":
393
  # main()
394
 
395
-
396
  import streamlit as st
397
  import matplotlib.pyplot as plt
398
  import torch
399
- from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
 
400
  from datasets import load_dataset, Dataset
401
  from evaluate import load as load_metric
402
  from torch.utils.data import DataLoader
@@ -413,35 +413,39 @@ import plotly.graph_objects as go
413
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
414
  fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
415
 
416
- class CustomDataCollator(DataCollatorWithPadding):
 
 
 
417
  def __call__(self, features):
418
- if 'input_ids' in features[0] and isinstance(features[0]['input_ids'][0], int):
419
- # Handle byte encoding case
420
- max_length = max(len(f["input_ids"]) for f in features)
421
- for f in features:
422
- f['input_ids'] += [0] * (max_length - len(f['input_ids']))
423
- return super().__call__(features)
424
-
425
- def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8=False):
426
  raw_datasets = load_dataset(dataset_name)
427
  raw_datasets = raw_datasets.shuffle(seed=42)
428
  del raw_datasets["unsupervised"]
429
 
430
- if not use_utf8:
431
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
432
 
433
- def tokenize_function(examples):
434
- return tokenizer(examples["text"], truncation=True)
 
435
 
436
- tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
437
  tokenized_datasets = tokenized_datasets.remove_columns("text")
438
  tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
439
  else:
440
- def utf8_encode_function(examples):
441
- examples["input_ids"] = [list(text.encode('utf-8')) for text in examples["text"]]
442
- return examples
443
 
444
- tokenized_datasets = raw_datasets.map(utf8_encode_function, batched=True)
 
 
 
445
  tokenized_datasets = tokenized_datasets.remove_columns("text")
446
  tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
447
 
@@ -454,7 +458,7 @@ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8
454
  train_datasets.append(train_dataset)
455
  test_datasets.append(test_dataset)
456
 
457
- data_collator = CustomDataCollator(tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased"))
458
 
459
  return train_datasets, test_datasets, data_collator, raw_datasets
460
 
@@ -634,15 +638,11 @@ def read_log_file2():
634
  def main():
635
  st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
636
  logs = read_log_file2()
637
- # cleanLogs = # Define a pattern to match relevant log entries
638
  pattern = re.compile(r"memory|loss|accuracy|round|client", re.IGNORECASE)
639
 
640
-
641
- # Filter the log data
642
  filtered_logs = [line for line in logs.splitlines() if pattern.search(line)]
643
  st.markdown(filtered_logs)
644
 
645
- # Provide a download button for the logs
646
  st.download_button(
647
  label="Download Logs",
648
  data="\n".join(filtered_logs),
@@ -650,13 +650,13 @@ def main():
650
  mime="text/plain"
651
  )
652
  dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
653
- model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
654
 
655
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
656
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
657
  use_utf8 = st.checkbox("Train on Byte UTF-8 Dataset", value=False)
658
 
659
- train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS, use_utf8=use_utf8)
660
 
661
  trainloaders = []
662
  testloaders = []
@@ -684,9 +684,6 @@ def main():
684
  trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
685
  testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
686
 
687
- trainloaders.append(trainloader)
688
- testloaders.append(testloader)
689
-
690
  net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
691
  client = CustomClient(net, trainloader, testloader, client_id=i+1)
692
  clients.append(client)
@@ -711,18 +708,10 @@ def main():
711
  st.write(f"### Round {round_num + 1} ✅")
712
 
713
  logs = read_log_file2()
714
- filtered_log_list = [line for line in logs.splitlines if pattern.search(line)]
715
  filtered_logs = "\n".join(filtered_log_list)
716
 
717
  st.markdown(filtered_logs)
718
- # Provide a download button for the logs
719
- # st.download_button(
720
- # label="Download Logs",
721
- # data=logs,
722
- # file_name="./log.txt",
723
- # mime="text/plain"
724
- # )
725
- # # Extract relevant data
726
  accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}")
727
  loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}")
728
 
@@ -733,21 +722,17 @@ def main():
733
  accuracies = [float(match[1]) for match in accuracy_matches]
734
  losses = [float(match[1]) for match in loss_matches]
735
 
736
- # Create accuracy plot
737
  accuracy_fig = go.Figure()
738
  accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy'))
739
  accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy')
740
 
741
- # Create loss plot
742
  loss_fig = go.Figure()
743
  loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss'))
744
  loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss')
745
 
746
- # Display plots in Streamlit
747
  st.plotly_chart(accuracy_fig)
748
  st.plotly_chart(loss_fig)
749
 
750
- # Display data table
751
  data = {
752
  'Round': rounds,
753
  'Accuracy': accuracies,
@@ -775,7 +760,6 @@ def main():
775
 
776
  st.success("Training completed successfully!")
777
 
778
- # Display final metrics
779
  st.write("## Final Client Metrics")
780
  for client in clients:
781
  st.write(f"### Client {client.client_id}")
@@ -788,7 +772,6 @@ def main():
788
 
789
  st.write(" ")
790
 
791
- # Display log.txt content
792
  st.write("## Training Log")
793
  st.write(read_log_file2())
794
 
 
392
  # if __name__ == "__main__":
393
  # main()
394
 
 
395
  import streamlit as st
396
  import matplotlib.pyplot as plt
397
  import torch
398
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
399
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
400
  from datasets import load_dataset, Dataset
401
  from evaluate import load as load_metric
402
  from torch.utils.data import DataLoader
 
413
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
414
  fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
415
 
416
+ class CustomDataCollator:
417
+ def __init__(self, pad_token_id=0):
418
+ self.pad_token_id = pad_token_id
419
+
420
  def __call__(self, features):
421
+ max_length = max(len(f["input_ids"]) for f in features)
422
+ for f in features:
423
+ f['input_ids'] += [self.pad_token_id] * (max_length - len(f['input_ids']))
424
+ batch = {k: torch.tensor([f[k] for f in features]) for k in features[0].keys()}
425
+ return batch
426
+
427
+ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8=False, model_name="bert-base-uncased"):
 
428
  raw_datasets = load_dataset(dataset_name)
429
  raw_datasets = raw_datasets.shuffle(seed=42)
430
  del raw_datasets["unsupervised"]
431
 
432
+ if model_name == "google/byt5-small":
433
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
434
 
435
+ def utf8_encode_function(examples):
436
+ examples["input_ids"] = [tokenizer(text.encode('utf-8'), return_tensors="pt")["input_ids"].squeeze().tolist() for text in examples["text"]]
437
+ return examples
438
 
439
+ tokenized_datasets = raw_datasets.map(utf8_encode_function, batched=True)
440
  tokenized_datasets = tokenized_datasets.remove_columns("text")
441
  tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
442
  else:
443
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
444
 
445
+ def tokenize_function(examples):
446
+ return tokenizer(examples["text"], truncation=True)
447
+
448
+ tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
449
  tokenized_datasets = tokenized_datasets.remove_columns("text")
450
  tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
451
 
 
458
  train_datasets.append(train_dataset)
459
  test_datasets.append(test_dataset)
460
 
461
+ data_collator = CustomDataCollator(pad_token_id=tokenizer.pad_token_id)
462
 
463
  return train_datasets, test_datasets, data_collator, raw_datasets
464
 
 
638
  def main():
639
  st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
640
  logs = read_log_file2()
 
641
  pattern = re.compile(r"memory|loss|accuracy|round|client", re.IGNORECASE)
642
 
 
 
643
  filtered_logs = [line for line in logs.splitlines() if pattern.search(line)]
644
  st.markdown(filtered_logs)
645
 
 
646
  st.download_button(
647
  label="Download Logs",
648
  data="\n".join(filtered_logs),
 
650
  mime="text/plain"
651
  )
652
  dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
653
+ model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased", "google/byt5-small"])
654
 
655
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
656
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
657
  use_utf8 = st.checkbox("Train on Byte UTF-8 Dataset", value=False)
658
 
659
+ train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS, use_utf8=use_utf8, model_name=model_name)
660
 
661
  trainloaders = []
662
  testloaders = []
 
684
  trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
685
  testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
686
 
 
 
 
687
  net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
688
  client = CustomClient(net, trainloader, testloader, client_id=i+1)
689
  clients.append(client)
 
708
  st.write(f"### Round {round_num + 1} ✅")
709
 
710
  logs = read_log_file2()
711
+ filtered_log_list = [line for line in logs.splitlines() if pattern.search(line)]
712
  filtered_logs = "\n".join(filtered_log_list)
713
 
714
  st.markdown(filtered_logs)
 
 
 
 
 
 
 
 
715
  accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}")
716
  loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}")
717
 
 
722
  accuracies = [float(match[1]) for match in accuracy_matches]
723
  losses = [float(match[1]) for match in loss_matches]
724
 
 
725
  accuracy_fig = go.Figure()
726
  accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy'))
727
  accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy')
728
 
 
729
  loss_fig = go.Figure()
730
  loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss'))
731
  loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss')
732
 
 
733
  st.plotly_chart(accuracy_fig)
734
  st.plotly_chart(loss_fig)
735
 
 
736
  data = {
737
  'Round': rounds,
738
  'Accuracy': accuracies,
 
760
 
761
  st.success("Training completed successfully!")
762
 
 
763
  st.write("## Final Client Metrics")
764
  for client in clients:
765
  st.write(f"### Client {client.client_id}")
 
772
 
773
  st.write(" ")
774
 
 
775
  st.write("## Training Log")
776
  st.write(read_log_file2())
777