phoebeklett commited on
Commit
c3edc15
1 Parent(s): 1df9b46

Upload 2 files

Browse files
Files changed (1) hide show
  1. modeling.py +2 -1
modeling.py CHANGED
@@ -654,7 +654,7 @@ class ExtendedLlamaAttention(nn.Module):
654
  if not output_attentions:
655
  attn_weights = None
656
 
657
- if not output_retrieved_memory_idx:
658
  reshaped_idx = None
659
  return attn_output, attn_weights, past_key_value, reshaped_idx
660
 
@@ -1568,6 +1568,7 @@ class ExtendedLlamaForCausalLM(LlamaPreTrainedModel):
1568
  "attention_mask": attention_mask,
1569
  "use_external_mind": kwargs.get("use_external_mind"), # EM: Add config here
1570
  "topk": kwargs.get("topk"),
 
1571
  }
1572
  )
1573
  return model_inputs
 
654
  if not output_attentions:
655
  attn_weights = None
656
 
657
+ if not output_retrieved_memory_idx or (long_range_past_key_value is None and faiss_indexes is None):
658
  reshaped_idx = None
659
  return attn_output, attn_weights, past_key_value, reshaped_idx
660
 
 
1568
  "attention_mask": attention_mask,
1569
  "use_external_mind": kwargs.get("use_external_mind"), # EM: Add config here
1570
  "topk": kwargs.get("topk"),
1571
+ "output_retrieved_memory_idx": kwargs.get("output_retrieved_memory_idx"),
1572
  }
1573
  )
1574
  return model_inputs