Kevin Fink
commited on
Commit
·
ee2912e
1
Parent(s):
aeb71da
dev
Browse files
app.py
CHANGED
@@ -121,89 +121,109 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
121 |
# Load the dataset
|
122 |
column_names = ['text', 'target']
|
123 |
|
124 |
-
try:
|
125 |
-
saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
|
126 |
-
if os.access(f'/data/{hub_id.strip()}_test_dataset', os.R_OK):
|
127 |
-
train_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset3')
|
128 |
-
saved_test_dataset = load_from_disk(f'/data/{hub_id.strip()}_validation_dataset')
|
129 |
-
dataset = load_dataset(dataset_name.strip())
|
130 |
-
print("FOUND TEST")
|
131 |
-
|
132 |
-
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
|
133 |
-
trainer = Trainer(
|
134 |
-
model=model,
|
135 |
-
args=training_args,
|
136 |
-
train_dataset=train_dataset,
|
137 |
-
eval_dataset=saved_test_dataset['input_ids'],
|
138 |
-
compute_metrics=compute_metrics,
|
139 |
-
data_collator=data_collator,
|
140 |
-
|
141 |
-
)
|
142 |
|
143 |
-
elif os.access(f'/data/{hub_id.strip()}_train_dataset3', os.R_OK):
|
144 |
-
dataset = load_dataset(dataset_name.strip())
|
145 |
-
|
146 |
-
dataset['test'] = dataset['test'].select(range(50))
|
147 |
-
del dataset['train']
|
148 |
-
del dataset['validation']
|
149 |
-
test_set = dataset.map(tokenize_function, batched=True, batch_size=50, remove_columns=column_names,)
|
150 |
-
test_set['test'].save_to_disk(f'/data/{hub_id.strip()}_test_dataset')
|
151 |
-
return 'TRAINING DONE'
|
152 |
|
153 |
-
elif os.access(f'/data/{hub_id.strip()}_validation_dataset', os.R_OK):
|
154 |
-
dataset = load_dataset(dataset_name.strip())
|
155 |
-
dataset['train'] = dataset['train'].select(range(8000))
|
156 |
-
dataset['train'] = dataset['train'].select(range(1000))
|
157 |
-
train_size = len(dataset['train'])
|
158 |
-
third_size = train_size // 3
|
159 |
-
del dataset['test']
|
160 |
-
del dataset['validation']
|
161 |
-
print("FOUND VALIDATION")
|
162 |
-
saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset2')
|
163 |
-
third_third = dataset['train'].select(range(third_size*2, train_size))
|
164 |
-
dataset['train'] = third_third
|
165 |
-
|
166 |
-
tokenized_second_half = dataset.map(tokenize_function, batched=True, batch_size=50,remove_columns=column_names,)
|
167 |
-
dataset['train'] = concatenate_datasets([saved_dataset, tokenized_second_half['train']])
|
168 |
-
dataset['train'].save_to_disk(f'/data/{hub_id.strip()}_train_dataset3')
|
169 |
-
return 'THIRD THIRD LOADED'
|
170 |
|
171 |
|
172 |
-
if os.access(f'/data/{hub_id.strip()}_train_dataset', os.R_OK) and not os.access(f'/data/{hub_id.strip()}_train_dataset3', os.R_OK):
|
173 |
-
dataset = load_dataset(dataset_name.strip())
|
174 |
-
dataset['train'] = dataset['train'].select(range(1000))
|
175 |
-
dataset['validation'] = dataset['validation'].select(range(100))
|
176 |
-
|
177 |
-
|
178 |
-
train_size = len(dataset['train'])
|
179 |
-
third_size = train_size // 3
|
180 |
-
second_third = dataset['train'].select(range(third_size, third_size*2))
|
181 |
-
dataset['train'] = second_third
|
182 |
-
del dataset['test']
|
183 |
-
tokenized_sh_fq_dataset = dataset.map(tokenize_function, batched=True, batch_size=50, remove_columns=column_names,)
|
184 |
-
dataset['train'] = concatenate_datasets([saved_dataset['train'], tokenized_sh_fq_dataset['train']])
|
185 |
-
dataset['train'].save_to_disk(f'/data/{hub_id.strip()}_train_dataset2')
|
186 |
-
dataset['validation'].save_to_disk(f'/data/{hub_id.strip()}_validation_dataset')
|
187 |
-
return 'SECOND THIRD LOADED'
|
188 |
|
189 |
-
except Exception as e:
|
190 |
-
print(f"An error occurred: {str(e)}, TB: {traceback.format_exc()}")
|
191 |
-
dataset = load_dataset(dataset_name.strip())
|
192 |
-
|
193 |
-
dataset['train'] = dataset['train'].select(range(1000))
|
194 |
-
train_size = len(dataset['train'])
|
195 |
-
third_size = train_size // 3
|
196 |
-
|
197 |
-
first_third = dataset['train'].select(range(third_size))
|
198 |
-
dataset['train'] = first_third
|
199 |
-
del dataset['test']
|
200 |
-
del dataset['validation']
|
201 |
-
tokenized_first_third = dataset.map(tokenize_function, batched=True, batch_size=50, remove_columns=column_names,)
|
202 |
|
203 |
-
tokenized_first_third.save_to_disk(f'/data/{hub_id.strip()}_train_dataset')
|
204 |
-
print('DONE')
|
205 |
-
return 'RUN AGAIN TO LOAD REST OF DATA'
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
# Fine-tune the model
|
208 |
trainer.evaluate()
|
209 |
#if os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir):
|
|
|
121 |
# Load the dataset
|
122 |
column_names = ['text', 'target']
|
123 |
|
124 |
+
#try:
|
125 |
+
#saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
|
126 |
+
#if os.access(f'/data/{hub_id.strip()}_test_dataset', os.R_OK):
|
127 |
+
#train_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset3')
|
128 |
+
#saved_test_dataset = load_from_disk(f'/data/{hub_id.strip()}_validation_dataset')
|
129 |
+
#dataset = load_dataset(dataset_name.strip())
|
130 |
+
#print("FOUND TEST")
|
131 |
+
## Create Trainer
|
132 |
+
#data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
|
133 |
+
#trainer = Trainer(
|
134 |
+
#model=model,
|
135 |
+
#args=training_args,
|
136 |
+
#train_dataset=train_dataset,
|
137 |
+
#eval_dataset=saved_test_dataset['input_ids'],
|
138 |
+
#compute_metrics=compute_metrics,
|
139 |
+
#data_collator=data_collator,
|
140 |
+
##processing_class=tokenizer,
|
141 |
+
#)
|
142 |
|
143 |
+
#elif os.access(f'/data/{hub_id.strip()}_train_dataset3', os.R_OK):
|
144 |
+
#dataset = load_dataset(dataset_name.strip())
|
145 |
+
##dataset['test'] = dataset['test'].select(range(700))
|
146 |
+
#dataset['test'] = dataset['test'].select(range(50))
|
147 |
+
#del dataset['train']
|
148 |
+
#del dataset['validation']
|
149 |
+
#test_set = dataset.map(tokenize_function, batched=True, batch_size=50, remove_columns=column_names,)
|
150 |
+
#test_set['test'].save_to_disk(f'/data/{hub_id.strip()}_test_dataset')
|
151 |
+
#return 'TRAINING DONE'
|
152 |
|
153 |
+
#elif os.access(f'/data/{hub_id.strip()}_validation_dataset', os.R_OK):
|
154 |
+
#dataset = load_dataset(dataset_name.strip())
|
155 |
+
#dataset['train'] = dataset['train'].select(range(8000))
|
156 |
+
#dataset['train'] = dataset['train'].select(range(1000))
|
157 |
+
#train_size = len(dataset['train'])
|
158 |
+
#third_size = train_size // 3
|
159 |
+
#del dataset['test']
|
160 |
+
#del dataset['validation']
|
161 |
+
#print("FOUND VALIDATION")
|
162 |
+
#saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset2')
|
163 |
+
#third_third = dataset['train'].select(range(third_size*2, train_size))
|
164 |
+
#dataset['train'] = third_third
|
165 |
+
##tokenized_second_half = tokenize_function(third_third)
|
166 |
+
#tokenized_second_half = dataset.map(tokenize_function, batched=True, batch_size=50,remove_columns=column_names,)
|
167 |
+
#dataset['train'] = concatenate_datasets([saved_dataset, tokenized_second_half['train']])
|
168 |
+
#dataset['train'].save_to_disk(f'/data/{hub_id.strip()}_train_dataset3')
|
169 |
+
#return 'THIRD THIRD LOADED'
|
170 |
|
171 |
|
172 |
+
#if os.access(f'/data/{hub_id.strip()}_train_dataset', os.R_OK) and not os.access(f'/data/{hub_id.strip()}_train_dataset3', os.R_OK):
|
173 |
+
#dataset = load_dataset(dataset_name.strip())
|
174 |
+
#dataset['train'] = dataset['train'].select(range(1000))
|
175 |
+
#dataset['validation'] = dataset['validation'].select(range(100))
|
176 |
+
##dataset['train'] = dataset['train'].select(range(8000))
|
177 |
+
##dataset['validation'] = dataset['validation'].select(range(300))
|
178 |
+
#train_size = len(dataset['train'])
|
179 |
+
#third_size = train_size // 3
|
180 |
+
#second_third = dataset['train'].select(range(third_size, third_size*2))
|
181 |
+
#dataset['train'] = second_third
|
182 |
+
#del dataset['test']
|
183 |
+
#tokenized_sh_fq_dataset = dataset.map(tokenize_function, batched=True, batch_size=50, remove_columns=column_names,)
|
184 |
+
#dataset['train'] = concatenate_datasets([saved_dataset['train'], tokenized_sh_fq_dataset['train']])
|
185 |
+
#dataset['train'].save_to_disk(f'/data/{hub_id.strip()}_train_dataset2')
|
186 |
+
#dataset['validation'].save_to_disk(f'/data/{hub_id.strip()}_validation_dataset')
|
187 |
+
#return 'SECOND THIRD LOADED'
|
188 |
|
189 |
+
#except Exception as e:
|
190 |
+
#print(f"An error occurred: {str(e)}, TB: {traceback.format_exc()}")
|
191 |
+
#dataset = load_dataset(dataset_name.strip())
|
192 |
+
##dataset['train'] = dataset['train'].select(range(8000))
|
193 |
+
#dataset['train'] = dataset['train'].select(range(1000))
|
194 |
+
#train_size = len(dataset['train'])
|
195 |
+
#third_size = train_size // 3
|
196 |
+
## Tokenize the dataset
|
197 |
+
#first_third = dataset['train'].select(range(third_size))
|
198 |
+
#dataset['train'] = first_third
|
199 |
+
#del dataset['test']
|
200 |
+
#del dataset['validation']
|
201 |
+
#tokenized_first_third = dataset.map(tokenize_function, batched=True, batch_size=50, remove_columns=column_names,)
|
202 |
|
203 |
+
#tokenized_first_third.save_to_disk(f'/data/{hub_id.strip()}_train_dataset')
|
204 |
+
#print('DONE')
|
205 |
+
#return 'RUN AGAIN TO LOAD REST OF DATA'
|
206 |
+
dataset = load_dataset(dataset_name.strip())
|
207 |
+
#dataset['train'] = dataset['train'].select(range(8000))
|
208 |
+
dataset['train'] = dataset['train'].select(range(1000))
|
209 |
+
dataset['validation'] = dataset['validatin'].select(range(100))
|
210 |
+
tokenized_first_third = dataset.map(tokenize_function, batched=True, batch_size=50, remove_columns=column_names,)
|
211 |
+
|
212 |
|
213 |
+
print('DONE')
|
214 |
+
|
215 |
+
|
216 |
+
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
|
217 |
+
trainer = Trainer(
|
218 |
+
model=model,
|
219 |
+
args=training_args,
|
220 |
+
train_dataset=dataset['train'],
|
221 |
+
eval_dataset=dataset['validation'],
|
222 |
+
compute_metrics=compute_metrics,
|
223 |
+
data_collator=data_collator,
|
224 |
+
#processing_class=tokenizer,
|
225 |
+
)
|
226 |
+
|
227 |
# Fine-tune the model
|
228 |
trainer.evaluate()
|
229 |
#if os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir):
|