numb3r3 commited on
Commit
cdef62a
·
1 Parent(s): 228fc3b

chore: init readme

Browse files
Files changed (1) hide show
  1. modeling_bert.py +54 -1
modeling_bert.py CHANGED
@@ -271,7 +271,7 @@ class JinaBertSelfAttention(nn.Module):
271
  f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
272
  f"heads ({config.num_attention_heads})"
273
  )
274
-
275
  self.attn_implementation = config.attn_implementation
276
  self.num_attention_heads = config.num_attention_heads
277
  self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
@@ -1945,6 +1945,8 @@ class JinaBertForSequenceClassification(JinaBertPreTrainedModel):
1945
  self.num_labels = config.num_labels
1946
  self.config = config
1947
 
 
 
1948
  self.bert = JinaBertModel(config)
1949
  classifier_dropout = (
1950
  config.classifier_dropout
@@ -2042,6 +2044,57 @@ class JinaBertForSequenceClassification(JinaBertPreTrainedModel):
2042
  attentions=outputs.attentions,
2043
  )
2044
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2045
 
2046
  @add_start_docstrings(
2047
  """
 
271
  f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
272
  f"heads ({config.num_attention_heads})"
273
  )
274
+
275
  self.attn_implementation = config.attn_implementation
276
  self.num_attention_heads = config.num_attention_heads
277
  self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
 
1945
  self.num_labels = config.num_labels
1946
  self.config = config
1947
 
1948
+ self._name_or_path = config._name_or_path
1949
+
1950
  self.bert = JinaBertModel(config)
1951
  classifier_dropout = (
1952
  config.classifier_dropout
 
2044
  attentions=outputs.attentions,
2045
  )
2046
 
2047
+ @torch.inference_mode()
2048
+ def compute_score(
2049
+ self,
2050
+ sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
2051
+ batch_size: int = 32,
2052
+ device: Optional[torch.device] = None,
2053
+ **tokenizer_kwargs,
2054
+ ):
2055
+ assert isinstance(sentence_pairs, list)
2056
+ if isinstance(sentence_pairs[0], str):
2057
+ sentence_pairs = [sentence_pairs]
2058
+
2059
+ if not hasattr(self, 'tokenizer'):
2060
+ from transformers import AutoTokenizer
2061
+
2062
+ self.tokenizer = AutoTokenizer.from_pretrained(self._name_or_path)
2063
+
2064
+ is_training = self.training
2065
+ self.eval()
2066
+
2067
+ if device is not None:
2068
+ self.to(device)
2069
+
2070
+ all_scores = []
2071
+ for start_index in range(
2072
+ 0, len(sentence_pairs), batch_size
2073
+ ):
2074
+ sentences_batch = sentence_pairs[
2075
+ start_index : start_index + (batch_size or self._eval_batch_size)
2076
+ ]
2077
+ inputs = self._tokenizer(
2078
+ sentences_batch,
2079
+ padding=True,
2080
+ truncation=True,
2081
+ return_tensors='pt',
2082
+ **tokenizer_kwargs,
2083
+ ).to(self.device)
2084
+
2085
+ scores = (
2086
+ self.forward(**inputs, return_dict=True)
2087
+ .logits.view(
2088
+ -1,
2089
+ )
2090
+ .float()
2091
+ )
2092
+ all_scores.extend(scores.cpu().numpy().tolist())
2093
+
2094
+ if len(all_scores) == 1:
2095
+ return all_scores[0]
2096
+ return all_scores
2097
+
2098
 
2099
  @add_start_docstrings(
2100
  """