ricomnl commited on
Commit
b24676d
1 Parent(s): 5cb733f

Addressed issues for tokenizer, anndata tokenizer now uses a fraction of memory

Browse files
Files changed (1) hide show
  1. geneformer/tokenizer.py +46 -30
geneformer/tokenizer.py CHANGED
@@ -27,6 +27,7 @@ warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
27
  import anndata as ad
28
  import loompy as lp
29
  import numpy as np
 
30
  from datasets import Dataset
31
 
32
  logger = logging.getLogger(__name__)
@@ -35,6 +36,15 @@ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
35
  TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
36
 
37
 
 
 
 
 
 
 
 
 
 
38
  def tokenize_cell(gene_vector, gene_tokens):
39
  """
40
  Convert normalized gene expression vector to tokenized rank value encoding.
@@ -42,11 +52,8 @@ def tokenize_cell(gene_vector, gene_tokens):
42
  # create array of gene vector with token indices
43
  # mask undetected genes
44
  nonzero_mask = np.nonzero(gene_vector)[0]
45
- # sort by median-scaled gene values
46
- sorted_indices = np.argsort(-gene_vector[nonzero_mask])
47
- # tokenize
48
- sentence_tokens = gene_tokens[nonzero_mask][sorted_indices]
49
- return sentence_tokens
50
 
51
 
52
  class TranscriptomeTokenizer:
@@ -101,6 +108,7 @@ class TranscriptomeTokenizer:
101
  output_directory: Path | str,
102
  output_prefix: str,
103
  file_format: Literal["loom", "h5ad"] = "loom",
 
104
  ):
105
  """
106
  Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
@@ -115,11 +123,13 @@ class TranscriptomeTokenizer:
115
  Prefix for output .dataset
116
  file_format : str
117
  Format of input files. Can be "loom" or "h5ad".
 
 
118
  """
119
  tokenized_cells, cell_metadata = self.tokenize_files(
120
  Path(data_directory), file_format
121
  )
122
- tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata)
123
 
124
  output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
125
  tokenized_dataset.save_to_disk(output_path)
@@ -129,7 +139,7 @@ class TranscriptomeTokenizer:
129
  ):
130
  tokenized_cells = []
131
  if self.custom_attr_name_dict is not None:
132
- loom_cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
133
  cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()}
134
 
135
  # loops through directories to tokenize .loom files
@@ -144,7 +154,7 @@ class TranscriptomeTokenizer:
144
  file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path)
145
  tokenized_cells += file_tokenized_cells
146
  if self.custom_attr_name_dict is not None:
147
- for k in loom_cell_attr:
148
  cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k]
149
  else:
150
  cell_metadata = None
@@ -155,8 +165,8 @@ class TranscriptomeTokenizer:
155
  raise
156
  return tokenized_cells, cell_metadata
157
 
158
- def tokenize_anndata(self, adata_file_path):
159
- adata = ad.read(adata_file_path)
160
  file_cell_metadata = {
161
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
162
  }
@@ -176,7 +186,7 @@ class TranscriptomeTokenizer:
176
  )
177
 
178
  try:
179
- adata.obs["filter_pass"]
180
  except KeyError:
181
  var_exists = False
182
  else:
@@ -193,24 +203,26 @@ class TranscriptomeTokenizer:
193
  filter_pass_loc = np.array([i for i in range(adata.shape[0])])
194
 
195
  tokenized_cells = []
196
- adata_filter = adata[
197
- filter_pass_loc, coding_miRNA_loc # filter cells and genes
198
- ]
199
 
200
- X_norm = (adata_filter.X / adata.X.sum(1) * 10_000 / norm_factor_vector).tocsr()
 
 
 
 
 
201
 
202
- tokenized_cells += [
203
- tokenize_cell(X_norm[i, ...].A.flatten(), coding_miRNA_tokens)
204
- for i in range(X_norm.shape[0])
205
- ]
206
 
207
- # add custom attributes for subview to dict
208
- for k in file_cell_metadata.keys():
209
- file_cell_metadata[k] += adata_filter.obs[k].tolist()
210
 
211
  return tokenized_cells, file_cell_metadata
212
 
213
- def tokenize_file(self, loom_file_path):
214
  if self.custom_attr_name_dict is not None:
215
  file_cell_metadata = {
216
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
@@ -261,7 +273,7 @@ class TranscriptomeTokenizer:
261
  subview_norm_array = (
262
  subview[:, :]
263
  / subview.ca.n_counts
264
- * 10_000
265
  / norm_factor_vector[:, None]
266
  )
267
  # tokenize subview gene vectors
@@ -279,21 +291,25 @@ class TranscriptomeTokenizer:
279
 
280
  return tokenized_cells, file_cell_metadata
281
 
282
- def create_dataset(self, tokenized_cells, cell_metadata):
 
283
  # create dict for dataset creation
284
  dataset_dict = {"input_ids": tokenized_cells}
285
  if self.custom_attr_name_dict is not None:
286
  dataset_dict.update(cell_metadata)
287
 
288
  # create dataset
289
- def dict_generator():
290
- for i in range(len(tokenized_cells)):
291
- yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}
292
- output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
 
 
 
293
 
294
  # truncate dataset
295
  def truncate(example):
296
- example["input_ids"] = example["input_ids"][0:2048]
297
  return example
298
 
299
  output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc)
 
27
  import anndata as ad
28
  import loompy as lp
29
  import numpy as np
30
+ import scipy.sparse as sp
31
  from datasets import Dataset
32
 
33
  logger = logging.getLogger(__name__)
 
36
  TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
37
 
38
 
39
+ def rank_genes(gene_vector, gene_tokens):
40
+ """
41
+ Rank gene expression vector.
42
+ """
43
+ # sort by median-scaled gene values
44
+ sorted_indices = np.argsort(-gene_vector)
45
+ return gene_tokens[sorted_indices]
46
+
47
+
48
  def tokenize_cell(gene_vector, gene_tokens):
49
  """
50
  Convert normalized gene expression vector to tokenized rank value encoding.
 
52
  # create array of gene vector with token indices
53
  # mask undetected genes
54
  nonzero_mask = np.nonzero(gene_vector)[0]
55
+ # rank by median-scaled gene values
56
+ return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
 
 
 
57
 
58
 
59
  class TranscriptomeTokenizer:
 
108
  output_directory: Path | str,
109
  output_prefix: str,
110
  file_format: Literal["loom", "h5ad"] = "loom",
111
+ use_generator: bool = False,
112
  ):
113
  """
114
  Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
 
123
  Prefix for output .dataset
124
  file_format : str
125
  Format of input files. Can be "loom" or "h5ad".
126
+ use_generator : bool
127
+ Whether to use generator or dict for tokenization.
128
  """
129
  tokenized_cells, cell_metadata = self.tokenize_files(
130
  Path(data_directory), file_format
131
  )
132
+ tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata, use_generator=use_generator)
133
 
134
  output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
135
  tokenized_dataset.save_to_disk(output_path)
 
139
  ):
140
  tokenized_cells = []
141
  if self.custom_attr_name_dict is not None:
142
+ cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
143
  cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()}
144
 
145
  # loops through directories to tokenize .loom files
 
154
  file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path)
155
  tokenized_cells += file_tokenized_cells
156
  if self.custom_attr_name_dict is not None:
157
+ for k in cell_attr:
158
  cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k]
159
  else:
160
  cell_metadata = None
 
165
  raise
166
  return tokenized_cells, cell_metadata
167
 
168
+ def tokenize_anndata(self, adata_file_path, target_sum=10_000, chunk_size=512):
169
+ adata = ad.read(adata_file_path, backed="r")
170
  file_cell_metadata = {
171
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
172
  }
 
186
  )
187
 
188
  try:
189
+ _ = adata.obs["filter_pass"]
190
  except KeyError:
191
  var_exists = False
192
  else:
 
203
  filter_pass_loc = np.array([i for i in range(adata.shape[0])])
204
 
205
  tokenized_cells = []
 
 
 
206
 
207
+ for i in range(0, len(filter_pass_loc), chunk_size):
208
+ idx = filter_pass_loc[i:i+chunk_size]
209
+ X = adata[idx].X
210
+
211
+ X_norm = (X / X[:, coding_miRNA_loc].sum(axis=1) * target_sum / norm_factor_vector)
212
+ X_norm = sp.csr_matrix(X_norm)
213
 
214
+ tokenized_cells += [
215
+ rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
216
+ for i in range(X_norm.shape[0])
217
+ ]
218
 
219
+ # add custom attributes for subview to dict
220
+ for k in file_cell_metadata.keys():
221
+ file_cell_metadata[k] += adata[idx].obs[k].tolist()
222
 
223
  return tokenized_cells, file_cell_metadata
224
 
225
+ def tokenize_file(self, loom_file_path, target_sum=10_000):
226
  if self.custom_attr_name_dict is not None:
227
  file_cell_metadata = {
228
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
 
273
  subview_norm_array = (
274
  subview[:, :]
275
  / subview.ca.n_counts
276
+ * target_sum
277
  / norm_factor_vector[:, None]
278
  )
279
  # tokenize subview gene vectors
 
291
 
292
  return tokenized_cells, file_cell_metadata
293
 
294
+ def create_dataset(self, tokenized_cells, cell_metadata, use_generator=False):
295
+ print("Creating dataset...")
296
  # create dict for dataset creation
297
  dataset_dict = {"input_ids": tokenized_cells}
298
  if self.custom_attr_name_dict is not None:
299
  dataset_dict.update(cell_metadata)
300
 
301
  # create dataset
302
+ if use_generator:
303
+ def dict_generator():
304
+ for i in range(len(tokenized_cells)):
305
+ yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}
306
+ output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
307
+ else:
308
+ output_dataset = Dataset.from_dict(dataset_dict)
309
 
310
  # truncate dataset
311
  def truncate(example):
312
+ example["input_ids"] = example["input_ids"][:2048]
313
  return example
314
 
315
  output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc)