Update geneformer/perturber_utils.py

#354
Files changed (1) hide show
  1. geneformer/perturber_utils.py +16 -7
geneformer/perturber_utils.py CHANGED
@@ -218,26 +218,35 @@ def delete_indices(example):
218
 
219
 
220
  # for genes_to_perturb = "all" where only genes within cell are overexpressed
221
- def overexpress_indices(example):
222
  indices = example["perturb_index"]
223
  if any(isinstance(el, list) for el in indices):
224
  indices = flatten_list(indices)
225
  for index in sorted(indices, reverse=True):
226
- example["input_ids"].insert(0, example["input_ids"].pop(index))
 
 
 
227
 
228
  example["length"] = len(example["input_ids"])
229
  return example
230
 
231
 
232
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
233
- def overexpress_tokens(example, max_len):
234
  # -100 indicates tokens to overexpress are not present in rank value encoding
235
  if example["perturb_index"] != [-100]:
236
  example = delete_indices(example)
237
- [
238
- example["input_ids"].insert(0, token)
239
- for token in example["tokens_to_perturb"][::-1]
240
- ]
 
 
 
 
 
 
241
 
242
  # truncate to max input size, must also truncate original emb to be comparable
243
  if len(example["input_ids"]) > max_len:
 
218
 
219
 
220
  # for genes_to_perturb = "all" where only genes within cell are overexpressed
221
+ def overexpress_indices(example, special_token):
222
  indices = example["perturb_index"]
223
  if any(isinstance(el, list) for el in indices):
224
  indices = flatten_list(indices)
225
  for index in sorted(indices, reverse=True):
226
+ if special_token:
227
+ example["input_ids"].insert(1, example["input_ids"].pop(index))
228
+ else:
229
+ example["input_ids"].insert(0, example["input_ids"].pop(index))
230
 
231
  example["length"] = len(example["input_ids"])
232
  return example
233
 
234
 
235
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
236
+ def overexpress_tokens(example, max_len, special_token):
237
  # -100 indicates tokens to overexpress are not present in rank value encoding
238
  if example["perturb_index"] != [-100]:
239
  example = delete_indices(example)
240
+ if special_token:
241
+ [
242
+ example["input_ids"].insert(1, token)
243
+ for token in example["tokens_to_perturb"][::-1]
244
+ ]
245
+ else:
246
+ [
247
+ example["input_ids"].insert(0, token)
248
+ for token in example["tokens_to_perturb"][::-1]
249
+ ]
250
 
251
  # truncate to max input size, must also truncate original emb to be comparable
252
  if len(example["input_ids"]) > max_len: