hchen725 commited on
Commit
fb901a0
1 Parent(s): b07f4b1

Add function for summing of Ensembl IDs

Browse files
Files changed (1) hide show
  1. geneformer/tokenizer.py +135 -4
geneformer/tokenizer.py CHANGED
@@ -36,14 +36,21 @@ Geneformer tokenizer.
36
 
37
  from __future__ import annotations
38
 
 
39
  import logging
40
  import pickle
 
41
  import warnings
42
  from pathlib import Path
43
  from typing import Literal
 
 
44
 
45
- import anndata as ad
46
  import numpy as np
 
 
 
 
47
  import scipy.sparse as sp
48
  from datasets import Dataset
49
 
@@ -52,7 +59,7 @@ import loompy as lp # noqa
52
 
53
  logger = logging.getLogger(__name__)
54
 
55
- from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
56
 
57
 
58
  def rank_genes(gene_vector, gene_tokens):
@@ -74,6 +81,115 @@ def tokenize_cell(gene_vector, gene_tokens):
74
  # rank by median-scaled gene values
75
  return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  class TranscriptomeTokenizer:
79
  def __init__(
@@ -85,6 +201,7 @@ class TranscriptomeTokenizer:
85
  special_token=False,
86
  gene_median_file=GENE_MEDIAN_FILE,
87
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
 
88
  ):
89
  """
90
  Initialize tokenizer.
@@ -103,11 +220,15 @@ class TranscriptomeTokenizer:
103
  | Max input size of model to truncate input to.
104
  special_token : bool = False
105
  | Adds CLS token before and EOS token after rank value encoding.
 
 
106
  gene_median_file : Path
107
  | Path to pickle file containing dictionary of non-zero median
108
  | gene expression values across Genecorpus-30M.
109
  token_dictionary_file : Path
110
  | Path to pickle file containing token dictionary (Ensembl IDs:token).
 
 
111
 
112
  """
113
  # dictionary of custom attributes {output dataset column name: input .loom column name}
@@ -134,6 +255,10 @@ class TranscriptomeTokenizer:
134
  with open(token_dictionary_file, "rb") as f:
135
  self.gene_token_dict = pickle.load(f)
136
 
 
 
 
 
137
  # gene keys for full vocabulary
138
  self.gene_keys = list(self.gene_token_dict.keys())
139
 
@@ -214,7 +339,7 @@ class TranscriptomeTokenizer:
214
  return tokenized_cells, cell_metadata
215
 
216
  def tokenize_anndata(self, adata_file_path, target_sum=10_000):
217
- adata = ad.read(adata_file_path, backed="r")
218
 
219
  if self.custom_attr_name_dict is not None:
220
  file_cell_metadata = {
@@ -256,7 +381,8 @@ class TranscriptomeTokenizer:
256
  idx = filter_pass_loc[i : i + self.chunk_size]
257
 
258
  n_counts = adata[idx].obs["n_counts"].values[:, None]
259
- X_view = adata[idx, coding_miRNA_loc].X
 
260
  X_norm = X_view / n_counts * target_sum / norm_factor_vector
261
  X_norm = sp.csr_matrix(X_norm)
262
 
@@ -280,6 +406,8 @@ class TranscriptomeTokenizer:
280
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
281
  }
282
 
 
 
283
  with lp.connect(str(loom_file_path)) as data:
284
  # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
285
  coding_miRNA_loc = np.where(
@@ -341,6 +469,9 @@ class TranscriptomeTokenizer:
341
  else:
342
  file_cell_metadata = None
343
 
 
 
 
344
  return tokenized_cells, file_cell_metadata
345
 
346
  def create_dataset(
 
36
 
37
  from __future__ import annotations
38
 
39
+ import os
40
  import logging
41
  import pickle
42
+ import sys
43
  import warnings
44
  from pathlib import Path
45
  from typing import Literal
46
+ from tqdm import tqdm
47
+ from collections import Counter
48
 
 
49
  import numpy as np
50
+ import scanpy as sc
51
+ import loompy as lp
52
+ import pandas as pd
53
+ import anndata as ad
54
  import scipy.sparse as sp
55
  from datasets import Dataset
56
 
 
59
 
60
  logger = logging.getLogger(__name__)
61
 
62
+ from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE, ENSEMBL_MAPPING_FILE
63
 
64
 
65
  def rank_genes(gene_vector, gene_tokens):
 
81
  # rank by median-scaled gene values
82
  return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
83
 
84
+ def sum_ensembl_ids(data_directory,
85
+ gene_mapping_dict,
86
+ file_format = "loom",
87
+ chunk_size = 512):
88
+ if file_format == "loom":
89
+ """
90
+ Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
91
+ """
92
+ with lp.connect(data_directory) as data:
93
+ assert "ensembl_id" in data.ra.keys(), "'ensembl_id' column missing from data.ra.keys()"
94
+ gene_ids_collapsed = [gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id]
95
+
96
+ if len(set(gene_ids_collapsed)) == len(set(data.ra.ensembl_id)):
97
+ return data_directory
98
+
99
+ else:
100
+ dedup_filename = data_directory.with_name(data_directory.stem + "__dedup.loom")
101
+ dup_genes = [idx for idx, count in Counter(data.ra["ensembl_id"]).items() if count > 1]
102
+ num_chunks = int(np.ceil(data.shape[1] / chunk_size))
103
+ first_chunk = True
104
+ for _, _, view in tqdm(data.scan(axis = 1, batch_size = chunk_size), total = num_chunks):
105
+ def process_chunk(view, duplic_genes):
106
+ data_count_view = pd.DataFrame(view, index=data.ra["ensembl_id"])
107
+ unique_data_df = data_count_view.loc[~data_count_view.index.isin(duplic_genes)]
108
+ dup_data_df = data_count_view.loc[data_count_view.index.isin(duplic_genes)]
109
+ summed_data = dup_data_df.groupby(dup_data_df.index).sum()
110
+ if not summed_data.index.is_unique:
111
+ raise ValueError("Error: summed data frame non-unique.")
112
+ data_count_view = pd.concat([unique_data_df, summed_data], axis=0)
113
+ if not data_count_view.index.is_unique:
114
+ raise ValueError("Error: final data frame non-unique.")
115
+ return data_count_view
116
+ processed_chunk = process_chunk(view[:, :], dup_genes)
117
+ processed_array = processed_chunk.to_numpy()
118
+ new_row_attrs = {"ensembl_id": processed_chunk.index.to_numpy()}
119
+
120
+ ra_keys = [k for k in data.ra.keys() if k != "ensembl_id"]
121
+ for ra_value in ra_keys:
122
+ mapping_dict = dict(zip(data.ra["ensembl_id"], data.ra[ra_value]))
123
+ values_new = [mapping_dict[i] for i in processed_chunk.index]
124
+ new_row_attrs[ra_value] = np.array(values_new)
125
+
126
+ if "n_counts" not in view.ca.keys():
127
+ total_count_view = np.sum(view[:,:], axis=0).astype(int)
128
+ view.ca["n_counts"] = total_count_view
129
+
130
+ if first_chunk: # Create the Loom file with the first chunk
131
+ lp.create(f"{dedup_filename}", processed_array, row_attrs=new_row_attrs, col_attrs=view.ca)
132
+ first_chunk = False
133
+ else: # Append subsequent chunks
134
+ with lp.connect(dedup_filename, mode='r+') as dsout:
135
+ dsout.add_columns(processed_array, col_attrs=view.ca)
136
+ return dedup_filename
137
+
138
+ elif file_format == "h5ad":
139
+ """
140
+ Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
141
+ Returns adata object with deduplicated Ensembl IDs.
142
+ """
143
+
144
+ data = sc.read_h5ad(str(data_directory))
145
+
146
+ assert "ensembl_id" in data.var.columns, "'ensembl_id' column missing from data.var"
147
+
148
+ gene_ids_collapsed = [gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id]
149
+
150
+ if len(set(gene_ids_collapsed)) == len(set(data.var.ensembl_id)):
151
+ return data
152
+
153
+ else:
154
+ data.var["gene_ids_collapsed"] = gene_ids_collapsed
155
+ data.var_names = gene_ids_collapsed
156
+ data = data[:, ~data.var.index.isna()]
157
+ dup_genes = [idx for idx, count in Counter(data.var_names).items() if count > 1]
158
+
159
+ num_chunks = int(np.ceil(data.shape[0] / chunk_size))
160
+
161
+ processed_genes = []
162
+ for i in tqdm(range(num_chunks)):
163
+
164
+ start_idx = i * chunk_size
165
+ end_idx = min((i + 1) * chunk_size, data.shape[0])
166
+ data_chunk = data[start_idx:end_idx, :]
167
+
168
+ processed_chunks = []
169
+ for dup_gene in dup_genes:
170
+ data_dup_gene = data_chunk[:, data_chunk.var_names == dup_gene]
171
+ df = pd.DataFrame.sparse.from_spmatrix(data_dup_gene.X,
172
+ index=data_dup_gene.obs_names,
173
+ columns=data_dup_gene.var_names)
174
+ df_sum = pd.DataFrame(df.sum(axis=1))
175
+ df_sum.columns = [dup_gene]
176
+ df_sum.index = data_dup_gene.obs.index
177
+ processed_chunks.append(df_sum)
178
+
179
+ processed_chunks = pd.concat(processed_chunks, axis=1)
180
+ processed_genes.append(processed_chunks)
181
+ processed_genes = pd.concat(processed_genes, axis = 0)
182
+ var_df = pd.DataFrame({"gene_ids_collapsed" : processed_genes.columns})
183
+ var_df.index = processed_genes.columns
184
+ processed_genes = sc.AnnData(X = processed_genes,
185
+ obs = data.obs,
186
+ var = var_df)
187
+
188
+ data_dedup = data[:, ~data.var.index.isin(dup_genes)] # Deduplicated data
189
+ data_dedup = sc.concat([data_dedup, processed_genes], axis = 1)
190
+ data_dedup.obs = data.obs
191
+ data_dedup.var = data_dedup.var.rename(columns = {"gene_ids_collapsed" : "ensembl_id"})
192
+ return data_dedup
193
 
194
  class TranscriptomeTokenizer:
195
  def __init__(
 
201
  special_token=False,
202
  gene_median_file=GENE_MEDIAN_FILE,
203
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
204
+ gene_mapping_file=ENSEMBL_MAPPING_FILE,
205
  ):
206
  """
207
  Initialize tokenizer.
 
220
  | Max input size of model to truncate input to.
221
  special_token : bool = False
222
  | Adds CLS token before and EOS token after rank value encoding.
223
+ collapse_gene_ids : bool = False
224
+ | Whether to collapse gene IDs based on gene mapping dictionary.
225
  gene_median_file : Path
226
  | Path to pickle file containing dictionary of non-zero median
227
  | gene expression values across Genecorpus-30M.
228
  token_dictionary_file : Path
229
  | Path to pickle file containing token dictionary (Ensembl IDs:token).
230
+ gene_mapping_file : Path
231
+ | Path to pickle file containing dictionary for collapsing gene IDs.
232
 
233
  """
234
  # dictionary of custom attributes {output dataset column name: input .loom column name}
 
255
  with open(token_dictionary_file, "rb") as f:
256
  self.gene_token_dict = pickle.load(f)
257
 
258
+ # load gene mappings dictionary (Ensembl IDs:Ensembl ID)
259
+ with open(gene_mapping_file, "rb") as f:
260
+ self.gene_mapping_dict = pickle.load(f)
261
+
262
  # gene keys for full vocabulary
263
  self.gene_keys = list(self.gene_token_dict.keys())
264
 
 
339
  return tokenized_cells, cell_metadata
340
 
341
  def tokenize_anndata(self, adata_file_path, target_sum=10_000):
342
+ adata = sum_ensembl_ids(adata_file_path, self.gene_mapping_dict, file_format = "h5ad", chunk_size = self.chunk_size)
343
 
344
  if self.custom_attr_name_dict is not None:
345
  file_cell_metadata = {
 
381
  idx = filter_pass_loc[i : i + self.chunk_size]
382
 
383
  n_counts = adata[idx].obs["n_counts"].values[:, None]
384
+ X_view0 = adata[idx,:].X
385
+ X_view = X_view0[:, coding_miRNA_loc]
386
  X_norm = X_view / n_counts * target_sum / norm_factor_vector
387
  X_norm = sp.csr_matrix(X_norm)
388
 
 
406
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
407
  }
408
 
409
+ loom_file_path = sum_ensembl_ids(loom_file_path, self.gene_mapping_dict, file_format = "loom", chunk_size = self.chunk_size)
410
+
411
  with lp.connect(str(loom_file_path)) as data:
412
  # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
413
  coding_miRNA_loc = np.where(
 
469
  else:
470
  file_cell_metadata = None
471
 
472
+ if "__dedup" in str(loom_file_path):
473
+ os.remove(str(loom_file_path))
474
+
475
  return tokenized_cells, file_cell_metadata
476
 
477
  def create_dataset(