Error Generating State Embeddings Dictionary
Hello,
I am having an issue with the last part of the embedding step. I keep getting a key error with the state_key when running the embex.get_state_embs() function. The code block and the error output is below.
cell_states_to_model={"state_key": "cell_type",
"start_state": "P",
"goal_state": "B",
"alt_states": ["A", "Other"]}
state_embs_dict = embex.get_state_embs(cell_states_to_model,
"Geneformer/geneformer-12L-30M",
"Geneformer_outputs/240709111232/cm_classifier_test_labeled_test.dataset",
output_dir,
output_prefix)
Here is the full traceback output from this code block:
RemoteTraceback Traceback (most recent call last)
RemoteTraceback:
"""
Traceback (most recent call last):
File "/mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/multiprocess/pool.py", line 125, in worker
result = (True, func(*args, **kwds))
File "/mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/utils/py_utils.py", line 678, in _write_generator_to_queue
for i, result in enumerate(func(**kwargs)):
File "/mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3552, in _map_single
batch = apply_function_on_filtered_inputs(
File "/mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3421, in apply_function_on_filtered_inputs
processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
File "/mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 6486, in get_indices_from_mask_function
mask.append(function(example, *additional_args, **fn_kwargs))
File "/mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/geneformer/perturber_utils.py", line 43, in filter_data_by_criteria
return example[key] in value
KeyError: 'cell_type'
"""
The above exception was the direct cause of the following exception:
KeyError Traceback (most recent call last)
Cell In[80], line 1
----> 1 state_embs_dict = embex.get_state_embs(cell_states_to_model,
2 "Geneformer/geneformer-12L-30M",
3 "Geneformer_outputs/240709111232/cm_classifier_test_labeled_test.dataset",
4 output_dir,
5 output_prefix)
File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/geneformer/emb_extractor.py:679, in EmbExtractor.get_state_embs(self, cell_states_to_model, model_directory, input_data_file, output_directory, output_prefix, output_torch_embs)
677 continue
678 elif (k == "start_state") or (k == "goal_state"):
--> 679 state_embs_dict[v] = self.extract_embs(
680 model_directory,
681 input_data_file,
682 output_directory,
683 output_prefix,
684 output_torch_embs,
685 cell_state={state_key: v},
686 )
687 else: # k == "alt_states"
688 for alt_state in v:
File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/geneformer/emb_extractor.py:558, in EmbExtractor.extract_embs(self, model_directory, input_data_file, output_directory, output_prefix, output_torch_embs, cell_state)
554 filtered_input_data = pu.load_and_filter(
555 self.filter_data, self.nproc, input_data_file
556 )
557 if cell_state is not None:
--> 558 filtered_input_data = pu.filter_by_dict(
559 filtered_input_data, cell_state, self.nproc
560 )
561 downsampled_data = pu.downsample_and_sort(filtered_input_data, self.max_ncells)
562 model = pu.load_model(
563 self.model_type, self.num_classes, model_directory, mode="eval"
564 )
File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/geneformer/perturber_utils.py:45, in filter_by_dict(data, filter_data, nproc)
42 def filter_data_by_criteria(example):
43 return example[key] in value
---> 45 data = data.filter(filter_data_by_criteria, num_proc=nproc)
46 if len(data) == 0:
47 logger.error("No cells remain after filtering. Check filtering criteria.")
File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/arrow_dataset.py:567, in transmit_format.<locals>.wrapper(*args, **kwargs)
560 self_format = {
561 "type": self._format_type,
562 "format_kwargs": self._format_kwargs,
563 "columns": self._format_columns,
564 "output_all_columns": self._output_all_columns,
565 }
566 # apply actual function
--> 567 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
568 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
569 # re-apply format to the output
File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/fingerprint.py:482, in fingerprint_transform.<locals>._fingerprint.<locals>.wrapper(*args, **kwargs)
478 validate_fingerprint(kwargs[fingerprint_name])
480 # Call actual function
--> 482 out = func(dataset, *args, **kwargs)
484 # Update fingerprint of in-place transforms + update in-place history of transforms
486 if inplace: # update after calling func so that the fingerprint doesn't change if the function fails
File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/arrow_dataset.py:3714, in Dataset.filter(self, function, with_indices, with_rank, input_columns, batched, batch_size, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
3711 if len(self) == 0:
3712 return self
-> 3714 indices = self.map(
3715 function=partial(
3716 get_indices_from_mask_function,
3717 function,
3718 batched,
3719 with_indices,
3720 with_rank,
3721 input_columns,
3722 self._indices,
3723 ),
3724 with_indices=True,
3725 with_rank=True,
3726 features=Features({"indices": Value("uint64")}),
3727 batched=True,
3728 batch_size=batch_size,
3729 remove_columns=self.column_names,
3730 keep_in_memory=keep_in_memory,
3731 load_from_cache_file=load_from_cache_file,
3732 cache_file_name=cache_file_name,
3733 writer_batch_size=writer_batch_size,
3734 fn_kwargs=fn_kwargs,
3735 num_proc=num_proc,
3736 suffix_template=suffix_template,
3737 new_fingerprint=new_fingerprint,
3738 input_columns=input_columns,
3739 desc=desc or "Filter",
3740 )
3741 new_dataset = copy.deepcopy(self)
3742 new_dataset._indices = indices.data
File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/arrow_dataset.py:602, in transmit_tasks.<locals>.wrapper(*args, **kwargs)
600 self: "Dataset" = kwargs.pop("self")
601 # apply actual function
--> 602 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
603 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
604 for dataset in datasets:
605 # Remove task templates if a column mapping of the template is no longer valid
File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/arrow_dataset.py:567, in transmit_format.<locals>.wrapper(*args, **kwargs)
560 self_format = {
561 "type": self._format_type,
562 "format_kwargs": self._format_kwargs,
563 "columns": self._format_columns,
564 "output_all_columns": self._output_all_columns,
565 }
566 # apply actual function
--> 567 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
568 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
569 # re-apply format to the output
File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/arrow_dataset.py:3253, in Dataset.map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
3247 logger.info(f"Spawning {num_proc} processes")
3248 with hf_tqdm(
3249 unit=" examples",
3250 total=pbar_total,
3251 desc=(desc or "Map") + f" (num_proc={num_proc})",
3252 ) as pbar:
-> 3253 for rank, done, content in iflatmap_unordered(
3254 pool, Dataset._map_single, kwargs_iterable=kwargs_per_job
3255 ):
3256 if done:
3257 shards_done += 1
File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/utils/py_utils.py:718, in iflatmap_unordered(pool, func, kwargs_iterable)
715 finally:
716 if not pool_changed:
717 # we get the result in case there's an error to raise
--> 718 [async_result.get(timeout=0.05) for async_result in async_results]
File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/datasets/utils/py_utils.py:718, in <listcomp>(.0)
715 finally:
716 if not pool_changed:
717 # we get the result in case there's an error to raise
--> 718 [async_result.get(timeout=0.05) for async_result in async_results]
File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/multiprocess/pool.py:774, in ApplyResult.get(self, timeout)
772 return self._value
773 else:
--> 774 raise self._value
KeyError: 'cell_type'
Using a different state key gives the same error so it seems that the key name in the input data at "Geneformer_outputs/240709111232/cm_classifier_test_labeled_test.dataset" is different from what's expected. I've also double checked the rest of my code and the key "cell_type" is used in the previous steps. To further indicate that there might be an issue with the data stored there is that it seems that the .arrow file located in that directory is not actually an .arrow file. The full traceback from that is below.
---------------------------------------------------------------------------
ArrowInvalid Traceback (most recent call last)
Cell In[6], line 9
7 # Open the .arrow file
8 with pa.memory_map(file_path, 'r') as source:
----> 9 reader = ipc.RecordBatchFileReader(source)
10 table = reader.read_all()
12 # Convert to pandas DataFrame for easier inspection
File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/pyarrow/ipc.py:110, in RecordBatchFileReader.__init__(self, source, footer_offset, options, memory_pool)
107 def __init__(self, source, footer_offset=None, *, options=None,
108 memory_pool=None):
109 options = _ensure_default_ipc_read_options(options)
--> 110 self._open(source, footer_offset=footer_offset,
111 options=options, memory_pool=memory_pool)
File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/pyarrow/ipc.pxi:1085, in pyarrow.lib._RecordBatchFileReader._open()
File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/pyarrow/error.pxi:154, in pyarrow.lib.pyarrow_internal_check_status()
File /mnt/scratch/Public/caleb/Miniconda3/miniconda3/envs/Geneformer5/lib/python3.10/site-packages/pyarrow/error.pxi:91, in pyarrow.lib.check_status()
ArrowInvalid: Not an Arrow file
Any help with solving these errors or where to look would be greatly appreciated.
Thanks
Thank you for your questions! Firstly, the datasets are in the Hugging Face dataset format - please see their documentation here. To open a file:
from datasets import load_from_disk
data = load_from_disk("/path/to/.dataset")
Secondly, if you are using a .dataset file generated by the classifier module, the column to be used for classification will be changed to "label" for cell classification (or "labels" for gene classification labels). You can check this by loading the dataset as above. If it is now "label", you can provide that label to downstream applications (or if you do not intend to use the .dataset for training anymore, you can just rename it back to "cell_type").