Truncate to 8k by default
#5
by
Jackmin108
- opened
- modeling_bert.py +3 -2
modeling_bert.py
CHANGED
@@ -1195,7 +1195,9 @@ class JinaBertModel(JinaBertPreTrainedModel):
|
|
1195 |
inverse_permutation = np.argsort(permutation)
|
1196 |
sentences = [sentences[idx] for idx in permutation]
|
1197 |
|
1198 |
-
padding = tokenizer_kwargs.
|
|
|
|
|
1199 |
|
1200 |
all_embeddings = []
|
1201 |
|
@@ -1214,7 +1216,6 @@ class JinaBertModel(JinaBertPreTrainedModel):
|
|
1214 |
encoded_input = self.tokenizer(
|
1215 |
sentences[i : i + batch_size],
|
1216 |
return_tensors='pt',
|
1217 |
-
padding=padding,
|
1218 |
**tokenizer_kwargs,
|
1219 |
).to(self.device)
|
1220 |
token_embs = self.forward(**encoded_input)[0]
|
|
|
1195 |
inverse_permutation = np.argsort(permutation)
|
1196 |
sentences = [sentences[idx] for idx in permutation]
|
1197 |
|
1198 |
+
tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
|
1199 |
+
tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 8192)
|
1200 |
+
tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
|
1201 |
|
1202 |
all_embeddings = []
|
1203 |
|
|
|
1216 |
encoded_input = self.tokenizer(
|
1217 |
sentences[i : i + batch_size],
|
1218 |
return_tensors='pt',
|
|
|
1219 |
**tokenizer_kwargs,
|
1220 |
).to(self.device)
|
1221 |
token_embs = self.forward(**encoded_input)[0]
|