Kevin Fink commited on
Commit
782c88b
·
1 Parent(s): 3270ada
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -115,7 +115,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
115
  max_length = model.get_input_embeddings().weight.shape[0]
116
  try:
117
  saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
118
- if os.access(f'/data/{hub_id.strip()}_validation_dataset'):
119
  dataset = load_dataset(dataset_name.strip())
120
  train_size = len(dataset['train'])
121
  third_size = train_size // 3
@@ -133,7 +133,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
133
  dataset['train'].save_to_disk(f'/data/{hub_id.strip()}_train_dataset3')
134
  return 'THIRD THIRD LOADED'
135
 
136
- if not os.access(f'/data/{hub_id.strip()}_train_dataset3'):
137
  train_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset3')
138
  if len(dataset['train']) == len(train_dataset['train']):
139
  dataset = load_dataset(dataset_name.strip())
@@ -154,7 +154,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
154
  eval_dataset=saved_test_dataset,
155
  compute_metrics=compute_metrics,
156
  )
157
- if os.access(f'/data/{hub_id.strip()}_train_dataset' and not os.access(f'/data/{hub_id.strip()}_train_dataset3')):
158
  dataset = load_dataset(dataset_name.strip())
159
  train_size = len(dataset['train'])
160
  third_size = train_size // 3
 
115
  max_length = model.get_input_embeddings().weight.shape[0]
116
  try:
117
  saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
118
+ if os.path.isfile(f'/data/{hub_id.strip()}_validation_dataset'):
119
  dataset = load_dataset(dataset_name.strip())
120
  train_size = len(dataset['train'])
121
  third_size = train_size // 3
 
133
  dataset['train'].save_to_disk(f'/data/{hub_id.strip()}_train_dataset3')
134
  return 'THIRD THIRD LOADED'
135
 
136
+ if not os.path.isfile(f'/data/{hub_id.strip()}_train_dataset3'):
137
  train_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset3')
138
  if len(dataset['train']) == len(train_dataset['train']):
139
  dataset = load_dataset(dataset_name.strip())
 
154
  eval_dataset=saved_test_dataset,
155
  compute_metrics=compute_metrics,
156
  )
157
+ if os.path.isfile(f'/data/{hub_id.strip()}_train_dataset' and not os.access(f'/data/{hub_id.strip()}_train_dataset3')):
158
  dataset = load_dataset(dataset_name.strip())
159
  train_size = len(dataset['train'])
160
  third_size = train_size // 3