# Cell classifier def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_set = Path('geneformer/gene_median_dictionary.pkl'), pretrained_model = ".", dataset = 'Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/', dataset_split = None, filter_cells = .005, epochs = 1, cpu_cores = os.cpu_count(), geneformer_batch_size = 12, optimizer = 'adamw', max_lr = 5e-5, num_gpus = torch.cuda.device_count(), max_input_size = 2 ** 11, lr_schedule_fn = "linear", warmup_steps = 500, freeze_layers = 0, emb_extract = False, max_cells = 1000, emb_layer = 0, emb_filter = None, emb_dir = 'embeddings', overwrite = True, label = "cell_type", data_filter = None, forward_batch = 200, model_location = None, skip_training = False, sample_data = 1, inference = False, optimize_hyperparameters = False, output_dir = None): ''' Primary Parameters ------------------- dataset: path Path to fine-tuning/testing dataset for training model_location: path Path to location of existing model to use for inference and embedding extraction pretrained_model: path Path to pretrained GeneFormer 30M model before fine-tuning inference: bool Chooses whether to perform inference (which causes the function to return the list of similarities). Defaults to False skip_training: bool Chooses whether to skip training the model. Defaults to False emb_extract: bool Choose whether to extract embeddings and calculate similarities. Defaults to True optimize_hyperparameters: bool Choose whether to optimize model hyperparamters. Defaults to False label: string The label string in the formatted dataset that contains true class labels. Defaults to "label" Customization Parameters ------------------- dataset_split: str How the dataset should be partitioned (if at all), and what ID should be used for partitioning data_filter: list (For embeddings and inference) Runs analysis subsets of the dataset by the ID defined by dataset_split label: str What feature should be read as a classification label emb_layer: int What layer embeddings should be extracted and compared from. emb_filter: ['cell1', 'cell2'...] Allows user to narrow down range of cells that embeddings will be extracted from. max_cells: int How many embeddings from cells should be extracted. freeze_layers: int Number of layers should be permanently frozen during fine-tuning (starting from the first layer, 4 brings it up to the pretrained model). sample_data: float What proportion of the HF dataset should be used ''' # Gene Classifier def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_table.csv", genes = "Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv", corpus_30M = "Genecorpus-30M/genecorpus_30M_2048.dataset/", model = '.', max_input_size = 2 ** 11, max_lr = 5e-5, freeze_layers = 4, num_gpus = 1, num_proc = os.cpu_count(), geneformer_batch_size = 9, epochs = 1, filter_dataset = 50_000, emb_extract = True, emb_layer = 0, forward_batch = 200, filter_data = None, inference = False, k_validate = True, model_location = "230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/", skip_training = False, emb_dir = 'gene_emb', output_dir = None, max_cells = 1000, num_cpus = os.cpu_count()): """" Primary Parameters ----------- gene_info: path Path to gene mappings corpus_30M: path Path to 30M Gene Corpus model: path Path to pretrained GeneFormer model genes: path Path to csv file containing different columns of genes and the column labels inference: bool Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False k_validate: bool Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True skip_training: bool Whether the model should skip the training portion. Defaults to False emb_extract: bool WHether the model should extract embeddings for a given gene (WIP) Customization Parameters ----------- freeze_layers: int Freezes x number of layers from the model. Default is 4 (2 non-frozen layers) filter_dataset: int Number of cells to filter from 30M dataset. Default is 50_000 emb_layer: int What layer embeddings are extracted from. Default is 4 filter_data: str, list Filters down embeddings to a single category. Default is None """