Kevin Fink commited on
Commit
a5454ef
·
1 Parent(s): 81f28e8
Files changed (1) hide show
  1. app.py +19 -14
app.py CHANGED
@@ -111,46 +111,51 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
111
  return model_inputs
112
 
113
  #max_length = 512
114
- # Load the dataset
115
- dataset = load_dataset(dataset_name.strip())
116
  train_size = len(dataset['train'])
117
  third_size = train_size // 3
118
  max_length = model.get_input_embeddings().weight.shape[0]
119
  try:
120
  saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
121
- if 'validation' in saved_dataset.keys():
122
- if 'test' in saved_dataset.keys():
 
 
 
123
  print("FOUND TEST")
124
- dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset3')
125
  # Create Trainer
126
  trainer = Trainer(
127
  model=model,
128
  args=training_args,
129
- train_dataset=tokenized_train_dataset,
130
- eval_dataset=tokenized_test_dataset,
131
  compute_metrics=compute_metrics,
132
  )
133
- else:
 
134
  print("FOUND VALIDATION")
135
  saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset2')
136
  third_third = dataset['train'].select(range(third_size*2, train_size))
137
  dataset['train'] = third_third
138
  tokenized_second_half = dataset.map(tokenize_function, batched=True)
139
  dataset['train'] = concatenate_datasets([saved_dataset['train'], tokenized_second_half['train']])
140
- tokenized_train_dataset = dataset['train']
141
- tokenized_test_dataset = dataset['test']
142
- dataset.save_to_disk(f'/data/{hub_id.strip()}_train_dataset3')
143
  return 'THIRD THIRD LOADED'
144
- else:
 
145
  second_third = dataset['train'].select(range(third_size, third_size*2))
146
  dataset['train'] = second_third
147
  del dataset['test']
148
  tokenized_sh_fq_dataset = dataset.map(tokenize_function, batched=True)
149
  dataset['train'] = concatenate_datasets([saved_dataset['train'], tokenized_sh_fq_dataset['train']])
150
- dataset.save_to_disk(f'/data/{hub_id.strip()}_train_dataset2')
 
151
  return 'SECOND THIRD LOADED'
152
 
153
- except:
 
154
  # Tokenize the dataset
155
  first_third = dataset['train'].select(range(third_size))
156
  dataset['train'] = first_third
 
111
  return model_inputs
112
 
113
  #max_length = 512
114
+ # Load the dataset
 
115
  train_size = len(dataset['train'])
116
  third_size = train_size // 3
117
  max_length = model.get_input_embeddings().weight.shape[0]
118
  try:
119
  saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
120
+ try:
121
+ load_from_disk(f'/data/{hub_id.strip()}_validation_dataset')
122
+ dataset = load_dataset(dataset_name.strip())
123
+ try:
124
+ saved_test_dataset = load_from_disk(f'/data/{hub_id.strip()}_test_dataset')
125
  print("FOUND TEST")
126
+ train_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset3')
127
  # Create Trainer
128
  trainer = Trainer(
129
  model=model,
130
  args=training_args,
131
+ train_dataset=train_dataset,
132
+ eval_dataset=saved_test_dataset,
133
  compute_metrics=compute_metrics,
134
  )
135
+ except:
136
+ dataset = load_dataset(dataset_name.strip())
137
  print("FOUND VALIDATION")
138
  saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset2')
139
  third_third = dataset['train'].select(range(third_size*2, train_size))
140
  dataset['train'] = third_third
141
  tokenized_second_half = dataset.map(tokenize_function, batched=True)
142
  dataset['train'] = concatenate_datasets([saved_dataset['train'], tokenized_second_half['train']])
143
+ dataset['train'].save_to_disk(f'/data/{hub_id.strip()}_train_dataset3')
144
+ dataset['test'].save_to_disk(f'/data/{hub_id.strip()}_test_dataset')
 
145
  return 'THIRD THIRD LOADED'
146
+ except:
147
+ dataset = load_dataset(dataset_name.strip())
148
  second_third = dataset['train'].select(range(third_size, third_size*2))
149
  dataset['train'] = second_third
150
  del dataset['test']
151
  tokenized_sh_fq_dataset = dataset.map(tokenize_function, batched=True)
152
  dataset['train'] = concatenate_datasets([saved_dataset['train'], tokenized_sh_fq_dataset['train']])
153
+ dataset['train'].save_to_disk(f'/data/{hub_id.strip()}_train_dataset2')
154
+ dataset['validation'].save_to_disk(f'/data/{hub_id.strip()}_validation_dataset')
155
  return 'SECOND THIRD LOADED'
156
 
157
+ except:
158
+ dataset = load_dataset(dataset_name.strip())
159
  # Tokenize the dataset
160
  first_third = dataset['train'].select(range(third_size))
161
  dataset['train'] = first_third