# UniMax Language Dataset Sampler with DDP support This repository contains an unofficial implementation of the UNIMAX sampling algorithm using PyTorch. The UNIMAX algorithm ["UniMax: Fairer and more Effective Language Sampling for Large-Scale Multilingual Pretraining" by HW Chung et al. (ICLR 2023)](https://arxiv.org/abs/2304.09151) is used to generate a sampling distribution of languages based on their character counts, a total character budget, and a specified number of epochs per language. This can be useful for training language models on datasets with imbalanced language distribution. ## Contents 1. `unimax_sampler.py`: This Python file contains the `UnimaxSampler` class, a PyTorch `Sampler` that uses the UNIMAX algorithm. 2. `test_unimax_sampler.py`: This Python file contains a unit test for the `UnimaxSampler` class to ensure its correct functionality. ## Usage ```python from torch.utils.data import Dataset, DataLoader from unimax_sampler import UnimaxSampler # Define your parameters language_character_counts = [100, 200, 300, 400, 500] total_character_budget = 1000 num_epochs = 2 # Create the UnimaxSampler unimax_sampler = UnimaxSampler(language_character_counts, total_character_budget, num_epochs) ``` Then, use the sampler as the sampler argument when creating a DataLoader. ```python # Disable shuffle when using custom sampler... data_loader = DataLoader(my_dataset, batch_size=2, shuffle=None, sampler=unimax_sampler) ``` For DDP, ```python if torch.distributed.is_initialized(): sampler = DistributedUnimaxSampler(...) else: return unimax_sampler(...) ``` ## Note The initial version of this code was created by [Chat GPT-4](https://chat.openai.com/), based on the pseudocode provided in the [UNIMAX](https://arxiv.org/abs/2304.09151) paper. Subsequently, the code was manually revised for `PyTorch` Distributed Data Parallel ([DDP](https://pytorch.org/docs/stable/notes/ddp.html)) framework. The DistributedSamplerWrapper implementation is derived from an earlier version found in the [Catalyst](https://github.com/catalyst-team/catalyst) project. ## License This project is licensed under the MIT License.