Kevin Fink commited on
Commit
6397229
·
1 Parent(s): 1888d7d
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -28,8 +28,8 @@ def fine_tune_model(model_name, dataset_name, hub_id, api_key, num_epochs, batch
28
 
29
  max_length = 128
30
  try:
31
- tokenized_train_dataset = load_from_disk(f'{hub_id.strip()}_train_dataset')
32
- tokenized_test_dataset = load_from_disk(f'{hub_id.strip()}_test_dataset')
33
  tokenized_datasets = concatenate_datasets([tokenized_train_dataset, tokenized_test_dataset])
34
  except:
35
  # Tokenize the dataset
@@ -58,8 +58,8 @@ def fine_tune_model(model_name, dataset_name, hub_id, api_key, num_epochs, batch
58
 
59
  tokenized_datasets = dataset.map(tokenize_function, batched=True, batch_size=32)
60
 
61
- tokenized_datasets['train'].save_to_disk(f'{hub_id.strip()}_train_dataset')
62
- tokenized_datasets['test'].save_to_disk(f'{hub_id.strip()}_test_dataset')
63
 
64
 
65
  # Set training arguments
@@ -98,8 +98,7 @@ def fine_tune_model(model_name, dataset_name, hub_id, api_key, num_epochs, batch
98
  eval_dataset=tokenized_datasets['test'],
99
  #callbacks=[LoggingCallback()],
100
  )
101
- for batch in trainer.get_train_dataloader():
102
- print(batch['input_ids'].shape, batch['labels'].shape)
103
  # Fine-tune the model
104
  trainer.train()
105
  trainer.push_to_hub(commit_message="Training complete!")
 
28
 
29
  max_length = 128
30
  try:
31
+ tokenized_train_dataset = load_from_disk(f'data/{hub_id.strip()}_train_dataset')
32
+ tokenized_test_dataset = load_from_disk(f'data/{hub_id.strip()}_test_dataset')
33
  tokenized_datasets = concatenate_datasets([tokenized_train_dataset, tokenized_test_dataset])
34
  except:
35
  # Tokenize the dataset
 
58
 
59
  tokenized_datasets = dataset.map(tokenize_function, batched=True, batch_size=32)
60
 
61
+ tokenized_datasets['train'].save_to_disk(f'data/{hub_id.strip()}_train_dataset')
62
+ tokenized_datasets['test'].save_to_disk(f'data/{hub_id.strip()}_test_dataset')
63
 
64
 
65
  # Set training arguments
 
98
  eval_dataset=tokenized_datasets['test'],
99
  #callbacks=[LoggingCallback()],
100
  )
101
+
 
102
  # Fine-tune the model
103
  trainer.train()
104
  trainer.push_to_hub(commit_message="Training complete!")