Kevin Fink
commited on
Commit
·
782c88b
1
Parent(s):
3270ada
dev
Browse files
app.py
CHANGED
@@ -115,7 +115,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
115 |
max_length = model.get_input_embeddings().weight.shape[0]
|
116 |
try:
|
117 |
saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
|
118 |
-
if os.
|
119 |
dataset = load_dataset(dataset_name.strip())
|
120 |
train_size = len(dataset['train'])
|
121 |
third_size = train_size // 3
|
@@ -133,7 +133,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
133 |
dataset['train'].save_to_disk(f'/data/{hub_id.strip()}_train_dataset3')
|
134 |
return 'THIRD THIRD LOADED'
|
135 |
|
136 |
-
if not os.
|
137 |
train_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset3')
|
138 |
if len(dataset['train']) == len(train_dataset['train']):
|
139 |
dataset = load_dataset(dataset_name.strip())
|
@@ -154,7 +154,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
154 |
eval_dataset=saved_test_dataset,
|
155 |
compute_metrics=compute_metrics,
|
156 |
)
|
157 |
-
if os.
|
158 |
dataset = load_dataset(dataset_name.strip())
|
159 |
train_size = len(dataset['train'])
|
160 |
third_size = train_size // 3
|
|
|
115 |
max_length = model.get_input_embeddings().weight.shape[0]
|
116 |
try:
|
117 |
saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
|
118 |
+
if os.path.isfile(f'/data/{hub_id.strip()}_validation_dataset'):
|
119 |
dataset = load_dataset(dataset_name.strip())
|
120 |
train_size = len(dataset['train'])
|
121 |
third_size = train_size // 3
|
|
|
133 |
dataset['train'].save_to_disk(f'/data/{hub_id.strip()}_train_dataset3')
|
134 |
return 'THIRD THIRD LOADED'
|
135 |
|
136 |
+
if not os.path.isfile(f'/data/{hub_id.strip()}_train_dataset3'):
|
137 |
train_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset3')
|
138 |
if len(dataset['train']) == len(train_dataset['train']):
|
139 |
dataset = load_dataset(dataset_name.strip())
|
|
|
154 |
eval_dataset=saved_test_dataset,
|
155 |
compute_metrics=compute_metrics,
|
156 |
)
|
157 |
+
if os.path.isfile(f'/data/{hub_id.strip()}_train_dataset' and not os.access(f'/data/{hub_id.strip()}_train_dataset3')):
|
158 |
dataset = load_dataset(dataset_name.strip())
|
159 |
train_size = len(dataset['train'])
|
160 |
third_size = train_size // 3
|