Kevin Fink
commited on
Commit
·
42338b1
1
Parent(s):
5500a71
deve
Browse files
app.py
CHANGED
@@ -32,33 +32,33 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
32 |
torch.cuda.empty_cache()
|
33 |
torch.nn.CrossEntropyLoss()
|
34 |
rouge_metric = evaluate.load("rouge", cache_dir='/data/cache')
|
35 |
-
def compute_metrics(eval_preds):
|
36 |
-
preds, labels = eval_preds
|
37 |
-
if isinstance(preds, tuple):
|
38 |
-
preds = preds[0]
|
39 |
-
from pprint import pprint as pp
|
40 |
-
pp(preds)
|
41 |
-
|
42 |
-
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
|
43 |
-
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
44 |
|
45 |
-
|
46 |
-
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
47 |
-
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
48 |
|
49 |
-
|
50 |
-
result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels)
|
51 |
-
result = {k: round(v * 100, 4) for k, v in result.items()}
|
52 |
|
53 |
-
|
54 |
-
accuracy = accuracy_score(decoded_labels, decoded_preds)
|
55 |
-
result["eval_accuracy"] = round(accuracy * 100, 4)
|
56 |
|
57 |
-
|
58 |
-
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
59 |
-
result["gen_len"] = np.mean(prediction_lens)
|
60 |
|
61 |
-
return result
|
62 |
|
63 |
login(api_key.strip())
|
64 |
|
@@ -135,118 +135,112 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
135 |
# Load the dataset
|
136 |
column_names = ['text', 'target']
|
137 |
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
#compute_metrics=compute_metrics,
|
153 |
#data_collator=data_collator,
|
154 |
-
|
155 |
-
|
156 |
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
#
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
#return 'THIRD THIRD LOADED'
|
184 |
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
#return 'SECOND THIRD LOADED'
|
202 |
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
#
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
#tokenized_first_third = dataset.map(tokenize_function, batched=True, batch_size=50, remove_columns=column_names,)
|
216 |
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
dataset = load_dataset(dataset_name.strip())
|
221 |
|
222 |
-
#dataset['train'] = dataset['train'].select(range(
|
223 |
-
dataset['
|
224 |
-
|
225 |
-
train_set = dataset.map(tokenize_function, batched=True)
|
226 |
-
#valid_set = dataset['validation'].map(tokenize_function, batched=True)
|
227 |
|
228 |
|
229 |
#print(train_set.keys())
|
230 |
print('DONE')
|
231 |
|
232 |
|
233 |
-
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
|
234 |
-
trainer = Trainer(
|
235 |
-
model=model,
|
236 |
-
args=training_args,
|
237 |
-
train_dataset=train_set['train'],
|
238 |
-
eval_dataset=train_set['validation'],
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
)
|
243 |
|
244 |
# Fine-tune the model
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
#train_result = trainer.train()
|
250 |
trainer.push_to_hub(commit_message="Training complete!")
|
251 |
except Exception as e:
|
252 |
return f"An error occurred: {str(e)}, TB: {traceback.format_exc()}"
|
|
|
32 |
torch.cuda.empty_cache()
|
33 |
torch.nn.CrossEntropyLoss()
|
34 |
rouge_metric = evaluate.load("rouge", cache_dir='/data/cache')
|
35 |
+
#def compute_metrics(eval_preds):
|
36 |
+
#preds, labels = eval_preds
|
37 |
+
#if isinstance(preds, tuple):
|
38 |
+
#preds = preds[0]
|
39 |
+
#from pprint import pprint as pp
|
40 |
+
#pp(preds)
|
41 |
+
## Replace -100s used for padding as we can't decode them
|
42 |
+
#preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
|
43 |
+
#labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
44 |
|
45 |
+
## Decode predictions and labels
|
46 |
+
#decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
47 |
+
#decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
48 |
|
49 |
+
## Compute ROUGE metrics
|
50 |
+
#result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels)
|
51 |
+
#result = {k: round(v * 100, 4) for k, v in result.items()}
|
52 |
|
53 |
+
## Calculate accuracy
|
54 |
+
#accuracy = accuracy_score(decoded_labels, decoded_preds)
|
55 |
+
#result["eval_accuracy"] = round(accuracy * 100, 4)
|
56 |
|
57 |
+
## Calculate average generation length
|
58 |
+
#prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
59 |
+
#result["gen_len"] = np.mean(prediction_lens)
|
60 |
|
61 |
+
#return result
|
62 |
|
63 |
login(api_key.strip())
|
64 |
|
|
|
135 |
# Load the dataset
|
136 |
column_names = ['text', 'target']
|
137 |
|
138 |
+
try:
|
139 |
+
saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
|
140 |
+
if os.access(f'/data/{hub_id.strip()}_test_dataset', os.R_OK):
|
141 |
+
train_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset3')
|
142 |
+
saved_test_dataset = load_from_disk(f'/data/{hub_id.strip()}_validation_dataset')
|
143 |
+
dataset = load_dataset(dataset_name.strip())
|
144 |
+
print("FOUND TEST")
|
145 |
+
# Create Trainer
|
146 |
+
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
|
147 |
+
trainer = Trainer(
|
148 |
+
model=model,
|
149 |
+
args=training_args,
|
150 |
+
train_dataset=train_dataset,
|
151 |
+
eval_dataset=saved_test_dataset,
|
152 |
#compute_metrics=compute_metrics,
|
153 |
#data_collator=data_collator,
|
154 |
+
#processing_class=tokenizer,
|
155 |
+
)
|
156 |
|
157 |
+
elif os.access(f'/data/{hub_id.strip()}_train_dataset3', os.R_OK):
|
158 |
+
dataset = load_dataset(dataset_name.strip())
|
159 |
+
#dataset['test'] = dataset['test'].select(range(700))
|
160 |
+
dataset['test'] = dataset['test'].select(range(50))
|
161 |
+
del dataset['train']
|
162 |
+
del dataset['validation']
|
163 |
+
test_set = dataset.map(tokenize_function, batched=True)
|
164 |
+
test_set['test'].save_to_disk(f'/data/{hub_id.strip()}_test_dataset')
|
165 |
+
return 'TRAINING DONE'
|
166 |
|
167 |
+
elif os.access(f'/data/{hub_id.strip()}_validation_dataset', os.R_OK):
|
168 |
+
dataset = load_dataset(dataset_name.strip())
|
169 |
+
dataset['train'] = dataset['train'].select(range(15000))
|
170 |
+
train_size = len(dataset['train'])
|
171 |
+
third_size = train_size // 3
|
172 |
+
del dataset['test']
|
173 |
+
del dataset['validation']
|
174 |
+
print("FOUND VALIDATION")
|
175 |
+
saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset2')
|
176 |
+
third_third = dataset['train'].select(range(third_size*2, train_size))
|
177 |
+
dataset['train'] = third_third
|
178 |
+
#tokenized_second_half = tokenize_function(third_third)
|
179 |
+
tokenized_second_half = dataset.map(tokenize_function, batched=True)
|
180 |
+
dataset['train'] = concatenate_datasets([saved_dataset, tokenized_second_half['train']])
|
181 |
+
dataset['train'].save_to_disk(f'/data/{hub_id.strip()}_train_dataset3')
|
182 |
+
return 'THIRD THIRD LOADED'
|
|
|
183 |
|
184 |
|
185 |
+
if os.access(f'/data/{hub_id.strip()}_train_dataset', os.R_OK) and not os.access(f'/data/{hub_id.strip()}_train_dataset3', os.R_OK):
|
186 |
+
dataset = load_dataset(dataset_name.strip())
|
187 |
+
dataset['train'] = dataset['train'].select(range(15000))
|
188 |
+
dataset['validation'] = dataset['validation'].select(range(2000))
|
189 |
+
train_size = len(dataset['train'])
|
190 |
+
third_size = train_size // 3
|
191 |
+
second_third = dataset['train'].select(range(third_size, third_size*2))
|
192 |
+
dataset['train'] = second_third
|
193 |
+
del dataset['test']
|
194 |
+
tokenized_sh_fq_dataset = dataset.map(tokenize_function, batched=True,)
|
195 |
+
saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
|
196 |
+
dataset['train'] = concatenate_datasets([saved_dataset, tokenized_sh_fq_dataset['train']])
|
197 |
+
dataset['train'].save_to_disk(f'/data/{hub_id.strip()}_train_dataset2')
|
198 |
+
dataset['validation'].save_to_disk(f'/data/{hub_id.strip()}_validation_dataset')
|
199 |
+
return 'SECOND THIRD LOADED'
|
|
|
200 |
|
201 |
+
except Exception as e:
|
202 |
+
print(f"An error occurred: {str(e)}, TB: {traceback.format_exc()}")
|
203 |
+
dataset = load_dataset(dataset_name.strip())
|
204 |
+
dataset['train'] = dataset['train'].select(range(15000))
|
205 |
+
train_size = len(dataset['train'])
|
206 |
+
third_size = train_size // 3
|
207 |
+
# Tokenize the dataset
|
208 |
+
first_third = dataset['train'].select(range(third_size))
|
209 |
+
dataset['train'] = first_third
|
210 |
+
del dataset['test']
|
211 |
+
del dataset['validation']
|
212 |
+
tokenized_first_third = dataset.map(tokenize_function, batched=True,)
|
|
|
213 |
|
214 |
+
tokenized_first_third.save_to_disk(f'/data/{hub_id.strip()}_train_dataset')
|
215 |
+
print('DONE')
|
216 |
+
return 'RUN AGAIN TO LOAD REST OF DATA'
|
217 |
dataset = load_dataset(dataset_name.strip())
|
218 |
|
219 |
+
#dataset['train'] = dataset['train'].select(range(4000))
|
220 |
+
#dataset['validation'] = dataset['validation'].select(range(200))
|
221 |
+
#train_set = dataset.map(tokenize_function, batched=True)
|
|
|
|
|
222 |
|
223 |
|
224 |
#print(train_set.keys())
|
225 |
print('DONE')
|
226 |
|
227 |
|
228 |
+
#data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
|
229 |
+
#trainer = Trainer(
|
230 |
+
#model=model,
|
231 |
+
#args=training_args,
|
232 |
+
#train_dataset=train_set['train'],
|
233 |
+
#eval_dataset=train_set['validation'],
|
234 |
+
##compute_metrics=compute_metrics,
|
235 |
+
##data_collator=data_collator,
|
236 |
+
##processing_class=tokenizer,
|
237 |
+
#)
|
238 |
|
239 |
# Fine-tune the model
|
240 |
+
if os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir):
|
241 |
+
train_result = trainer.train(resume_from_checkpoint=True)
|
242 |
+
else:
|
243 |
+
train_result = trainer.train()
|
|
|
244 |
trainer.push_to_hub(commit_message="Training complete!")
|
245 |
except Exception as e:
|
246 |
return f"An error occurred: {str(e)}, TB: {traceback.format_exc()}"
|