File size: 23,148 Bytes
07423df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
import logging
import math
import os
from typing import Any, DefaultDict, Dict, List, Optional, Tuple, Union, no_type_check

import networkx as nx
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import torch
from sklearn.model_selection import GroupShuffleSplit, train_test_split
from torch import distributed as dist
from torch.utils.data import DataLoader, Sampler, SequentialSampler

from llm_studio.src.datasets.conversation_chain_handler import ConversationChainHandler
from llm_studio.src.utils.exceptions import LLMDataException
from llm_studio.src.utils.gpu_utils import sync_across_processes
from llm_studio.src.utils.utils import PatchedAttribute, set_seed

logger = logging.getLogger(__name__)


def read_dataframe(
    path: str,
    n_rows: int = -1,
    meta_only: bool = False,
    non_missing_columns: Optional[List[str]] = None,
    verbose: bool = False,
    handling: str = "warn",
    fill_columns: Optional[List[str]] = None,
    fill_value: Any = "",
    mode: str = "",
) -> pd.DataFrame:
    """Reading a dataframe from different file types

    Args:
        path: path of the dataframe
        n_rows: number of rows to limit to
        meta_only: return only meta information
        non_missing_columns: list of columns that cannot contain missing values
        verbose: if warning about dropped rows should be logged
        handling: how to handle missing values
        fill_columns: columns where empty value should be filled (used for empty text)
        fill_value: value to fill empty columns with (used for empty text)
        mode: dataset type, used only for better exception/log information
    Returns:
        dataframe

    """

    non_missing_columns = [] if non_missing_columns is None else non_missing_columns
    fill_columns = [] if fill_columns is None else fill_columns

    meta_info_path = os.path.split(path)
    meta_info_path = os.path.join(
        *meta_info_path[:-1],
        "__meta_info__" + meta_info_path[-1] + ".csv",
    )
    if meta_only and os.path.exists(meta_info_path):
        path = meta_info_path

    if path.endswith(".csv"):
        df = pd.read_csv(path, lineterminator="\n").reset_index(drop=True)
    elif path.endswith(".pq") or path.endswith(".parquet"):
        try:
            df = pd.read_parquet(path, engine="pyarrow").reset_index(drop=True)
        except Exception:
            df = pd.read_parquet(path, engine="fastparquet").reset_index(drop=True)
    elif path.endswith(".json") or path == "":
        return pd.DataFrame()
    else:
        raise ValueError(
            f"Could not determine type of file {path}: "
            f"CSV (`.csv`) and Parquet (`.pq` and `.parquet`) are supported."
        )

    if fill_columns:
        df[fill_columns] = df[fill_columns].fillna(fill_value)

    if meta_only and os.path.exists(meta_info_path):
        return df

    non_missing_columns = [x for x in non_missing_columns if x in df]
    if len(non_missing_columns):
        orig_size = df.shape[0]
        non_missing_index = df[non_missing_columns].dropna().index
        dropped_index = [idx for idx in df.index if idx not in non_missing_index]
        df = df.loc[non_missing_index].reset_index(drop=True)
        new_size = df.shape[0]
        if new_size < orig_size and verbose:
            logger.warning(
                f"Dropped {orig_size - new_size} rows when reading dataframe '{path}' "
                f"due to missing values encountered in one of the following columns:"
                f" {non_missing_columns} in the following rows: {dropped_index}"
            )

            if handling == "error":
                dropped_str = dropped_index

                if len(dropped_str) > 10:
                    dropped_str = dropped_str[:5] + ["..."] + dropped_str[-5:]

                dropped_str = ", ".join([str(x) for x in dropped_str])
                prefix = f"{mode} " if mode else ""
                error = (
                    f"{prefix}dataset contains {len(dropped_index)} rows with missing "
                    f"values in one of the following columns: {non_missing_columns} in "
                    f"the following rows: {dropped_str}"
                )

                raise ValueError(error.capitalize())

    if n_rows > -1:
        df = df.iloc[sample_indices(len(df), n_indices=n_rows)]

    # create meta information dataframe if it does not exist
    if not os.path.exists(meta_info_path):
        df_meta = pd.DataFrame(columns=df.columns)
        df_meta.to_csv(meta_info_path, index=False)

    return df


def get_fill_columns(cfg: Any) -> List[str]:
    if hasattr(cfg.dataset, "prompt_column"):
        if isinstance(cfg.dataset.prompt_column, (list, tuple)):
            return list(cfg.dataset.prompt_column)
        return [cfg.dataset.prompt_column]

    return []


def read_dataframe_drop_missing_labels(path: str, cfg: Any) -> pd.DataFrame:
    if isinstance(cfg.dataset.prompt_column, tuple):
        input_cols = list(cfg.dataset.prompt_column)
    else:
        input_cols = [cfg.dataset.prompt_column]
    verbose = cfg.environment._local_rank == 0
    fill_columns = get_fill_columns(cfg)
    df = read_dataframe(
        path,
        non_missing_columns=input_cols,
        verbose=verbose,
        fill_columns=fill_columns,
    )
    df[input_cols] = df[input_cols].fillna("").astype(str)
    if (
        hasattr(cfg.dataset, "answer_column")
        and cfg.dataset.answer_column in df.columns
    ):
        df[cfg.dataset.answer_column] = (
            df[cfg.dataset.answer_column].fillna("").astype(str)
        )
    return df


def is_valid_data_frame(path: str, csv_rows: int = 100) -> bool:
    """Checking data frame format

    Args:
        path: path of the dataframe
        csv_rows: number of rows to limit to when checking csv files

    Returns:
        bool

    """
    try:
        if path.endswith(".csv"):
            pd.read_csv(path, nrows=csv_rows, lineterminator="\n")
        elif path.endswith(".pq") or path.endswith(".parquet"):
            pq.ParquetFile(path)
        else:
            raise ValueError(
                f"Could not determine type of file {path}: "
                f"CSV (`.csv`) and Parquet (`.pq` and `.parquet`) are supported."
            )
    except Exception as e:
        logger.error(str(e))
        return False
    return True


def sample_data(cfg: Any, df: pd.DataFrame) -> pd.DataFrame:
    """Sample data from the dataframe"""

    if cfg.dataset.parent_id_column != "None" and "id" in df.columns:
        parent_mapping = df.set_index("id")["parent_id"].to_dict()

        # A recursive function to get the root id for each node
        def get_root(node):
            parent = parent_mapping.get(node)
            if parent is None or pd.isna(parent):
                return node
            return get_root(parent)

        # Apply the function to assign each row the root id
        df["root_id"] = df["id"].apply(get_root)

        # Sample root_ids without replacement
        root_ids = df["root_id"].unique()
        n_sampled_root_ids = int(len(root_ids) * cfg.dataset.data_sample)

        np.random.seed(7331)
        sampled_root_ids = np.random.choice(
            root_ids, size=n_sampled_root_ids, replace=False
        )

        # Filter the dataframe to only include rows with sampled root_ids
        df = df[df["root_id"].isin(sampled_root_ids)].reset_index(drop=True)
        del df["root_id"]
    else:
        # at least 10 observations
        n = max(10, int(len(df) * cfg.dataset.data_sample))
        df = df.sample(n=min(n, len(df)), random_state=7331, replace=False)

    return df


def load_mt_bench_data(cfg: Any) -> pd.DataFrame:
    """Loads MT-BENCH data.

    Args:
        cfg: input config

    Returns:
        MT-BENCH DataFrame
    """

    prompt_column = cfg.dataset.prompt_column[0]
    answer_column = cfg.dataset.answer_column

    df = df = pd.read_json("prompts/mt-bench/question.jsonl", lines=True)
    df = df.rename(columns={"turns": prompt_column, "reference": answer_column})
    df[prompt_column] = df[prompt_column].apply(lambda x: x[0])
    df[answer_column] = (
        df[answer_column].fillna("").apply(lambda x: x[0] if x != "" else x)
    )

    return df


def get_data(cfg: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Prepares train and validation DataFrames.

    Args:
        cfg: input config

    Returns:
        Train and validation DataFrames
    """

    train_df, val_df = load_train_valid_data(cfg)

    if (
        hasattr(cfg.prediction, "metric_gpt_template")
        and cfg.prediction.metric_gpt_template == "mt-bench"
    ):
        if cfg.environment._local_rank == 0:
            logger.info(
                "Overwriting validation data with MT-BENCH data. Please note that "
                "respective metric is an approximation and might not fully match "
                "the original implementation."
            )
        val_df = load_mt_bench_data(cfg)

    if cfg.dataset.data_sample < 1.0:
        if "Train" in cfg.dataset.data_sample_choice:
            train_df = sample_data(cfg, train_df)
        if "Validation" in cfg.dataset.data_sample_choice:
            val_df = sample_data(cfg, val_df)

    if cfg.training.train_validation_data:
        train_df = pd.concat([train_df, val_df], axis=0)

    train_df = cfg.dataset.dataset_class.preprocess_dataframe(
        train_df, cfg, mode="train"
    )
    val_df = cfg.dataset.dataset_class.preprocess_dataframe(
        val_df, cfg, mode="validation"
    )

    return train_df.reset_index(drop=True), val_df.reset_index(drop=True)


def merge_on_common_items(lst):
    G = nx.Graph()
    for sublst in lst:
        for item in sublst:
            G.add_edge(sublst[0], item)
    return [list(c) for c in nx.connected_components(G)]


def load_train_valid_data(cfg) -> Tuple[pd.DataFrame, pd.DataFrame]:
    if cfg.dataset.validation_strategy == "custom":
        if cfg.dataset.validation_dataframe == "None":
            raise LLMDataException(
                "No validation dataframe provided. "
                "Please provide a validation dataframe or "
                "choose a different validation strategy."
            )
        train_df = read_dataframe_drop_missing_labels(cfg.dataset.train_dataframe, cfg)
        val_df = read_dataframe_drop_missing_labels(
            cfg.dataset.validation_dataframe, cfg
        )
    elif cfg.dataset.validation_strategy == "automatic":
        if cfg.environment._local_rank == 0:
            logger.info("Setting up automatic validation split...")
        df = read_dataframe_drop_missing_labels(cfg.dataset.train_dataframe, cfg)
        if cfg.dataset.parent_id_column != "None" and "id" in df.columns:
            # split based on conversation_chain_ids
            # this ensures that all samples from the
            # same conversation are in the same fold
            with PatchedAttribute(cfg.dataset, "limit_chained_samples", True):
                conversation_chain_ids = ConversationChainHandler(
                    df=df, cfg=cfg
                ).conversation_chain_ids
            # Some conversations may have the same parent id, e.g. for OASST
            # 6aa548c6-65ad-4531-9411-76173ae060a3 and
            # 2a164c2a-4f0e-45aa-8990-e7dd3b51c06b
            # have the same parent a8df94e3-cfc7-4736-9587-0ec943d0fec3
            # We need to merge those into a single group
            conversation_chain_ids = merge_on_common_items(conversation_chain_ids)
            conversation_chain_labels = [
                i
                for i, conversation_chain_id in enumerate(conversation_chain_ids)
                for _ in conversation_chain_id
            ]
            group_shuffle_split = GroupShuffleSplit(
                test_size=cfg.dataset.validation_size, n_splits=1, random_state=1337
            )
            train_idx, val_idx = next(
                group_shuffle_split.split(df, groups=conversation_chain_labels)
            )
            # flatten conversation_chain_ids
            flattened_conversation_chain_ids = np.array(
                [
                    idx
                    for conversation_chain_id in conversation_chain_ids
                    for idx in conversation_chain_id
                ]
            )
            train_df = df.iloc[flattened_conversation_chain_ids[train_idx]].reset_index(
                drop=True
            )
            val_df = df.iloc[flattened_conversation_chain_ids[val_idx]].reset_index(
                drop=True
            )
        else:
            train_df, val_df = train_test_split(
                df, test_size=cfg.dataset.validation_size, random_state=1337
            )
    else:
        raise LLMDataException("No valid validation strategy provided.")
    return train_df, val_df


def worker_init_fn(worker_id: int) -> None:
    """Sets the random seed for each worker.

    Args:
        worker_id: ID of the corresponding worker
    """

    if "PYTHONHASHSEED" in os.environ:
        seed = int(os.environ["PYTHONHASHSEED"]) + worker_id
    else:
        seed = np.random.get_state()[1][0] + worker_id  # type: ignore
    set_seed(seed)


def get_train_dataset(train_df: pd.DataFrame, cfg: Any, verbose=True):
    """Prepares train Dataset.

    Args:
        train_df: train DataFrame
        cfg: input config
        verbose: whether to print the logs

    Returns:
        Train Dataset
    """

    if cfg.environment._local_rank == 0 and verbose:
        logger.info("Loading train dataset...")

    train_dataset = cfg.dataset.dataset_class(df=train_df, cfg=cfg, mode="train")
    return train_dataset


def get_train_dataloader(train_ds: Any, cfg: Any, verbose=True):
    """Prepares train DataLoader.

    Args:
        train_ds: train Dataset
        cfg: input config
        verbose: whether to print the logs

    Returns:
        Train Dataloader
    """

    sampler: Sampler
    if cfg.environment._distributed:
        sampler = torch.utils.data.distributed.DistributedSampler(
            train_ds,
            num_replicas=cfg.environment._world_size,
            rank=cfg.environment._local_rank,
            shuffle=True,
            seed=cfg.environment._seed,
            drop_last=True,
        )
        sampler_length = len(sampler)
    else:
        sampler = None
        sampler_length = len(train_ds)

    if sampler_length < cfg.training.batch_size and cfg.training.drop_last_batch:
        logger.warning(
            "Training data too small when dropping last batch. Number of rows "
            "should be at least batch size multiplied by number of gpus. "
            "Forcing to keep last batch."
        )
        cfg.training.drop_last_batch = False
    if sampler_length <= 1:
        raise LLMDataException("Data too small to train model.")

    train_dataloader = DataLoader(
        train_ds,
        sampler=sampler,
        shuffle=(sampler is None),
        batch_size=cfg.training.batch_size,
        num_workers=cfg.environment.number_of_workers,
        pin_memory=True,
        collate_fn=train_ds.get_train_collate_fn(),
        drop_last=cfg.training.drop_last_batch,
        worker_init_fn=worker_init_fn,
    )

    if cfg.environment._local_rank == 0 and verbose:
        logger.info(f"Number of observations in train dataset: {len(train_ds)}")

    return train_dataloader


def get_val_dataset(val_df: pd.DataFrame, cfg: Any, verbose: bool = True):
    """Prepares validation Dataset.

    Args:
        val_df: validation DataFrame
        cfg: input config
        verbose: verbose

    Returns:
        Validation Dataset
    """

    if verbose and cfg.environment._local_rank == 0:
        logger.info("Loading validation dataset...")
    val_dataset = cfg.dataset.dataset_class(df=val_df, cfg=cfg, mode="validation")

    return val_dataset


def get_val_dataloader(val_ds: Any, cfg: Any, verbose: bool = True):
    """Prepares validation DataLoader.

    Args:
        val_ds: validation Dataset
        cfg: input config
        verbose: verbose

    Returns:
        Validation Dataloader
    """

    sampler: Sampler
    if cfg.environment._distributed and cfg.environment._distributed_inference:
        sampler = OrderedDistributedSampler(
            val_ds,
            num_replicas=cfg.environment._world_size,
            rank=cfg.environment._local_rank,
        )
    else:
        sampler = SequentialSampler(val_ds)

    batch_size = get_inference_batch_size(cfg)

    val_dataloader = DataLoader(
        val_ds,
        sampler=sampler,
        batch_size=batch_size,
        num_workers=cfg.environment.number_of_workers,
        pin_memory=True,
        collate_fn=val_ds.get_validation_collate_fn(),
        worker_init_fn=worker_init_fn,
    )

    if verbose and cfg.environment._local_rank == 0:
        logger.info(f"Number of observations in validation dataset: {len(val_ds)}")

    return val_dataloader


@no_type_check
def cat_batches(
    data: DefaultDict[str, Union[torch.Tensor, np.ndarray]]
) -> DefaultDict[str, Union[torch.Tensor, np.ndarray]]:
    """Concatenates output data from several batches

    Args:
        data: dict with keys and list of batch outputs

    Returns:
        Concatenated dict

    """

    for key, value in data.items():
        if len(value[0].shape) == 0:
            if isinstance(value[0], torch.Tensor):
                data[key] = torch.stack(value)
            else:
                data[key] = np.stack(value)
        else:
            if isinstance(value[0], torch.Tensor):
                data[key] = torch.cat(value, dim=0)
            else:
                data[key] = np.concatenate(value, axis=0)

    return data


class OrderedDistributedSampler(Sampler):
    """
    Sampler that restricts data loading to a subset of the dataset.
    It is especially useful in conjunction with
    :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
    process can pass a DistributedSampler instance as a DataLoader sampler,
    and load a subset of the original dataset that is exclusive to it.
    Source:
    https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/distributed_sampler.py
    """

    def __init__(
        self,
        dataset: Any,
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
    ):
        """
        Args:
            dataset: Dataset used for sampling
            num_replicas: Number of processes participating in distributed training
            rank: Rank of the current process within num_replicas
        """

        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        indices = list(range(len(self.dataset)))

        # add extra samples to make it evenly divisible
        indices += [0] * (self.total_size - len(indices))
        assert len(indices) == self.total_size

        # subsample
        indices = indices[
            self.rank * self.num_samples : self.rank * self.num_samples
            + self.num_samples
        ]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self):
        return self.num_samples


def sample_indices(length: int, n_indices: int = 10, seed: int = 1337) -> np.ndarray:
    """Samples random indices

    Args:
        length: length to sample from
        n_indices: number of indices to sample
        seed: seed for sampling

    Returns:
        sampled indices
    """
    state = np.random.get_state()
    np.random.seed(seed)
    idx = np.random.choice(
        np.arange(length), size=min(length, n_indices), replace=False
    )
    np.random.set_state(state)

    return idx


def get_inference_batch_size(cfg: Any) -> int:
    """Calculates inference batch size

    Args:
        cfg: config with all the hyperparameters
    Returns:
        Inference batch size
    """

    if cfg.prediction.batch_size_inference != 0:
        return cfg.prediction.batch_size_inference
    else:
        return cfg.training.batch_size


def sanity_check(cfg):
    """
    Perform sanity check on the data
    """

    df = read_dataframe_drop_missing_labels(cfg.dataset.train_dataframe, cfg)
    cfg.dataset.dataset_class.sanity_check(df=df, cfg=cfg, mode="train")
    valid_filename = cfg.dataset.validation_dataframe
    if isinstance(valid_filename, str) and os.path.exists(valid_filename):
        df = read_dataframe_drop_missing_labels(valid_filename, cfg)
        cfg.dataset.dataset_class.sanity_check(df=df, cfg=cfg, mode="validation")


def batch_padding(
    cfg: Any,
    batch: Dict,
    training: bool = True,
    mask_key: str = "attention_mask",
    pad_keys: List[str] = ["input_ids", "attention_mask", "special_tokens_mask"],
    padding_side: str = "left",
) -> Dict:
    """Pads a batch according to set quantile, or cuts it at maximum length"""
    if cfg.environment.compile_model:
        # logger.warning("Batch padding not functional with torch compile.")
        return batch
    elif batch[mask_key].sum() == 0:
        # continued pretraining
        return batch
    elif cfg.tokenizer.padding_quantile == 0:
        return batch
    elif training and cfg.tokenizer.padding_quantile < 1.0:
        if padding_side == "left":
            lengths = torch.stack(
                [
                    torch.where(batch[mask_key][i] == 1)[0].min()
                    for i in range(batch[mask_key].size(0))
                ]
            ).float()
            quantile = 1 - cfg.tokenizer.padding_quantile
        else:
            lengths = torch.stack(
                [
                    torch.where(batch[mask_key][i] == 1)[0].max()
                    for i in range(batch[mask_key].size(0))
                ]
            ).float()
            quantile = cfg.tokenizer.padding_quantile
        if cfg.environment._distributed:
            lengths = sync_across_processes(
                lengths, cfg.environment._world_size
            )  # type: ignore
        idx = int(torch.floor(torch.quantile(lengths, quantile)))
    else:
        if padding_side == "left":
            idx = int(torch.where(batch[mask_key] == 1)[1].min())
        else:
            idx = int(torch.where(batch[mask_key] == 1)[1].max())

    if padding_side == "left":
        for key in pad_keys:
            if key in batch:
                batch[key] = batch[key][:, idx:].contiguous()
    else:
        idx += 1
        for key in pad_keys:
            if key in batch:
                batch[key] = batch[key][:, :idx].contiguous()

    return batch