from typing import List, Optional MODEL_SELECTION_ID: str = "model_selection" MODEL_VERSION_SELECTION_ID: str = "model_version_selection" LOAD_IN_4_BIT_ID: str = "load_in_4bit" BNB_4BIT_QUANT_TYPE: str = "bnb_4bit_quant_type" BNB_4BIT_COMPUTE_DTYPE: str = "bnb_4bit_compute_dtype" BNB_4BIT_USE_DOUBLE_QUANT: str = "bnb_4bit_use_double_quant" DATASET_SELECTION_ID = "dataset_selection" DATASET_SHUFFLING_SEED = "dataset_seed" FLASH_ATTENTION_ID = "flash_attention" PAD_SIDE_ID = "pad_side" PAD_VALUE_ID = "pad_value" LORA_R_ID = "lora_r" LORA_ALPHA_ID = "lora_alpha" LORA_DROPOUT_ID = "lora_dropout" LORA_BIAS_ID = 'lora_bias' NUM_TRAIN_EPOCHS_ID = "num_train_epochs" MAX_STEPS_ID = "max_steps_id" LOGGING_STEPS_ID = "logging_steps" PER_DEVICE_TRAIN_BATCH_SIZE = "per_device_train_batch_size" SAVE_STRATEGY_ID = "save_strategy" GRADIENT_ACCUMULATION_STEPS_ID = "gradient_accumulation_steps" GRADIENT_CHECKPOINTING_ID = "gradient_checkpointing" LEARNING_RATE_ID = "learning_rate" MAX_GRAD_NORM_ID = "max_grad_norm" WARMUP_RATIO_ID = "warmup_ratio" LR_SCHEDULER_TYPE_ID = "lr_scheduler_type" OUTPUT_DIR_ID = "output_dir" PUSH_TO_HUB_ID = "push_to_hub" REPOSITORY_NAME_ID = "repo_id" REPORT_TO_ID = "report_to" README_ID = "readme" MAX_SEQ_LENGTH_ID = "max_seq_length" PACKING_ID = "packing" OPTIMIZER_ID = "optim" BETA1_ID = "adam_beta1" BETA2_ID = "adam_beta2" EPSILON_ID = "adam_epsilon" WEIGHT_DECAY_ID = "weight_decay" class FTDataSet: def __init__(self, path: str, dataset_split: Optional[str] = None): self.path = path self.dataset_split = dataset_split def __str__(self): return self.path deita_dataset = FTDataSet(path="HuggingFaceH4/deita-10k-v0-sft", dataset_split="train_sft") dolly = FTDataSet(path="philschmid/dolly-15k-oai-style", dataset_split="train") ultrachat_200k = FTDataSet(path="HuggingFaceH4/ultrachat_200k", dataset_split="train_sft") ft_datasets = [deita_dataset, dolly, ultrachat_200k] class Model: def __init__(self, name: str, versions: List[str]): self.name = name self.versions = versions def __str__(self): return self.name models: List[Model] = [] gemma = Model(name="google/gemma", versions=["7b", "2b"]) models.append(gemma) falcon = Model(name="tiiuae/falcon", versions=["7b"]) # "7b-instruct" models.append(falcon) phi = Model(name="microsoft/phi", versions=["1_5", "1", "2"]) models.append(phi) llama = Model(name="meta-llama/Llama-2", versions=["7b", "7b-hf"]) # "7b-chat", "7b-chat-hf" models.append(llama) mistral = Model(name="mistralai/Mistral", versions=["7B-v0.1"]) # "7B-Instruct-v0.1" models.append(mistral) tinyLlama = Model(name="TinyLlama/TinyLlama-1.1B", versions=['intermediate-step-1431k-3T', 'step-50K-105b', 'intermediate-step-240k-503b', 'intermediate-step-715k-1.5T', 'intermediate-step-1195k-token-2.5T']) models.append(tinyLlama)