Saving train state of step 5
Browse files- checkpoint-5-epoch-0/model.safetensors +3 -0
- checkpoint-5-epoch-0/model_1.safetensors +3 -0
- checkpoint-5-epoch-0/optimizer.bin +3 -0
- checkpoint-5-epoch-0/random_states_0.pkl +3 -0
- checkpoint-5-epoch-0/scheduler.bin +3 -0
- distil-whisper/events.out.tfevents.1715073979.server02.1433788.0 +3 -0
- distil-whisper/events.out.tfevents.1715074029.server02.1434198.0 +3 -0
- distil-whisper/events.out.tfevents.1715095796.server02.1514457.0 +3 -0
- distil-whisper/events.out.tfevents.1715137750.server02.1659182.0 +3 -0
- distil-whisper/events.out.tfevents.1715142860.server02.1688240.0 +3 -0
- distil-whisper/events.out.tfevents.1715144009.server02.1717420.0 +3 -0
- distil-whisper/events.out.tfevents.1715144142.server02.1721266.0 +3 -0
- distil-whisper/events.out.tfevents.1715144248.server02.1724677.0 +3 -0
- distil-whisper/events.out.tfevents.1715144329.server02.1726964.0 +3 -0
- distil-whisper/events.out.tfevents.1715144689.server02.1736871.0 +3 -0
- distil-whisper/events.out.tfevents.1715144766.server02.1739137.0 +3 -0
- distil-whisper/events.out.tfevents.1715145134.server02.1748391.0 +3 -0
- distil-whisper/events.out.tfevents.1715152989.server02.1776687.0 +3 -0
- distil-whisper/events.out.tfevents.1715153425.server02.1778557.0 +3 -0
- distil-whisper/events.out.tfevents.1715153634.server02.1779609.0 +3 -0
- distil-whisper/events.out.tfevents.1715153723.server02.1780155.0 +3 -0
- distil-whisper/events.out.tfevents.1715154461.server02.1782973.0 +3 -0
- distil-whisper/events.out.tfevents.1715160495.server02.1805047.0 +3 -0
- run_distillation.py +36 -8
- test_partial_function.py +41 -0
checkpoint-5-epoch-0/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a21e3711ac40e9335e1f3f3996f60b973cd257c3f524366cf6b834e59d49f13
|
3 |
+
size 3025686376
|
checkpoint-5-epoch-0/model_1.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6b395c8a7e2bda655c415580106288d0387c227efd641bf4e11c1cd735fdb37a
|
3 |
+
size 4361070048
|
checkpoint-5-epoch-0/optimizer.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2d29f0667d7e38b9abb98596a5a9348d8f95ae4e4a7715159e01a41ac9d2f620
|
3 |
+
size 955539578
|
checkpoint-5-epoch-0/random_states_0.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:85d573cec64fffbd3f22840ac5142a2d5238117a2d0f909e2a3a64155fe22435
|
3 |
+
size 14344
|
checkpoint-5-epoch-0/scheduler.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:61c54c0f7915329263989409611568f153678f74fb6fe4366f23ad24844d158f
|
3 |
+
size 1064
|
distil-whisper/events.out.tfevents.1715073979.server02.1433788.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5acad6483f543c7a7c6c1db549ee743b5cd298b504a7d47ab30b9f233fb919c4
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715074029.server02.1434198.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b67893905a27526433ed6485bb3eaffe1e82b3c7da45cbaa27d9266c53433144
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715095796.server02.1514457.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:979ff30080df67a4dcab044cf870600f43073d71992160952446f78b19dcf897
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715137750.server02.1659182.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2c806ece5e794b1027c129222cdd44b22a92e26332c899a6e4dc8583f757f7dc
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715142860.server02.1688240.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3e1d8a8e923945bfdcb5f29a077c4f7009484ae9a917bf9ea970492efde3c5aa
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715144009.server02.1717420.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4825318ef78b792ae93dc0cd60918637895c9833607e1918cfad58d83bff016
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715144142.server02.1721266.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:980b48300850fc0e5190e9466fd1749a1ed461b5ff2fe918d3e3dfb3644625ef
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715144248.server02.1724677.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4bdc23e6b0151e29c67e9069845f1218f686f7667fe3f3bdd1663eea19240cc6
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715144329.server02.1726964.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:551fe2b2ff53ab0e30742564ca1935299589000ee897fb65612d7706002e701c
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715144689.server02.1736871.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:631771d076472b08d53a7585ff812de4dd8e4e500b011da3a34c74ea6cc65d33
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715144766.server02.1739137.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:57ce892831e97b587c4573c575dc1ba0e11317517f0b5e3ba4b41822d4eea0e6
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715145134.server02.1748391.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:89c5cccd2e339637b9564d94fb6abf49e6dcd9e481292d4d12deaa5367ac49bb
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715152989.server02.1776687.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:02fc9bf530ac7fe72cc27f5be66569b22a8d0f20634adea1ffd8b9b8e084cefe
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715153425.server02.1778557.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f66ee5f168a300d24bffa01d463604a5486270b89177b5579501bd69da02f864
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715153634.server02.1779609.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2eb46c3608e41a15b89df7edb7aa506521d7c2f5f9528f58b44be97b6a7a4b90
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715153723.server02.1780155.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:697804e781696b9e3f46d180576ac88978f77f5493d86bc6f8591928a313daa1
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715154461.server02.1782973.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3790ffdcde3556f8cb2531ac6ade1add4b4cce83bda933fa6f5cb5cdd68f3566
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715160495.server02.1805047.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2f4403d07051c2a9e92defb2b8ba4d895beda85e4d2a8b1c8a2fa816d0183ffb
|
3 |
+
size 392
|
run_distillation.py
CHANGED
@@ -855,6 +855,9 @@ def main():
|
|
855 |
)
|
856 |
raw_datasets_train_features = list(raw_datasets["train"].features.keys())
|
857 |
|
|
|
|
|
|
|
858 |
if training_args.do_eval:
|
859 |
dataset_names_dict = convert_dataset_str_to_list(
|
860 |
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
|
@@ -1074,6 +1077,7 @@ def main():
|
|
1074 |
else raw_datasets["train"].select(range(data_args.max_train_samples))
|
1075 |
)
|
1076 |
|
|
|
1077 |
if training_args.do_eval and data_args.max_eval_samples is not None:
|
1078 |
for eval_split in all_eval_splits:
|
1079 |
raw_datasets[eval_split] = (
|
@@ -1101,6 +1105,13 @@ def main():
|
|
1101 |
function=is_wer_in_range,
|
1102 |
input_columns=["text", "whisper_transcript"],
|
1103 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1104 |
|
1105 |
if wer_threshold is not None and use_pseudo_labels:
|
1106 |
with accelerator.main_process_first():
|
@@ -1217,6 +1228,7 @@ def main():
|
|
1217 |
if not data_args.streaming
|
1218 |
else map_fn_eval()
|
1219 |
)
|
|
|
1220 |
|
1221 |
# 10.5: Filter training data with inputs longer than `max_input_length`
|
1222 |
def is_audio_in_length_range(length):
|
@@ -1266,6 +1278,8 @@ def main():
|
|
1266 |
# 11. Define Evaluation Metrics
|
1267 |
def compute_metrics(preds, labels):
|
1268 |
# replace padded labels by the padding token
|
|
|
|
|
1269 |
for idx in range(len(labels)):
|
1270 |
labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
|
1271 |
|
@@ -1289,7 +1303,7 @@ def main():
|
|
1289 |
|
1290 |
# 12. Define Training Schedule
|
1291 |
# Store some constants
|
1292 |
-
per_device_train_batch_size = int(training_args.per_device_train_batch_size)
|
1293 |
train_batch_size = per_device_train_batch_size * accelerator.num_processes
|
1294 |
gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
|
1295 |
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
@@ -1306,8 +1320,8 @@ def main():
|
|
1306 |
num_epochs = int(np.ceil(total_train_steps / steps_per_epoch))
|
1307 |
else:
|
1308 |
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
|
1309 |
-
num_epochs = sys.maxsize
|
1310 |
-
steps_per_epoch = total_train_steps
|
1311 |
else:
|
1312 |
raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
|
1313 |
|
@@ -1318,7 +1332,9 @@ def main():
|
|
1318 |
eval_steps = steps_per_epoch
|
1319 |
else:
|
1320 |
eval_steps = training_args.eval_steps
|
1321 |
-
|
|
|
|
|
1322 |
# 13. Define optimizer, LR scheduler, collator
|
1323 |
decay_parameters = get_parameter_names(
|
1324 |
student_model,
|
@@ -1350,7 +1366,7 @@ def main():
|
|
1350 |
num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
|
1351 |
num_training_steps=total_train_steps * accelerator.num_processes,
|
1352 |
)
|
1353 |
-
|
1354 |
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
1355 |
processor=processor,
|
1356 |
decoder_start_token_id=decoder_start_token_id,
|
@@ -1382,11 +1398,16 @@ def main():
|
|
1382 |
}
|
1383 |
)
|
1384 |
print(f" gen_kwargs : {gen_kwargs}")
|
|
|
|
|
1385 |
#15. Prepare everything with accelerate
|
1386 |
student_model, teacher_model, optimizer, lr_scheduler = accelerator.prepare(
|
1387 |
student_model, teacher_model, optimizer, lr_scheduler
|
1388 |
)
|
1389 |
|
|
|
|
|
|
|
1390 |
def kl_divergence(target_distribution, log_predicted_distribution, labels):
|
1391 |
kl_loss = nn.KLDivLoss(reduction="none")
|
1392 |
divergence = kl_loss(log_predicted_distribution, target_distribution)
|
@@ -1415,8 +1436,8 @@ def main():
|
|
1415 |
teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
|
1416 |
else:
|
1417 |
# do the full forward pass for the teacher model (encoder + decoder)
|
1418 |
-
teacher_outputs = teacher_model(**batch)
|
1419 |
-
|
1420 |
# CE (data) loss
|
1421 |
ce_loss = student_outputs.loss
|
1422 |
# rescale distribution by temperature to ensure gradients scale correctly
|
@@ -1519,6 +1540,13 @@ def main():
|
|
1519 |
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
|
1520 |
else:
|
1521 |
resume_step = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1522 |
|
1523 |
for epoch in range(epochs_trained, num_epochs):
|
1524 |
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
|
@@ -1596,7 +1624,7 @@ def main():
|
|
1596 |
eval_labels = []
|
1597 |
eval_start = time.time()
|
1598 |
|
1599 |
-
|
1600 |
vectorized_datasets[eval_split],
|
1601 |
collate_fn=data_collator,
|
1602 |
batch_size=per_device_eval_batch_size,
|
|
|
855 |
)
|
856 |
raw_datasets_train_features = list(raw_datasets["train"].features.keys())
|
857 |
|
858 |
+
|
859 |
+
print(f'858 raw_datasets["train"] : {raw_datasets["train"] }')
|
860 |
+
|
861 |
if training_args.do_eval:
|
862 |
dataset_names_dict = convert_dataset_str_to_list(
|
863 |
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
|
|
|
1077 |
else raw_datasets["train"].select(range(data_args.max_train_samples))
|
1078 |
)
|
1079 |
|
1080 |
+
#if we want to select first n samples , not entire validation set
|
1081 |
if training_args.do_eval and data_args.max_eval_samples is not None:
|
1082 |
for eval_split in all_eval_splits:
|
1083 |
raw_datasets[eval_split] = (
|
|
|
1105 |
function=is_wer_in_range,
|
1106 |
input_columns=["text", "whisper_transcript"],
|
1107 |
)
|
1108 |
+
|
1109 |
+
|
1110 |
+
|
1111 |
+
|
1112 |
+
print(f' raw_datasets["train"].filter : {raw_datasets["train"].filter}')
|
1113 |
+
print(f' raw_datasets["train"] : {raw_datasets["train"]}')
|
1114 |
+
|
1115 |
|
1116 |
if wer_threshold is not None and use_pseudo_labels:
|
1117 |
with accelerator.main_process_first():
|
|
|
1228 |
if not data_args.streaming
|
1229 |
else map_fn_eval()
|
1230 |
)
|
1231 |
+
|
1232 |
|
1233 |
# 10.5: Filter training data with inputs longer than `max_input_length`
|
1234 |
def is_audio_in_length_range(length):
|
|
|
1278 |
# 11. Define Evaluation Metrics
|
1279 |
def compute_metrics(preds, labels):
|
1280 |
# replace padded labels by the padding token
|
1281 |
+
print(f" preds : {preds}")
|
1282 |
+
print(f" labels : {labels}")
|
1283 |
for idx in range(len(labels)):
|
1284 |
labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
|
1285 |
|
|
|
1303 |
|
1304 |
# 12. Define Training Schedule
|
1305 |
# Store some constants
|
1306 |
+
per_device_train_batch_size = int(training_args.per_device_train_batch_size)
|
1307 |
train_batch_size = per_device_train_batch_size * accelerator.num_processes
|
1308 |
gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
|
1309 |
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
|
|
1320 |
num_epochs = int(np.ceil(total_train_steps / steps_per_epoch))
|
1321 |
else:
|
1322 |
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
|
1323 |
+
num_epochs = sys.maxsize #num_epochs as much as possible
|
1324 |
+
steps_per_epoch = total_train_steps
|
1325 |
else:
|
1326 |
raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
|
1327 |
|
|
|
1332 |
eval_steps = steps_per_epoch
|
1333 |
else:
|
1334 |
eval_steps = training_args.eval_steps
|
1335 |
+
|
1336 |
+
print(f" num_epochs : {num_epochs}")
|
1337 |
+
print(f" steps_per_epoch = total_train_steps : {steps_per_epoch}")
|
1338 |
# 13. Define optimizer, LR scheduler, collator
|
1339 |
decay_parameters = get_parameter_names(
|
1340 |
student_model,
|
|
|
1366 |
num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
|
1367 |
num_training_steps=total_train_steps * accelerator.num_processes,
|
1368 |
)
|
1369 |
+
print()
|
1370 |
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
1371 |
processor=processor,
|
1372 |
decoder_start_token_id=decoder_start_token_id,
|
|
|
1398 |
}
|
1399 |
)
|
1400 |
print(f" gen_kwargs : {gen_kwargs}")
|
1401 |
+
print(f" raw_datasets['eval']: {raw_datasets['eval']}")
|
1402 |
+
|
1403 |
#15. Prepare everything with accelerate
|
1404 |
student_model, teacher_model, optimizer, lr_scheduler = accelerator.prepare(
|
1405 |
student_model, teacher_model, optimizer, lr_scheduler
|
1406 |
)
|
1407 |
|
1408 |
+
|
1409 |
+
|
1410 |
+
|
1411 |
def kl_divergence(target_distribution, log_predicted_distribution, labels):
|
1412 |
kl_loss = nn.KLDivLoss(reduction="none")
|
1413 |
divergence = kl_loss(log_predicted_distribution, target_distribution)
|
|
|
1436 |
teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
|
1437 |
else:
|
1438 |
# do the full forward pass for the teacher model (encoder + decoder)
|
1439 |
+
teacher_outputs = teacher_model(**batch)
|
1440 |
+
|
1441 |
# CE (data) loss
|
1442 |
ce_loss = student_outputs.loss
|
1443 |
# rescale distribution by temperature to ensure gradients scale correctly
|
|
|
1540 |
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
|
1541 |
else:
|
1542 |
resume_step = None
|
1543 |
+
print(f" raw_datasets['train'] : {raw_datasets['train']} ")
|
1544 |
+
print(f" raw_datasets['eval'] : {raw_datasets['eval']} ")
|
1545 |
+
|
1546 |
+
print(f" vectorized_datasets['eval'] : {vectorized_datasets['eval']}")
|
1547 |
+
print(f" vectorized_datasets['train'] : {vectorized_datasets['train']}")
|
1548 |
+
|
1549 |
+
|
1550 |
|
1551 |
for epoch in range(epochs_trained, num_epochs):
|
1552 |
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
|
|
|
1624 |
eval_labels = []
|
1625 |
eval_start = time.time()
|
1626 |
|
1627 |
+
validation_dataloader = DataLoader(
|
1628 |
vectorized_datasets[eval_split],
|
1629 |
collate_fn=data_collator,
|
1630 |
batch_size=per_device_eval_batch_size,
|
test_partial_function.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
# Mock dataset in a dictionary form, similar to what you might find in a data processing library
|
4 |
+
dataset = {
|
5 |
+
"train": [
|
6 |
+
{"text": "Hello world", "id": 1},
|
7 |
+
{"text": "Partial functions are cool", "id": 2},
|
8 |
+
]
|
9 |
+
}
|
10 |
+
|
11 |
+
# Function to preprocess the dataset
|
12 |
+
def prepare_train_dataset(example):
|
13 |
+
# Let's say we just transform the text to uppercase for simplicity
|
14 |
+
return {"text": example["text"].upper()}
|
15 |
+
|
16 |
+
# Columns to remove from the dataset after the transformation
|
17 |
+
columns_to_remove = ['id']
|
18 |
+
|
19 |
+
# Creating a mock map function for the dataset
|
20 |
+
def dataset_map(batch, function, remove_columns, batched, batch_size):
|
21 |
+
# Process each batch
|
22 |
+
transformed_data = [function(example) for example in batch]
|
23 |
+
# Remove specified columns
|
24 |
+
for item in transformed_data:
|
25 |
+
for column in remove_columns:
|
26 |
+
item.pop(column, None)
|
27 |
+
return transformed_data
|
28 |
+
|
29 |
+
# Using partial to pre-configure the map function
|
30 |
+
map_fn_train = partial(
|
31 |
+
dataset_map,
|
32 |
+
batch=dataset["train"],
|
33 |
+
function=prepare_train_dataset,
|
34 |
+
remove_columns=columns_to_remove,
|
35 |
+
batched=True,
|
36 |
+
batch_size=2 # Assuming we process all data in one batch for simplicity
|
37 |
+
)
|
38 |
+
|
39 |
+
# Using the configured function
|
40 |
+
transformed_dataset = map_fn_train()
|
41 |
+
print(transformed_dataset)
|