Update modeling_mpt.py
Browse files- modeling_mpt.py +1 -1
modeling_mpt.py
CHANGED
@@ -181,7 +181,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
181 |
x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
|
182 |
assert isinstance(self.emb_drop, nn.Module)
|
183 |
x = self.emb_drop(x_shrunk)
|
184 |
-
(attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=
|
185 |
if use_cache and past_key_values is None:
|
186 |
past_key_values = [() for _ in range(self.config.n_layers)]
|
187 |
all_hidden_states = () if output_hidden_states else None
|
|
|
181 |
x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
|
182 |
assert isinstance(self.emb_drop, nn.Module)
|
183 |
x = self.emb_drop(x_shrunk)
|
184 |
+
(attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
|
185 |
if use_cache and past_key_values is None:
|
186 |
past_key_values = [() for _ in range(self.config.n_layers)]
|
187 |
all_hidden_states = () if output_hidden_states else None
|