Rename precious3_gpt_multi_modalX.py to precious3_gpt_multi_modal.py
Browse files
precious3_gpt_multi_modalX.py → precious3_gpt_multi_modal.py
RENAMED
@@ -13,12 +13,12 @@ from transformers import PreTrainedTokenizerFast
|
|
13 |
import os
|
14 |
import torch.nn.functional as F
|
15 |
|
16 |
-
from modeling_mpt import MPTModel, MPTForCausalLM, gen_attention_mask_in_length
|
17 |
-
from configuration_mpt import MPTConfig
|
18 |
-
from blocks import MPTBlock
|
19 |
-
from norm import NORM_CLASS_REGISTRY
|
20 |
-
from custom_embedding import SharedEmbedding
|
21 |
-
from attention import ATTN_CLASS_REGISTRY, attn_bias_shape, build_attn_bias, gen_slopes
|
22 |
|
23 |
import logging
|
24 |
log = logging.getLogger(__name__)
|
@@ -85,10 +85,10 @@ class Custom_MptModel(MPTModel): # MptModel
|
|
85 |
|
86 |
|
87 |
self.modality2_embedding_projection = nn.ModuleList([nn.Linear(modality2_dim, config.d_model),
|
88 |
-
#
|
89 |
nn.ReLU(),
|
90 |
nn.Linear(config.d_model, config.d_model),
|
91 |
-
#
|
92 |
nn.ReLU(),
|
93 |
nn.Linear(config.d_model, config.d_model)])# nn.Linear(modality0_dim, self.hidden_size)
|
94 |
|
@@ -351,4 +351,4 @@ class Custom_MPTForCausalLM(MPTForCausalLM):
|
|
351 |
_labels = torch.roll(labels, shifts=-1)
|
352 |
_labels[:, -1] = -100
|
353 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1))
|
354 |
-
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
|
|
13 |
import os
|
14 |
import torch.nn.functional as F
|
15 |
|
16 |
+
from mpt_7b.modeling_mpt import MPTModel, MPTForCausalLM, gen_attention_mask_in_length
|
17 |
+
from mpt_7b.configuration_mpt import MPTConfig
|
18 |
+
from mpt_7b.blocks import MPTBlock
|
19 |
+
from mpt_7b.norm import NORM_CLASS_REGISTRY
|
20 |
+
from mpt_7b.custom_embedding import SharedEmbedding
|
21 |
+
from mpt_7b.attention import ATTN_CLASS_REGISTRY, attn_bias_shape, build_attn_bias, gen_slopes
|
22 |
|
23 |
import logging
|
24 |
log = logging.getLogger(__name__)
|
|
|
85 |
|
86 |
|
87 |
self.modality2_embedding_projection = nn.ModuleList([nn.Linear(modality2_dim, config.d_model),
|
88 |
+
# nn.BatchNorm1d(config.d_model),
|
89 |
nn.ReLU(),
|
90 |
nn.Linear(config.d_model, config.d_model),
|
91 |
+
# nn.BatchNorm1d(config.d_model),
|
92 |
nn.ReLU(),
|
93 |
nn.Linear(config.d_model, config.d_model)])# nn.Linear(modality0_dim, self.hidden_size)
|
94 |
|
|
|
351 |
_labels = torch.roll(labels, shifts=-1)
|
352 |
_labels[:, -1] = -100
|
353 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1))
|
354 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|