ricomnl commited on
Commit
94e8d23
1 Parent(s): b24676d

Fixed issues

Browse files
Files changed (1) hide show
  1. geneformer/tokenizer.py +12 -6
geneformer/tokenizer.py CHANGED
@@ -167,9 +167,11 @@ class TranscriptomeTokenizer:
167
 
168
  def tokenize_anndata(self, adata_file_path, target_sum=10_000, chunk_size=512):
169
  adata = ad.read(adata_file_path, backed="r")
170
- file_cell_metadata = {
171
- attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
172
- }
 
 
173
 
174
  coding_miRNA_loc = np.where(
175
  [self.genelist_dict.get(i, False) for i in adata.var["ensembl_id"]]
@@ -208,7 +210,8 @@ class TranscriptomeTokenizer:
208
  idx = filter_pass_loc[i:i+chunk_size]
209
  X = adata[idx].X
210
 
211
- X_norm = (X / X[:, coding_miRNA_loc].sum(axis=1) * target_sum / norm_factor_vector)
 
212
  X_norm = sp.csr_matrix(X_norm)
213
 
214
  tokenized_cells += [
@@ -217,8 +220,11 @@ class TranscriptomeTokenizer:
217
  ]
218
 
219
  # add custom attributes for subview to dict
220
- for k in file_cell_metadata.keys():
221
- file_cell_metadata[k] += adata[idx].obs[k].tolist()
 
 
 
222
 
223
  return tokenized_cells, file_cell_metadata
224
 
 
167
 
168
  def tokenize_anndata(self, adata_file_path, target_sum=10_000, chunk_size=512):
169
  adata = ad.read(adata_file_path, backed="r")
170
+
171
+ if self.custom_attr_name_dict is not None:
172
+ file_cell_metadata = {
173
+ attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
174
+ }
175
 
176
  coding_miRNA_loc = np.where(
177
  [self.genelist_dict.get(i, False) for i in adata.var["ensembl_id"]]
 
210
  idx = filter_pass_loc[i:i+chunk_size]
211
  X = adata[idx].X
212
 
213
+ X_view = X[:, coding_miRNA_loc]
214
+ X_norm = (X_view / X_view.sum(axis=1) * target_sum / norm_factor_vector)
215
  X_norm = sp.csr_matrix(X_norm)
216
 
217
  tokenized_cells += [
 
220
  ]
221
 
222
  # add custom attributes for subview to dict
223
+ if self.custom_attr_name_dict is not None:
224
+ for k in file_cell_metadata.keys():
225
+ file_cell_metadata[k] += adata[idx].obs[k].tolist()
226
+ else:
227
+ file_cell_metadata = None
228
 
229
  return tokenized_cells, file_cell_metadata
230