Error Generating State Embeddings Dictionary

#370
by cparsonage - opened

Hello,

I am having an issue with the last part of the embedding step. I keep getting a key error with the state_key when running the embex.get_state_embs() function. The code block and the error output is below.

cell_states_to_model={"state_key": "cell_type", 
                      "start_state": "P", 
                      "goal_state": "B", 
                      "alt_states": ["A", "Other"]}

state_embs_dict = embex.get_state_embs(cell_states_to_model,
                                       "Geneformer/geneformer-12L-30M",
                                       "Geneformer_outputs/240709111232/cm_classifier_test_labeled_test.dataset",
                                       output_dir,
                                       output_prefix) 

Here is the full traceback output from this code block:

RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/multiprocess/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
  File "/mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/utils/py_utils.py", line 678, in _write_generator_to_queue
    for i, result in enumerate(func(**kwargs)):
  File "/mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3552, in _map_single
    batch = apply_function_on_filtered_inputs(
  File "/mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3421, in apply_function_on_filtered_inputs
    processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
  File "/mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 6486, in get_indices_from_mask_function
    mask.append(function(example, *additional_args, **fn_kwargs))
  File "/mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/geneformer/perturber_utils.py", line 43, in filter_data_by_criteria
    return example[key] in value
KeyError: 'cell_type'
"""

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
Cell In[80], line 1
----> 1 state_embs_dict = embex.get_state_embs(cell_states_to_model,
      2                                        "Geneformer/geneformer-12L-30M",
      3                                        "Geneformer_outputs/240709111232/cm_classifier_test_labeled_test.dataset",
      4                                        output_dir,
      5                                        output_prefix)

File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/geneformer/emb_extractor.py:679, in EmbExtractor.get_state_embs(self, cell_states_to_model, model_directory, input_data_file, output_directory, output_prefix, output_torch_embs)
    677     continue
    678 elif (k == "start_state") or (k == "goal_state"):
--> 679     state_embs_dict[v] = self.extract_embs(
    680         model_directory,
    681         input_data_file,
    682         output_directory,
    683         output_prefix,
    684         output_torch_embs,
    685         cell_state={state_key: v},
    686     )
    687 else:  # k == "alt_states"
    688     for alt_state in v:

File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/geneformer/emb_extractor.py:558, in EmbExtractor.extract_embs(self, model_directory, input_data_file, output_directory, output_prefix, output_torch_embs, cell_state)
    554 filtered_input_data = pu.load_and_filter(
    555     self.filter_data, self.nproc, input_data_file
    556 )
    557 if cell_state is not None:
--> 558     filtered_input_data = pu.filter_by_dict(
    559         filtered_input_data, cell_state, self.nproc
    560     )
    561 downsampled_data = pu.downsample_and_sort(filtered_input_data, self.max_ncells)
    562 model = pu.load_model(
    563     self.model_type, self.num_classes, model_directory, mode="eval"
    564 )

File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/geneformer/perturber_utils.py:45, in filter_by_dict(data, filter_data, nproc)
     42     def filter_data_by_criteria(example):
     43         return example[key] in value
---> 45     data = data.filter(filter_data_by_criteria, num_proc=nproc)
     46 if len(data) == 0:
     47     logger.error("No cells remain after filtering. Check filtering criteria.")

File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/arrow_dataset.py:567, in transmit_format.<locals>.wrapper(*args, **kwargs)
    560 self_format = {
    561     "type": self._format_type,
    562     "format_kwargs": self._format_kwargs,
    563     "columns": self._format_columns,
    564     "output_all_columns": self._output_all_columns,
    565 }
    566 # apply actual function
--> 567 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    568 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    569 # re-apply format to the output

File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/fingerprint.py:482, in fingerprint_transform.<locals>._fingerprint.<locals>.wrapper(*args, **kwargs)
    478             validate_fingerprint(kwargs[fingerprint_name])
    480 # Call actual function
--> 482 out = func(dataset, *args, **kwargs)
    484 # Update fingerprint of in-place transforms + update in-place history of transforms
    486 if inplace:  # update after calling func so that the fingerprint doesn't change if the function fails

File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/arrow_dataset.py:3714, in Dataset.filter(self, function, with_indices, with_rank, input_columns, batched, batch_size, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
   3711 if len(self) == 0:
   3712     return self
-> 3714 indices = self.map(
   3715     function=partial(
   3716         get_indices_from_mask_function,
   3717         function,
   3718         batched,
   3719         with_indices,
   3720         with_rank,
   3721         input_columns,
   3722         self._indices,
   3723     ),
   3724     with_indices=True,
   3725     with_rank=True,
   3726     features=Features({"indices": Value("uint64")}),
   3727     batched=True,
   3728     batch_size=batch_size,
   3729     remove_columns=self.column_names,
   3730     keep_in_memory=keep_in_memory,
   3731     load_from_cache_file=load_from_cache_file,
   3732     cache_file_name=cache_file_name,
   3733     writer_batch_size=writer_batch_size,
   3734     fn_kwargs=fn_kwargs,
   3735     num_proc=num_proc,
   3736     suffix_template=suffix_template,
   3737     new_fingerprint=new_fingerprint,
   3738     input_columns=input_columns,
   3739     desc=desc or "Filter",
   3740 )
   3741 new_dataset = copy.deepcopy(self)
   3742 new_dataset._indices = indices.data

File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/arrow_dataset.py:602, in transmit_tasks.<locals>.wrapper(*args, **kwargs)
    600     self: "Dataset" = kwargs.pop("self")
    601 # apply actual function
--> 602 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    603 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    604 for dataset in datasets:
    605     # Remove task templates if a column mapping of the template is no longer valid

File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/arrow_dataset.py:567, in transmit_format.<locals>.wrapper(*args, **kwargs)
    560 self_format = {
    561     "type": self._format_type,
    562     "format_kwargs": self._format_kwargs,
    563     "columns": self._format_columns,
    564     "output_all_columns": self._output_all_columns,
    565 }
    566 # apply actual function
--> 567 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    568 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    569 # re-apply format to the output

File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/arrow_dataset.py:3253, in Dataset.map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
   3247 logger.info(f"Spawning {num_proc} processes")
   3248 with hf_tqdm(
   3249     unit=" examples",
   3250     total=pbar_total,
   3251     desc=(desc or "Map") + f" (num_proc={num_proc})",
   3252 ) as pbar:
-> 3253     for rank, done, content in iflatmap_unordered(
   3254         pool, Dataset._map_single, kwargs_iterable=kwargs_per_job
   3255     ):
   3256         if done:
   3257             shards_done += 1

File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/utils/py_utils.py:718, in iflatmap_unordered(pool, func, kwargs_iterable)
    715 finally:
    716     if not pool_changed:
    717         # we get the result in case there's an error to raise
--> 718         [async_result.get(timeout=0.05) for async_result in async_results]

File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/utils/py_utils.py:718, in <listcomp>(.0)
    715 finally:
    716     if not pool_changed:
    717         # we get the result in case there's an error to raise
--> 718         [async_result.get(timeout=0.05) for async_result in async_results]

File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/multiprocess/pool.py:774, in ApplyResult.get(self, timeout)
    772     return self._value
    773 else:
--> 774     raise self._value

KeyError: 'cell_type'

Using a different state key gives the same error so it seems that the key name in the input data at "Geneformer_outputs/240709111232/cm_classifier_test_labeled_test.dataset" is different from what's expected. I've also double checked the rest of my code and the key "cell_type" is used in the previous steps. To further indicate that there might be an issue with the data stored there is that it seems that the .arrow file located in that directory is not actually an .arrow file. The full traceback from that is below.

---------------------------------------------------------------------------
ArrowInvalid                              Traceback (most recent call last)
Cell In[6], line 9
      7 # Open the .arrow file
      8 with pa.memory_map(file_path, 'r') as source:
----> 9     reader = ipc.RecordBatchFileReader(source)
     10     table = reader.read_all()
     12 # Convert to pandas DataFrame for easier inspection

File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/pyarrow/ipc.py:110, in RecordBatchFileReader.__init__(self, source, footer_offset, options, memory_pool)
    107 def __init__(self, source, footer_offset=None, *, options=None,
    108              memory_pool=None):
    109     options = _ensure_default_ipc_read_options(options)
--> 110     self._open(source, footer_offset=footer_offset,
    111                options=options, memory_pool=memory_pool)

File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/pyarrow/ipc.pxi:1085, in pyarrow.lib._RecordBatchFileReader._open()

File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/pyarrow/error.pxi:154, in pyarrow.lib.pyarrow_internal_check_status()

File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/pyarrow/error.pxi:91, in pyarrow.lib.check_status()

ArrowInvalid: Not an Arrow file

Any help with solving these errors or where to look would be greatly appreciated.

Thanks

Thank you for your questions! Firstly, the datasets are in the Hugging Face dataset format - please see their documentation here. To open a file:

from datasets import load_from_disk
data = load_from_disk("/path/to/.dataset")

Secondly, if you are using a .dataset file generated by the classifier module, the column to be used for classification will be changed to "label" for cell classification (or "labels" for gene classification labels). You can check this by loading the dataset as above. If it is now "label", you can provide that label to downstream applications (or if you do not intend to use the .dataset for training anymore, you can just rename it back to "cell_type").

ctheodoris changed discussion status to closed

Sign up or log in to comment