add full support for inputs_embeds
#10
by
jxm
- opened
modeling_hf_nomic_bert.py
CHANGED
@@ -983,22 +983,21 @@ class NomicBertEmbeddings(nn.Module):
|
|
983 |
position_ids: (batch, seqlen)
|
984 |
token_type_ids: (batch, seqlen)
|
985 |
"""
|
986 |
-
batch_size, seqlen = input_ids.shape
|
987 |
-
|
988 |
if inputs_embeds is None:
|
989 |
embeddings = self.word_embeddings(input_ids)
|
990 |
else:
|
991 |
embeddings = inputs_embeds
|
992 |
-
|
|
|
993 |
if self.type_vocab_size > 0:
|
994 |
if token_type_ids is None:
|
995 |
-
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=
|
996 |
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
997 |
embeddings = embeddings + token_type_embeddings
|
998 |
|
999 |
if self.max_position_embeddings > 0:
|
1000 |
if position_ids is None:
|
1001 |
-
position_ids = torch.arange(seqlen, dtype=torch.long, device=
|
1002 |
position_embeddings = self.position_embeddings(position_ids)
|
1003 |
embeddings = embeddings + position_embeddings
|
1004 |
return embeddings
|
@@ -1688,8 +1687,6 @@ class NomicBertModel(NomicBertPreTrainedModel):
|
|
1688 |
):
|
1689 |
if input_ids is not None and inputs_embeds is not None:
|
1690 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
1691 |
-
if token_type_ids is None:
|
1692 |
-
token_type_ids = torch.zeros_like(input_ids)
|
1693 |
hidden_states = self.embeddings(
|
1694 |
input_ids=input_ids,
|
1695 |
position_ids=position_ids,
|
@@ -1699,7 +1696,7 @@ class NomicBertModel(NomicBertPreTrainedModel):
|
|
1699 |
hidden_states = self.emb_ln(hidden_states)
|
1700 |
hidden_states = self.emb_drop(hidden_states)
|
1701 |
|
1702 |
-
attention_mask = self.get_extended_attention_mask(attention_mask,
|
1703 |
sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict)
|
1704 |
|
1705 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
|
|
983 |
position_ids: (batch, seqlen)
|
984 |
token_type_ids: (batch, seqlen)
|
985 |
"""
|
|
|
|
|
986 |
if inputs_embeds is None:
|
987 |
embeddings = self.word_embeddings(input_ids)
|
988 |
else:
|
989 |
embeddings = inputs_embeds
|
990 |
+
batch_size, seqlen, _ = embeddings.shape
|
991 |
+
|
992 |
if self.type_vocab_size > 0:
|
993 |
if token_type_ids is None:
|
994 |
+
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=embeddings.device)
|
995 |
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
996 |
embeddings = embeddings + token_type_embeddings
|
997 |
|
998 |
if self.max_position_embeddings > 0:
|
999 |
if position_ids is None:
|
1000 |
+
position_ids = torch.arange(seqlen, dtype=torch.long, device=embeddings.device)
|
1001 |
position_embeddings = self.position_embeddings(position_ids)
|
1002 |
embeddings = embeddings + position_embeddings
|
1003 |
return embeddings
|
|
|
1687 |
):
|
1688 |
if input_ids is not None and inputs_embeds is not None:
|
1689 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
|
|
1690 |
hidden_states = self.embeddings(
|
1691 |
input_ids=input_ids,
|
1692 |
position_ids=position_ids,
|
|
|
1696 |
hidden_states = self.emb_ln(hidden_states)
|
1697 |
hidden_states = self.emb_drop(hidden_states)
|
1698 |
|
1699 |
+
attention_mask = self.get_extended_attention_mask(attention_mask, hidden_states.shape[:-1])
|
1700 |
sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict)
|
1701 |
|
1702 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|