Files changed (4) hide show
  1. README.md +3 -6
  2. config.json +2 -2
  3. configuration_minicpm.py +0 -2
  4. modeling_minicpm.py +21 -42
README.md CHANGED
@@ -344,6 +344,7 @@ When running evaluation on BEIR and C-MTEB/Retrieval, we use instructions in `in
344
 
345
  ```
346
  transformers==4.37.2
 
347
  ```
348
 
349
  ### 示例脚本 Demo
@@ -357,9 +358,7 @@ import torch.nn.functional as F
357
 
358
  model_name = "openbmb/MiniCPM-Embedding"
359
  tokenizer = AutoTokenizer.from_pretrained(model_name)
360
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16).to("cuda")
361
- # You can also use the following line to enable the Flash Attention 2 implementation
362
- # model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.float16).to("cuda")
363
  model.eval()
364
 
365
  # 由于在 `model.forward` 中缩放了最终隐层表示,此处的 mean pooling 实际上起到了 weighted mean pooling 的作用
@@ -403,9 +402,7 @@ import torch
403
  from sentence_transformers import SentenceTransformer
404
 
405
  model_name = "openbmb/MiniCPM-Embedding"
406
- model = SentenceTransformer(model_name, trust_remote_code=True, model_kwargs={ "torch_dtype": torch.float16})
407
- # You can also use the following line to enable the Flash Attention 2 implementation
408
- # model = SentenceTransformer(model_name, trust_remote_code=True, attn_implementation="flash_attention_2", model_kwargs={ "torch_dtype": torch.float16})
409
 
410
  queries = ["中国的首都是哪里?"]
411
  passages = ["beijing", "shanghai"]
 
344
 
345
  ```
346
  transformers==4.37.2
347
+ flash-attn>2.3.5
348
  ```
349
 
350
  ### 示例脚本 Demo
 
358
 
359
  model_name = "openbmb/MiniCPM-Embedding"
360
  tokenizer = AutoTokenizer.from_pretrained(model_name)
361
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.float16).to("cuda")
 
 
362
  model.eval()
363
 
364
  # 由于在 `model.forward` 中缩放了最终隐层表示,此处的 mean pooling 实际上起到了 weighted mean pooling 的作用
 
402
  from sentence_transformers import SentenceTransformer
403
 
404
  model_name = "openbmb/MiniCPM-Embedding"
405
+ model = SentenceTransformer(model_name, trust_remote_code=True, model_kwargs={"attn_implementation": "flash_attention_2", "torch_dtype": torch.float16})
 
 
406
 
407
  queries = ["中国的首都是哪里?"]
408
  passages = ["beijing", "shanghai"]
config.json CHANGED
@@ -17,7 +17,7 @@
17
  "initializer_range": 0.1,
18
  "intermediate_size": 5760,
19
  "is_causal": false,
20
- "max_position_embeddings": 512,
21
  "num_attention_heads": 36,
22
  "num_hidden_layers": 40,
23
  "num_key_value_heads": 36,
@@ -25,7 +25,7 @@
25
  "rope_scaling": null,
26
  "torch_dtype": "bfloat16",
27
  "transformers_version": "4.36.0",
28
- "use_cache": false,
29
  "vocab_size": 122753,
30
  "scale_emb": 12,
31
  "dim_model_base": 256,
 
17
  "initializer_range": 0.1,
18
  "intermediate_size": 5760,
19
  "is_causal": false,
20
+ "max_position_embeddings": 2048,
21
  "num_attention_heads": 36,
22
  "num_hidden_layers": 40,
23
  "num_key_value_heads": 36,
 
25
  "rope_scaling": null,
26
  "torch_dtype": "bfloat16",
27
  "transformers_version": "4.36.0",
28
+ "use_cache": true,
29
  "vocab_size": 122753,
30
  "scale_emb": 12,
31
  "dim_model_base": 256,
configuration_minicpm.py CHANGED
@@ -140,7 +140,6 @@ class MiniCPMConfig(PretrainedConfig):
140
  dim_model_base=1,
141
  scale_depth=1,
142
  is_causal=True,
143
- adapt_mean_pooling=True,
144
  **kwargs,
145
  ):
146
  self.vocab_size = vocab_size
@@ -169,7 +168,6 @@ class MiniCPMConfig(PretrainedConfig):
169
  self.dim_model_base = dim_model_base
170
  self.scale_depth = scale_depth
171
  self.is_causal = is_causal
172
- self.adapt_mean_pooling = adapt_mean_pooling
173
 
174
  super().__init__(
175
  pad_token_id=pad_token_id,
 
140
  dim_model_base=1,
141
  scale_depth=1,
142
  is_causal=True,
 
143
  **kwargs,
144
  ):
145
  self.vocab_size = vocab_size
 
168
  self.dim_model_base = dim_model_base
169
  self.scale_depth = scale_depth
170
  self.is_causal = is_causal
 
171
 
172
  super().__init__(
173
  pad_token_id=pad_token_id,
modeling_minicpm.py CHANGED
@@ -21,16 +21,12 @@
21
  import math
22
  import warnings
23
  from typing import List, Optional, Tuple, Union, Dict
24
- import os
25
- from tqdm import tqdm
26
  import torch
27
  import torch.nn.functional as F
28
  import torch.utils.checkpoint
29
  from torch import nn
30
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
- import numpy as np
32
- from copy import deepcopy
33
- from transformers import AutoTokenizer
34
 
35
  from transformers.activations import ACT2FN
36
  from transformers.cache_utils import Cache, DynamicCache
@@ -39,7 +35,6 @@ from transformers.modeling_attn_mask_utils import (
39
  _prepare_4d_attention_mask,
40
  _prepare_4d_causal_attention_mask,
41
  _prepare_4d_causal_attention_mask_for_sdpa,
42
- _prepare_4d_attention_mask_for_sdpa,
43
  )
44
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
45
  from transformers.modeling_utils import PreTrainedModel
@@ -325,6 +320,9 @@ class MiniCPMAttention(nn.Module):
325
  self.rope_theta = config.rope_theta
326
 
327
  self.is_causal = config.is_causal
 
 
 
328
 
329
  if (self.head_dim * self.num_heads) != self.hidden_size:
330
  raise ValueError(
@@ -981,8 +979,6 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
981
  self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
982
 
983
  self.gradient_checkpointing = False
984
- self.is_causal = config.is_causal
985
- self.adapt_mean_pooling = config.adapt_mean_pooling
986
  # Initialize weights and apply final processing
987
  self.post_init()
988
 
@@ -1004,7 +1000,6 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1004
  output_attentions: Optional[bool] = None,
1005
  output_hidden_states: Optional[bool] = None,
1006
  return_dict: Optional[bool] = None,
1007
- adapt_mean_pooling: Optional[bool] = None,
1008
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1009
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1010
  output_hidden_states = (
@@ -1049,35 +1044,24 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1049
  inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb
1050
 
1051
  _attention_mask = attention_mask
 
1052
  if self._use_flash_attention_2:
1053
  # 2d mask is passed through the layers
1054
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1055
  elif self._use_sdpa and not output_attentions:
1056
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1057
  # the manual implementation that requires a 4D causal mask in all cases.
1058
- if self.is_causal:
1059
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa (
1060
- attention_mask,
1061
- (batch_size, seq_length),
1062
- inputs_embeds,
1063
- past_key_values_length,
1064
- )
1065
- else:
1066
- attention_mask = _prepare_4d_attention_mask_for_sdpa(
1067
- attention_mask,
1068
- inputs_embeds.dtype,
1069
- )
1070
  else:
1071
  # 4d mask is passed through the layers
1072
- if self.is_causal:
1073
- attention_mask = _prepare_4d_causal_attention_mask (
1074
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1075
- )
1076
- else:
1077
- attention_mask = _prepare_4d_attention_mask(
1078
- attention_mask,
1079
- inputs_embeds.dtype,
1080
- )
1081
 
1082
  # embed positions
1083
  hidden_states = inputs_embeds
@@ -1125,18 +1109,14 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1125
  if output_hidden_states:
1126
  all_hidden_states += (hidden_states,)
1127
 
1128
- next_cache = None
1129
-
1130
  # gen weight before mean pooling
1131
- if adapt_mean_pooling is None:
1132
- adapt_mean_pooling = self.adapt_mean_pooling
1133
- if adapt_mean_pooling:
1134
- attention_mask_ = _attention_mask * _attention_mask.cumsum(dim=1)
1135
- s = hidden_states * attention_mask_.unsqueeze(-1).float()
1136
- d = attention_mask_.sum(dim=1, keepdim=True).unsqueeze(1).float() /_attention_mask.sum(dim=1, keepdim=True).unsqueeze(1).float()
1137
-
1138
- hidden_states = s / d
1139
 
 
 
 
1140
  if use_cache:
1141
  next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1142
  if not return_dict:
@@ -1147,8 +1127,7 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1147
  hidden_states=all_hidden_states,
1148
  attentions=all_self_attns,
1149
  )
1150
-
1151
-
1152
 
1153
  class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1154
  _tied_weights_keys = ["lm_head.weight"]
 
21
  import math
22
  import warnings
23
  from typing import List, Optional, Tuple, Union, Dict
24
+
 
25
  import torch
26
  import torch.nn.functional as F
27
  import torch.utils.checkpoint
28
  from torch import nn
29
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
 
 
30
 
31
  from transformers.activations import ACT2FN
32
  from transformers.cache_utils import Cache, DynamicCache
 
35
  _prepare_4d_attention_mask,
36
  _prepare_4d_causal_attention_mask,
37
  _prepare_4d_causal_attention_mask_for_sdpa,
 
38
  )
39
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
40
  from transformers.modeling_utils import PreTrainedModel
 
320
  self.rope_theta = config.rope_theta
321
 
322
  self.is_causal = config.is_causal
323
+
324
+ logger.info(f"self.is_causal = {self.is_causal}")
325
+
326
 
327
  if (self.head_dim * self.num_heads) != self.hidden_size:
328
  raise ValueError(
 
979
  self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
980
 
981
  self.gradient_checkpointing = False
 
 
982
  # Initialize weights and apply final processing
983
  self.post_init()
984
 
 
1000
  output_attentions: Optional[bool] = None,
1001
  output_hidden_states: Optional[bool] = None,
1002
  return_dict: Optional[bool] = None,
 
1003
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1004
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1005
  output_hidden_states = (
 
1044
  inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb
1045
 
1046
  _attention_mask = attention_mask
1047
+
1048
  if self._use_flash_attention_2:
1049
  # 2d mask is passed through the layers
1050
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1051
  elif self._use_sdpa and not output_attentions:
1052
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1053
  # the manual implementation that requires a 4D causal mask in all cases.
1054
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1055
+ attention_mask,
1056
+ (batch_size, seq_length),
1057
+ inputs_embeds,
1058
+ past_key_values_length,
1059
+ )
 
 
 
 
 
 
1060
  else:
1061
  # 4d mask is passed through the layers
1062
+ attention_mask = _prepare_4d_causal_attention_mask(
1063
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1064
+ )
 
 
 
 
 
 
1065
 
1066
  # embed positions
1067
  hidden_states = inputs_embeds
 
1109
  if output_hidden_states:
1110
  all_hidden_states += (hidden_states,)
1111
 
 
 
1112
  # gen weight before mean pooling
1113
+ attention_mask_ = _attention_mask * _attention_mask.cumsum(dim=1)
1114
+ s = hidden_states * attention_mask_.unsqueeze(-1).float()
1115
+ d = attention_mask_.sum(dim=1, keepdim=True).unsqueeze(1).float() /_attention_mask.sum(dim=1, keepdim=True).unsqueeze(1).float()
 
 
 
 
 
1116
 
1117
+ hidden_states = s / d
1118
+
1119
+ next_cache = None
1120
  if use_cache:
1121
  next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1122
  if not return_dict:
 
1127
  hidden_states=all_hidden_states,
1128
  attentions=all_self_attns,
1129
  )
1130
+
 
1131
 
1132
  class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1133
  _tied_weights_keys = ["lm_head.weight"]