ctheodoris davidjwen commited on
Commit
3a94209
1 Parent(s): 9169bfd

Fixed bug in gen_attention_mask with len > max_len (#158)

Browse files

- Fixed bug in gen_attention_mask with len > max_len (7c77bae654e0d93a27e1988e107b3258902b3d05)


Co-authored-by: David Wen <davidjwen@users.noreply.huggingface.co>

Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +4 -1
geneformer/in_silico_perturber.py CHANGED
@@ -342,7 +342,6 @@ def quant_cos_sims(model,
342
  max_range = min(i+forward_batch_size, total_batch_length)
343
 
344
  perturbation_minibatch = perturbation_batch.select([i for i in range(i, max_range)])
345
-
346
  # determine if need to pad or truncate batch
347
  minibatch_length_set = set(perturbation_minibatch["length"])
348
  minibatch_lengths = perturbation_minibatch["length"]
@@ -354,12 +353,14 @@ def quant_cos_sims(model,
354
 
355
  if needs_pad_or_trunc == True:
356
  max_len = min(max(minibatch_length_set),model_input_size)
 
357
  def pad_or_trunc_example(example):
358
  example["input_ids"] = pad_or_truncate_encoding(example["input_ids"],
359
  pad_token_id,
360
  max_len)
361
  return example
362
  perturbation_minibatch = perturbation_minibatch.map(pad_or_trunc_example, num_proc=nproc)
 
363
  perturbation_minibatch.set_format(type="torch")
364
 
365
  input_data_minibatch = perturbation_minibatch["input_ids"]
@@ -570,6 +571,8 @@ def gen_attention_mask(minibatch_encoding, max_len = None):
570
  original_lens = minibatch_encoding["length"]
571
  attention_mask = [[1]*original_len
572
  +[0]*(max_len - original_len)
 
 
573
  for original_len in original_lens]
574
  return torch.tensor(attention_mask).to("cuda")
575
 
 
342
  max_range = min(i+forward_batch_size, total_batch_length)
343
 
344
  perturbation_minibatch = perturbation_batch.select([i for i in range(i, max_range)])
 
345
  # determine if need to pad or truncate batch
346
  minibatch_length_set = set(perturbation_minibatch["length"])
347
  minibatch_lengths = perturbation_minibatch["length"]
 
353
 
354
  if needs_pad_or_trunc == True:
355
  max_len = min(max(minibatch_length_set),model_input_size)
356
+ print(max_len)
357
  def pad_or_trunc_example(example):
358
  example["input_ids"] = pad_or_truncate_encoding(example["input_ids"],
359
  pad_token_id,
360
  max_len)
361
  return example
362
  perturbation_minibatch = perturbation_minibatch.map(pad_or_trunc_example, num_proc=nproc)
363
+
364
  perturbation_minibatch.set_format(type="torch")
365
 
366
  input_data_minibatch = perturbation_minibatch["input_ids"]
 
571
  original_lens = minibatch_encoding["length"]
572
  attention_mask = [[1]*original_len
573
  +[0]*(max_len - original_len)
574
+ if original_len <= max_len
575
+ else [1]*max_len
576
  for original_len in original_lens]
577
  return torch.tensor(attention_mask).to("cuda")
578