tiendung commited on
Commit
d3b32d2
1 Parent(s): 1824b52

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: text-classification
3
+ tags:
4
+ - transformers
5
+ - reranker
6
+ - cross-encoder
7
+ - transformers.js
8
+ language:
9
+ - multilingual
10
+ inference: false
11
+ license: cc-by-nc-4.0
12
+ library_name: transformers
13
+ ---
14
+
15
+ <br><br>
16
+
17
+ <p align="center">
18
+ <img src="https://aeiljuispo.cloudimg.io/v7/https://cdn-uploads.huggingface.co/production/uploads/603763514de52ff951d89793/AFoybzd5lpBQXEBrQHuTt.png?w=200&h=200&f=face" alt="Finetuner logo: Finetuner helps you to create experiments in order to improve embeddings on search tasks. It accompanies you to deliver the last mile of performance-tuning for neural search applications." width="150px">
19
+ </p>
20
+
21
+ <p align="center">
22
+ <b>Trained by <a href="https://jina.ai/"><b>Jina AI</b></a>.</b>
23
+ </p>
24
+
25
+ # jina-reranker-v2-base-multilingual
26
+
27
+ ## Intended Usage & Model Info
28
+
29
+ The **Jina Reranker v2** (`jina-reranker-v2-base-multilingual`) is a transformer-based model that has been fine-tuned for text reranking task, which is a crucial component in many information retrieval systems. It is a cross-encoder model that takes a query and a document pair as input and outputs a score indicating the relevance of the document to the query. The model is trained on a large dataset of query-document pairs and is capable of reranking documents in multiple languages with high accuracy.
30
+
31
+ Compared with the state-of-the-art reranker models, including the previous released `jina-reranker-v1-base-en`, the **Jina Reranker v2** model has demonstrated competitiveness across a series of benchmarks targeting for text retrieval, multilingual capability, function-calling-aware and text-to-SQL-aware reranking, and code retrieval tasks.
32
+
33
+ The `jina-reranker-v2-base-multilingual` model is capable of handling long texts with a context length of up to `1024` tokens, enabling the processing of extensive inputs. To enable the model to handle long texts that exceed 1024 tokens, the model uses a sliding window approach to chunk the input text into smaller pieces and rerank each chunk separately.
34
+
35
+ The model is also equipped with a flash attention mechanism, which significantly improves the model's performance.
36
+
37
+
38
+ # Usage
39
+
40
+ _This model repository is licenced for research and evaluation purposes under CC-BY-NC-4.0. For commercial usage, please refer to Jina AI's APIs, AWS Sagemaker or Azure Marketplace offerings. Please [contact us](https://jina.ai/contact-sales) for any further clarifications._
41
+ 1. The easiest way to use `jina-reranker-v2-base-multilingual` is to call Jina AI's [Reranker API](https://jina.ai/reranker/).
42
+
43
+ ```bash
44
+ curl https://api.jina.ai/v1/rerank \
45
+ -H "Content-Type: application/json" \
46
+ -H "Authorization: Bearer YOUR_API_KEY" \
47
+ -d '{
48
+ "model": "jina-reranker-v2-base-multilingual",
49
+ "query": "Organic skincare products for sensitive skin",
50
+ "documents": [
51
+ "Organic skincare for sensitive skin with aloe vera and chamomile.",
52
+ "New makeup trends focus on bold colors and innovative techniques",
53
+ "Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille",
54
+ "Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken",
55
+ "Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla",
56
+ "Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras",
57
+ "针对敏感肌专门设计的天然有机护肤产品",
58
+ "新的化妆趋势注重鲜艳的颜色和创新的技巧",
59
+ "敏感肌のために特別に設計された天然有機スキンケア製品",
60
+ "新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています"
61
+ ],
62
+ "top_n": 3
63
+ }'
64
+ ```
65
+
66
+ 2. You can also use the `transformers` library to interact with the model programmatically.
67
+
68
+ Before you start, install the `transformers` and `einops` libraries:
69
+
70
+ ```bash
71
+ pip install transformers einops
72
+ ```
73
+
74
+ And then:
75
+ ```python
76
+ from transformers import AutoModelForSequenceClassification
77
+
78
+ model = AutoModelForSequenceClassification.from_pretrained(
79
+ 'jinaai/jina-reranker-v2-base-multilingual',
80
+ torch_dtype="auto",
81
+ trust_remote_code=True,
82
+ )
83
+
84
+ model.to('cuda') # or 'cpu' if no GPU is available
85
+ model.eval()
86
+
87
+ # Example query and documents
88
+ query = "Organic skincare products for sensitive skin"
89
+ documents = [
90
+ "Organic skincare for sensitive skin with aloe vera and chamomile.",
91
+ "New makeup trends focus on bold colors and innovative techniques",
92
+ "Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille",
93
+ "Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken",
94
+ "Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla",
95
+ "Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras",
96
+ "针对敏感肌专门设计的天然有机护肤产品",
97
+ "新的化妆趋势注重鲜艳的颜色和创新的技巧",
98
+ "敏感肌のために特別に設計された天然有機スキンケア製品",
99
+ "新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています",
100
+ ]
101
+
102
+ # construct sentence pairs
103
+ sentence_pairs = [[query, doc] for doc in documents]
104
+
105
+ scores = model.compute_score(sentence_pairs, max_length=1024)
106
+ ```
107
+
108
+ The scores will be a list of floats, where each float represents the relevance score of the corresponding document to the query. Higher scores indicate higher relevance.
109
+ For instance the returning scores in this case will be:
110
+ ```bash
111
+ [0.8311430811882019, 0.09401018172502518,
112
+ 0.6334102749824524, 0.08269733935594559,
113
+ 0.7620701193809509, 0.09947021305561066,
114
+ 0.9263036847114563, 0.05834583938121796,
115
+ 0.8418256044387817, 0.11124119907617569]
116
+ ```
117
+
118
+ The model gives high relevance scores to the documents that are most relevant to the query regardless of the language of the document.
119
+
120
+ Note that by default, the `jina-reranker-v2-base-multilingual` model uses [flash attention](https://github.com/Dao-AILab/flash-attention), which requires certain types of GPU hardware to run.
121
+ If you encounter any issues, you can try call `AutoModelForSequenceClassification.from_pretrained()` with `use_flash_attn=False`.
122
+ This will use the standard attention mechanism instead of flash attention.
123
+
124
+ If you want to use flash attention for fast inference, you need to install the following packages:
125
+ ```bash
126
+ pip install ninja # required for flash attention
127
+ pip install flash-attn --no-build-isolation
128
+ ```
129
+ Enjoy the 3x-6x speedup with flash attention! ⚡️⚡️⚡️
130
+
131
+
132
+ 3. You can also use the `transformers.js` library to run the model directly in JavaScript (in-browser, Node.js, Deno, etc.)!
133
+
134
+ If you haven't already, you can install the [Transformers.js](https://huggingface.co/docs/transformers.js) JavaScript library (v3) using:
135
+ ```bash
136
+ npm i xenova/transformers.js#v3
137
+ ```
138
+
139
+ Then, you can use the following code to interact with the model:
140
+ ```js
141
+ import { AutoTokenizer, XLMRobertaModel } from '@xenova/transformers';
142
+
143
+ const model_id = 'jinaai/jina-reranker-v2-base-multilingual';
144
+ const model = await XLMRobertaModel.from_pretrained(model_id, { dtype: 'fp32' });
145
+ const tokenizer = await AutoTokenizer.from_pretrained(model_id);
146
+
147
+ /**
148
+ * Performs ranking with the CrossEncoder on the given query and documents. Returns a sorted list with the document indices and scores.
149
+ * @param {string} query A single query
150
+ * @param {string[]} documents A list of documents
151
+ * @param {Object} options Options for ranking
152
+ * @param {number} [options.top_k=undefined] Return the top-k documents. If undefined, all documents are returned.
153
+ * @param {number} [options.return_documents=false] If true, also returns the documents. If false, only returns the indices and scores.
154
+ */
155
+ async function rank(query, documents, {
156
+ top_k = undefined,
157
+ return_documents = false,
158
+ } = {}) {
159
+ const inputs = tokenizer(
160
+ new Array(documents.length).fill(query),
161
+ { text_pair: documents, padding: true, truncation: true }
162
+ )
163
+ const { logits } = await model(inputs);
164
+ return logits.sigmoid().tolist()
165
+ .map(([score], i) => ({
166
+ corpus_id: i,
167
+ score,
168
+ ...(return_documents ? { text: documents[i] } : {})
169
+ })).sort((a, b) => b.score - a.score).slice(0, top_k);
170
+ }
171
+
172
+ // Example usage:
173
+ const query = "Organic skincare products for sensitive skin"
174
+ const documents = [
175
+ "Organic skincare for sensitive skin with aloe vera and chamomile.",
176
+ "New makeup trends focus on bold colors and innovative techniques",
177
+ "Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille",
178
+ "Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken",
179
+ "Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla",
180
+ "Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras",
181
+ "针对敏感肌专门设计的天然有机护肤产品",
182
+ "新的化妆趋势注重鲜艳的颜色和创新的技巧",
183
+ "敏感肌のために特別に設計された天然有機スキンケア製品",
184
+ "新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています",
185
+ ]
186
+
187
+ const results = await rank(query, documents, { return_documents: true, top_k: 3 });
188
+ console.log(results);
189
+ ```
190
+
191
+
192
+ That's it! You can now use the `jina-reranker-v2-base-multilingual` model in your projects.
193
+
194
+
195
+ In addition to the `compute_score()` function, the `jina-reranker-v2-base-multilingual` model also provides a `model.rerank()` function that can be used to rerank documents based on a query. You can use it as follows:
196
+
197
+ ```python
198
+ result = model.rerank(
199
+ query,
200
+ documents,
201
+ max_query_length=512,
202
+ max_length=1024,
203
+ top_n=3
204
+ )
205
+ ```
206
+
207
+ Inside the `result` object, you will find the reranked documents along with their scores. You can use this information to further process the documents as needed.
208
+
209
+ The `rerank()` function will automatically chunk the input documents into smaller pieces if they exceed the model's maximum input length. This allows you to rerank long documents without running into memory issues.
210
+ Specifically, the `rerank()` function will split the documents into chunks of size `max_length` and rerank each chunk separately. The scores from all the chunks are then combined to produce the final reranking results. You can control the query length and document length in each chunk by setting the `max_query_length` and `max_length` parameters. The `rerank()` function also supports the `overlap` parameter (default is `80`) which determines how much overlap there is between adjacent chunks. This can be useful when reranking long documents to ensure that the model has enough context to make accurate predictions.
211
+
212
+ 3. Alternatively, `jina-reranker-v2-base-multilingual` has been integrated with `CrossEncoder` from the `sentence-transformers` library.
213
+
214
+ Before you start, install the `sentence-transformers` libraries:
215
+
216
+ ```bash
217
+ pip install sentence-transformers
218
+ ```
219
+
220
+ The [`CrossEncoder`](https://sbert.net/docs/package_reference/cross_encoder/cross_encoder.html) class supports a [`predict`](https://sbert.net/docs/package_reference/cross_encoder/cross_encoder.html#sentence_transformers.cross_encoder.CrossEncoder.predict) method to get query-document relevance scores, and a [`rank`](https://sbert.net/docs/package_reference/cross_encoder/cross_encoder.html#sentence_transformers.cross_encoder.CrossEncoder.rank) method to rank all documents given your query.
221
+
222
+ ```python
223
+ from sentence_transformers import CrossEncoder
224
+
225
+ model = CrossEncoder(
226
+ "jinaai/jina-reranker-v2-base-multilingual",
227
+ automodel_args={"torch_dtype": "auto"},
228
+ trust_remote_code=True,
229
+ )
230
+
231
+ # Example query and documents
232
+ query = "Organic skincare products for sensitive skin"
233
+ documents = [
234
+ "Organic skincare for sensitive skin with aloe vera and chamomile.",
235
+ "New makeup trends focus on bold colors and innovative techniques",
236
+ "Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille",
237
+ "Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken",
238
+ "Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla",
239
+ "Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras",
240
+ "针对敏感肌专门设计的天然有机护肤产品",
241
+ "新的化妆趋势注重鲜艳的颜色和创新的技巧",
242
+ "敏感肌のために特別に設計された天然有機スキンケア製品",
243
+ "新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています",
244
+ ]
245
+
246
+ # construct sentence pairs
247
+ sentence_pairs = [[query, doc] for doc in documents]
248
+
249
+ scores = model.predict(sentence_pairs, convert_to_tensor=True).tolist()
250
+ """
251
+ [0.828125, 0.0927734375, 0.6328125, 0.08251953125, 0.76171875, 0.099609375, 0.92578125, 0.058349609375, 0.84375, 0.111328125]
252
+ """
253
+
254
+ rankings = model.rank(query, documents, return_documents=True, convert_to_tensor=True)
255
+ print(f"Query: {query}")
256
+ for ranking in rankings:
257
+ print(f"ID: {ranking['corpus_id']}, Score: {ranking['score']:.4f}, Text: {ranking['text']}")
258
+ """
259
+ Query: Organic skincare products for sensitive skin
260
+ ID: 6, Score: 0.9258, Text: 针对敏感肌专门设计的天然有机护肤产品
261
+ ID: 8, Score: 0.8438, Text: 敏感肌のために特別に設計された天然有機スキンケア製品
262
+ ID: 0, Score: 0.8281, Text: Organic skincare for sensitive skin with aloe vera and chamomile.
263
+ ID: 4, Score: 0.7617, Text: Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla
264
+ ID: 2, Score: 0.6328, Text: Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille
265
+ ID: 9, Score: 0.1113, Text: 新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています
266
+ ID: 5, Score: 0.0996, Text: Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras
267
+ ID: 1, Score: 0.0928, Text: New makeup trends focus on bold colors and innovative techniques
268
+ ID: 3, Score: 0.0825, Text: Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken
269
+ ID: 7, Score: 0.0583, Text: 新的化妆趋势注重鲜艳的颜色和创新的技巧
270
+ """
271
+ ```
272
+
273
+ # Evaluation
274
+
275
+ We evaluated Jina Reranker v2 on multiple benchmarks to ensure top-tier performance and search relevance.
276
+
277
+ | Model Name | Model Size | MKQA(nDCG@10, 26 langs) | BEIR(nDCG@10, 17 datasets) | MLDR(recall@10, 13 langs) | CodeSearchNet (MRR@10, 3 tasks) | AirBench (nDCG@10, zh/en) | ToolBench (recall@3, 3 tasks) | TableSearch (recall@3) |
278
+ | :-----------------------------: | :----------: | ------------------------- | ---------------------------- | --------------------------- | --------------------------------- | --------------------------- | ------------------------------- | ------------------------ |
279
+ | jina-reranker-v2-multilingual | 278M | 54.83 | 53.17 | 68.95 | 71.36 | 61.33 | 77.75 | 93.31 |
280
+ | bge-reranker-v2-m3 | 568M | 54.17 | 53.65 | 59.73 | 62.86 | 61.28 | 78.46 | 74.86 |
281
+ | mmarco-mMiniLMv2-L12-H384-v1 | 118M | 53.37 | 45.40 | 28.91 | 51.78 | 56.46 | 58.39 | 53.60 |
282
+ | jina-reranker-v1-base-en | 137M | - | 52.45 | - | - | - | 74.13 | 72.89 |
283
+
284
+ Note:
285
+ - NDCG@10 and MRR@10 measure ranking quality, with higher scores indicating better search results
286
+ - recall@3 measures the proportion of relevant documents retrieved, with higher scores indicating better search results
block.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+
4
+ # Copyright (c) 2024, Tri Dao.
5
+
6
+ from functools import partial
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.fx
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+
15
+ from .mha import MHA
16
+ from .mlp import Mlp
17
+
18
+ try:
19
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
20
+ except ImportError:
21
+ layer_norm_fn, RMSNorm = None, None
22
+
23
+
24
+ def stochastic_depth(
25
+ input: Tensor, p: float, mode: str, training: bool = True
26
+ ) -> Tensor:
27
+ """
28
+ Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
29
+ <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
30
+ branches of residual architectures.
31
+ Args:
32
+ input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
33
+ being its batch i.e. a batch with ``N`` rows.
34
+ p (float): probability of the input to be zeroed.
35
+ mode (str): ``"batch"`` or ``"row"``.
36
+ ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
37
+ randomly selected rows from the batch.
38
+ training: apply stochastic depth if is ``True``. Default: ``True``
39
+ Returns:
40
+ Tensor[N, ...]: The randomly zeroed tensor.
41
+ """
42
+ if p < 0.0 or p > 1.0:
43
+ raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
44
+ if mode not in ["batch", "row"]:
45
+ raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
46
+ if not training or p == 0.0:
47
+ return input
48
+
49
+ survival_rate = 1.0 - p
50
+ if mode == "row":
51
+ size = [input.shape[0]] + [1] * (input.ndim - 1)
52
+ else:
53
+ size = [1] * input.ndim
54
+ noise = torch.empty(size, dtype=input.dtype, device=input.device)
55
+ noise = noise.bernoulli_(survival_rate)
56
+ if survival_rate > 0.0:
57
+ noise.div_(survival_rate)
58
+ return input * noise
59
+
60
+
61
+ torch.fx.wrap("stochastic_depth")
62
+
63
+
64
+ class StochasticDepth(nn.Module):
65
+ """
66
+ See :func:`stochastic_depth`.
67
+ """
68
+
69
+ def __init__(self, p: float, mode: str) -> None:
70
+ super().__init__()
71
+ self.p = p
72
+ self.mode = mode
73
+
74
+ def forward(self, input: Tensor) -> Tensor:
75
+ return stochastic_depth(input, self.p, self.mode, self.training)
76
+
77
+ def __repr__(self) -> str:
78
+ s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
79
+ return s
80
+
81
+
82
+ class Block(nn.Module):
83
+ def __init__(
84
+ self,
85
+ dim,
86
+ mixer_cls=None,
87
+ mlp_cls=None,
88
+ norm_cls=nn.LayerNorm,
89
+ dropout_cls=nn.Dropout,
90
+ prenorm=True,
91
+ resid_dropout1=0.0,
92
+ resid_dropout2=0.0,
93
+ drop_path1=0.0,
94
+ drop_path2=0.0,
95
+ fused_dropout_add_ln=False,
96
+ return_residual=False,
97
+ residual_in_fp32=False,
98
+ sequence_parallel=False,
99
+ mark_shared_params=False,
100
+ ):
101
+ """
102
+ For prenorm=True, this Block has a slightly different structure compared to a regular
103
+ prenorm Transformer block.
104
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
105
+ [Ref: https://arxiv.org/abs/2002.04745]
106
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
107
+ the hidden_states (output of the MLP) and the residual.
108
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
109
+ The residual needs to be provided (except for the very first block).
110
+
111
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
112
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
113
+
114
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
115
+ This is for performance reason: for post-norm architecture, returning the input allows us
116
+ to fuse the backward of nn.Linear with the residual connection.
117
+ """
118
+ super().__init__()
119
+ self.prenorm = prenorm
120
+ self.fused_dropout_add_ln = fused_dropout_add_ln
121
+ self.return_residual = return_residual
122
+ self.residual_in_fp32 = residual_in_fp32
123
+ if self.residual_in_fp32:
124
+ assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
125
+ if mixer_cls is None:
126
+ mixer_cls = partial(MHA, num_heads=dim // 64)
127
+ if mlp_cls is None:
128
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
129
+ self.mixer = mixer_cls(dim)
130
+ self.dropout1 = dropout_cls(resid_dropout1)
131
+ self.drop_path1 = StochasticDepth(drop_path1, mode="row")
132
+ self.norm1 = norm_cls(dim)
133
+ self.mlp = mlp_cls(dim)
134
+ if not isinstance(self.mlp, nn.Identity):
135
+ self.dropout2 = dropout_cls(resid_dropout2)
136
+ self.drop_path2 = StochasticDepth(drop_path2, mode="row")
137
+ self.norm2 = norm_cls(dim)
138
+
139
+ if self.fused_dropout_add_ln:
140
+ assert layer_norm_fn is not None, "Triton is not installed"
141
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
142
+ self.dropout1, nn.Dropout
143
+ )
144
+
145
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
146
+ # then the input to each worker in the tensor parallel group will be different.
147
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
148
+ # For now this is not an issue because we always use sequence_parallel=True during training
149
+ # and only use sequence_parallel=False during inference.
150
+
151
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
152
+ if sequence_parallel:
153
+ for p in self.norm1.parameters():
154
+ p._sequence_parallel = True
155
+ if hasattr(self, "norm2"):
156
+ for p in self.norm2.parameters():
157
+ p._sequence_parallel = True
158
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
159
+ if mark_shared_params:
160
+ for p in self.norm1.parameters():
161
+ p._shared_params = True
162
+ if hasattr(self, "norm2"):
163
+ for p in self.norm2.parameters():
164
+ p._shared_params = True
165
+
166
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
167
+ return self.mixer.allocate_inference_cache(
168
+ batch_size, max_seqlen, dtype=dtype, **kwargs
169
+ )
170
+
171
+ def forward(
172
+ self,
173
+ hidden_states: Tensor,
174
+ residual: Optional[Tensor] = None,
175
+ mixer_subset=None,
176
+ mixer_kwargs=None,
177
+ ):
178
+ r"""Pass the input through the encoder layer.
179
+
180
+ Args:
181
+ hidden_states: the sequence to the encoder layer (required).
182
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
183
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
184
+ before applying the query projection. Useful for e.g., ViT where we only care
185
+ about the CLS token in the last layer.
186
+ """
187
+ if self.prenorm:
188
+ if not self.fused_dropout_add_ln:
189
+ dropped = self.drop_path1(self.dropout1(hidden_states))
190
+ residual = (dropped + residual) if residual is not None else dropped
191
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
192
+ if self.residual_in_fp32:
193
+ residual = residual.to(torch.float32)
194
+ else:
195
+ if self.drop_path1.p == 0 or not self.training:
196
+ rowscale1 = None
197
+ else:
198
+ rowscale1 = self.drop_path1(
199
+ torch.ones(
200
+ hidden_states.shape[:-1],
201
+ device=hidden_states.device,
202
+ dtype=hidden_states.dtype,
203
+ )
204
+ )
205
+ hidden_states, residual = layer_norm_fn(
206
+ hidden_states,
207
+ self.norm1.weight,
208
+ self.norm1.bias,
209
+ residual=residual,
210
+ eps=self.norm1.eps,
211
+ dropout_p=self.dropout1.p if self.training else 0.0,
212
+ rowscale=rowscale1,
213
+ prenorm=True,
214
+ residual_in_fp32=self.residual_in_fp32,
215
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
216
+ )
217
+ if mixer_kwargs is None:
218
+ mixer_kwargs = {}
219
+ if mixer_subset is not None:
220
+ mixer_kwargs["mixer_subset"] = mixer_subset
221
+ hidden_states = self.mixer(hidden_states, **mixer_kwargs)
222
+ if mixer_subset is not None:
223
+ residual = residual[:, mixer_subset]
224
+ if not isinstance(self.mlp, nn.Identity):
225
+ if not self.fused_dropout_add_ln:
226
+ dropped = self.drop_path2(self.dropout2(hidden_states))
227
+ residual = (dropped + residual) if residual is not None else dropped
228
+ hidden_states = self.norm2(
229
+ residual.to(dtype=self.norm2.weight.dtype)
230
+ )
231
+ if self.residual_in_fp32:
232
+ residual = residual.to(torch.float32)
233
+ else:
234
+ if self.drop_path2.p == 0 or not self.training:
235
+ rowscale2 = None
236
+ else:
237
+ rowscale2 = self.drop_path2(
238
+ torch.ones(
239
+ hidden_states.shape[:-1],
240
+ device=hidden_states.device,
241
+ dtype=hidden_states.dtype,
242
+ )
243
+ )
244
+ hidden_states, residual = layer_norm_fn(
245
+ hidden_states,
246
+ self.norm2.weight,
247
+ self.norm2.bias,
248
+ residual=residual,
249
+ eps=self.norm2.eps,
250
+ dropout_p=self.dropout2.p if self.training else 0.0,
251
+ rowscale=rowscale2,
252
+ prenorm=True,
253
+ residual_in_fp32=self.residual_in_fp32,
254
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
255
+ )
256
+ hidden_states = self.mlp(hidden_states)
257
+ return hidden_states, residual
258
+ else:
259
+ assert residual is None
260
+ mixer_out = self.mixer(
261
+ hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
262
+ )
263
+ if self.return_residual: # mixer out is actually a pair here
264
+ mixer_out, hidden_states = mixer_out
265
+ if not self.fused_dropout_add_ln:
266
+ hidden_states = self.norm1(
267
+ (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
268
+ dtype=self.norm1.weight.dtype
269
+ )
270
+ )
271
+ else:
272
+ if self.drop_path1.p == 0 or not self.training:
273
+ rowscale1 = None
274
+ else:
275
+ rowscale1 = self.drop_path1(
276
+ torch.ones(
277
+ mixer_out.shape[:-1],
278
+ device=mixer_out.device,
279
+ dtype=mixer_out.dtype,
280
+ )
281
+ )
282
+ hidden_states = layer_norm_fn(
283
+ mixer_out,
284
+ self.norm1.weight,
285
+ self.norm1.bias,
286
+ residual=hidden_states,
287
+ eps=self.norm1.eps,
288
+ dropout_p=self.dropout1.p if self.training else 0.0,
289
+ rowscale=rowscale1,
290
+ prenorm=False,
291
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
292
+ )
293
+ if not isinstance(self.mlp, nn.Identity):
294
+ mlp_out = self.mlp(hidden_states)
295
+ if self.return_residual: # mlp out is actually a pair here
296
+ mlp_out, hidden_states = mlp_out
297
+ if not self.fused_dropout_add_ln:
298
+ hidden_states = self.norm2(
299
+ (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
300
+ dtype=self.norm2.weight.dtype
301
+ )
302
+ )
303
+ else:
304
+ if self.drop_path2.p == 0 or not self.training:
305
+ rowscale2 = None
306
+ else:
307
+ rowscale2 = self.drop_path2(
308
+ torch.ones(
309
+ mlp_out.shape[:-1],
310
+ device=mlp_out.device,
311
+ dtype=mlp_out.dtype,
312
+ )
313
+ )
314
+ hidden_states = layer_norm_fn(
315
+ mlp_out,
316
+ self.norm2.weight,
317
+ self.norm2.bias,
318
+ residual=hidden_states,
319
+ eps=self.norm2.eps,
320
+ dropout_p=self.dropout2.p if self.training else 0.0,
321
+ rowscale=rowscale2,
322
+ prenorm=False,
323
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
324
+ )
325
+ return hidden_states
326
+
327
+
328
+ class ParallelBlock(nn.Module):
329
+ """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
330
+ and PaLM.
331
+ """
332
+
333
+ def __init__(
334
+ self,
335
+ dim,
336
+ mixer_cls=None,
337
+ mlp_cls=None,
338
+ norm_cls=nn.LayerNorm,
339
+ dropout_cls=nn.Dropout,
340
+ resid_dropout1=0.0,
341
+ resid_dropout2=0.0,
342
+ tied_norm=False,
343
+ fused_dropout_add_ln=False,
344
+ residual_in_fp32=False,
345
+ sequence_parallel=False,
346
+ mark_shared_params=False,
347
+ ):
348
+ """
349
+ This Block has a slightly different structure compared to a regular
350
+ prenorm Transformer block.
351
+ The standard block is: LN -> MHA / MLP -> Dropout -> Add.
352
+ [Ref: https://arxiv.org/abs/2002.04745]
353
+ Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
354
+ the hidden_states (output1 of the MHA / MLP) and the residual.
355
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
356
+ The residual needs to be provided (except for the very first block).
357
+ """
358
+ super().__init__()
359
+ self.tied_norm = tied_norm
360
+ self.fused_dropout_add_ln = fused_dropout_add_ln
361
+ self.residual_in_fp32 = residual_in_fp32
362
+ if mixer_cls is None:
363
+ mixer_cls = partial(MHA, num_heads=dim // 64)
364
+ if mlp_cls is None:
365
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
366
+ self.mixer = mixer_cls(dim)
367
+ self.dropout1 = dropout_cls(resid_dropout1)
368
+ self.norm1 = norm_cls(dim)
369
+ self.mlp = mlp_cls(dim)
370
+ self.dropout2 = dropout_cls(resid_dropout2)
371
+ if not self.tied_norm:
372
+ self.norm2 = norm_cls(dim)
373
+
374
+ if self.fused_dropout_add_ln:
375
+ assert layer_norm_fn is not None, "Triton is not installed"
376
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
377
+ self.dropout1, nn.Dropout
378
+ )
379
+
380
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
381
+ # then the input to each worker in the tensor parallel group will be different.
382
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
383
+ # For now this is not an issue because we always use sequence_parallel=True during training
384
+ # and only use sequence_parallel=False during inference.
385
+
386
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
387
+ if sequence_parallel:
388
+ for p in self.norm1.parameters():
389
+ p._sequence_parallel = True
390
+ if hasattr(self, "norm2"):
391
+ for p in self.norm2.parameters():
392
+ p._sequence_parallel = True
393
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
394
+ if mark_shared_params:
395
+ for p in self.norm1.parameters():
396
+ p._shared_params = True
397
+ if hasattr(self, "norm2"):
398
+ for p in self.norm2.parameters():
399
+ p._shared_params = True
400
+
401
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
402
+ return self.mixer.allocate_inference_cache(
403
+ batch_size, max_seqlen, dtype=dtype, **kwargs
404
+ )
405
+
406
+ def forward(
407
+ self,
408
+ hidden_states1: Tensor,
409
+ hidden_states2: Optional[Tensor] = None,
410
+ residual: Optional[Tensor] = None,
411
+ mixer_kwargs=None,
412
+ ):
413
+ r"""Pass the input through the encoder layer.
414
+
415
+ Args:
416
+ hidden_states1: the output of the previous attention (mixer) or embedding layer.
417
+ hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
418
+ residual.
419
+ """
420
+ # TODO: Ideally we should only do the allgather / allreduce once for
421
+ # the Linear to MLP & Attention
422
+ if not self.fused_dropout_add_ln:
423
+ dropped1 = self.dropout1(hidden_states1)
424
+ # For the very 1st block, we only want 1 dropout, not two different dropouts
425
+ if hidden_states2 is not None:
426
+ dropped2 = self.dropout2(hidden_states2)
427
+ residual = (
428
+ (residual + dropped1 + dropped2)
429
+ if residual is not None
430
+ else dropped1 + dropped2
431
+ )
432
+ else:
433
+ residual = (residual + dropped1) if residual is not None else dropped1
434
+ hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
435
+ hidden_states2 = (
436
+ self.norm2(residual.to(dtype=self.norm2.weight.dtype))
437
+ if not self.tied_norm
438
+ else hidden_states1
439
+ )
440
+ if self.residual_in_fp32:
441
+ residual = residual.to(torch.float32)
442
+ else:
443
+ weight2, bias2 = (
444
+ (self.norm2.weight, self.norm2.bias)
445
+ if not self.tied_norm
446
+ else (None, None)
447
+ )
448
+ hidden_states1, *rest, residual = layer_norm_fn(
449
+ hidden_states1,
450
+ self.norm1.weight,
451
+ self.norm1.bias,
452
+ residual=residual,
453
+ x1=hidden_states2,
454
+ weight1=weight2,
455
+ bias1=bias2,
456
+ eps=self.norm1.eps,
457
+ dropout_p=self.dropout1.p if self.training else 0.0,
458
+ prenorm=True,
459
+ residual_in_fp32=self.residual_in_fp32,
460
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
461
+ )
462
+ if self.tied_norm:
463
+ hidden_states2 = hidden_states1
464
+ else:
465
+ (hidden_states2,) = rest
466
+ if mixer_kwargs is None:
467
+ mixer_kwargs = {}
468
+ hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
469
+ hidden_states2 = self.mlp(hidden_states2)
470
+ return hidden_states1, hidden_states2, residual
config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "jinaai/jina-reranker-v2-base-multilingual",
3
+ "architectures": [
4
+ "XLMRobertaForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "jinaai/jina-reranker-v2-base-multilingual--configuration_xlm_roberta.XLMRobertaFlashConfig",
9
+ "AutoModel": "jinaai/jina-reranker-v2-base-multilingual--modeling_xlm_roberta.XLMRobertaModel",
10
+ "AutoModelForSequenceClassification": "jinaai/jina-reranker-v2-base-multilingual--modeling_xlm_roberta.XLMRobertaForSequenceClassification"
11
+ },
12
+ "bos_token_id": 0,
13
+ "classifier_dropout": null,
14
+ "emb_pooler": null,
15
+ "eos_token_id": 2,
16
+ "hidden_act": "gelu",
17
+ "hidden_dropout_prob": 0.1,
18
+ "hidden_size": 768,
19
+ "id2label": {
20
+ "0": "LABEL_0"
21
+ },
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": 3072,
24
+ "label2id": {
25
+ "LABEL_0": 0
26
+ },
27
+ "layer_norm_eps": 1e-05,
28
+ "load_trained_adapters": false,
29
+ "lora_adaptations": null,
30
+ "lora_alpha": 1,
31
+ "lora_dropout_p": 0.0,
32
+ "lora_main_params_trainable": false,
33
+ "lora_rank": 4,
34
+ "matryoshka_dimensions": null,
35
+ "max_position_embeddings": 1026,
36
+ "num_attention_heads": 12,
37
+ "num_hidden_layers": 12,
38
+ "output_past": true,
39
+ "pad_token_id": 1,
40
+ "position_embedding_type": "absolute",
41
+ "torch_dtype": "bfloat16",
42
+ "transformers_version": "4.45.2",
43
+ "truncate_dim": null,
44
+ "type_vocab_size": 1,
45
+ "use_cache": false,
46
+ "use_flash_attn": true,
47
+ "vocab_size": 46528
48
+ }
configuration_xlm_roberta.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import torch
3
+
4
+ class XLMRobertaFlashConfig(PretrainedConfig):
5
+ def __init__(
6
+ self,
7
+ vocab_size=30522,
8
+ hidden_size=768,
9
+ num_hidden_layers=12,
10
+ num_attention_heads=12,
11
+ intermediate_size=3072,
12
+ hidden_act="gelu",
13
+ hidden_dropout_prob=0.1,
14
+ attention_probs_dropout_prob=0.1,
15
+ max_position_embeddings=512,
16
+ type_vocab_size=2,
17
+ initializer_range=0.02,
18
+ layer_norm_eps=1e-12,
19
+ pad_token_id=1,
20
+ bos_token_id=0,
21
+ eos_token_id=2,
22
+ position_embedding_type="absolute",
23
+ use_cache=True,
24
+ classifier_dropout=None,
25
+ lora_adaptations=None,
26
+ lora_rank=4,
27
+ lora_dropout_p=0.0,
28
+ lora_alpha=1,
29
+ lora_main_params_trainable=False,
30
+ load_trained_adapters=False,
31
+ use_flash_attn=True,
32
+ torch_dtype=None,
33
+ emb_pooler=None,
34
+ matryoshka_dimensions=None,
35
+ truncate_dim=None,
36
+ **kwargs,
37
+ ):
38
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
39
+
40
+
41
+ self.vocab_size = vocab_size
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.num_attention_heads = num_attention_heads
45
+ self.hidden_act = hidden_act
46
+ self.intermediate_size = intermediate_size
47
+ self.hidden_dropout_prob = hidden_dropout_prob
48
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
49
+ self.max_position_embeddings = max_position_embeddings
50
+ self.type_vocab_size = type_vocab_size
51
+ self.initializer_range = initializer_range
52
+ self.layer_norm_eps = layer_norm_eps
53
+ self.position_embedding_type = position_embedding_type
54
+ self.use_cache = use_cache
55
+ self.classifier_dropout = classifier_dropout
56
+ self.load_trained_adapters = load_trained_adapters
57
+ self.lora_adaptations = lora_adaptations
58
+ self.lora_rank = lora_rank
59
+ self.lora_dropout_p = lora_dropout_p
60
+ self.lora_alpha = lora_alpha
61
+ self.lora_main_params_trainable = lora_main_params_trainable
62
+ self.use_flash_attn = use_flash_attn
63
+ self.emb_pooler = emb_pooler
64
+ self.matryoshka_dimensions = matryoshka_dimensions
65
+ self.truncate_dim = truncate_dim
66
+ if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
67
+ self.torch_dtype = getattr(torch, torch_dtype)
68
+ else:
69
+ self.torch_dtype = torch_dtype
embedding.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py
2
+ # Commit id: f1a73d074002226c42ce65a1df170ecff9f022c0
3
+
4
+ # Copyright (c) 2022, Tri Dao.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from torch import Tensor
10
+
11
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
12
+
13
+
14
+ class XLMRobertaEmbeddings(nn.Module):
15
+ def __init__(
16
+ self,
17
+ embed_dim,
18
+ vocab_size,
19
+ max_position_embeddings,
20
+ type_vocab_size,
21
+ padding_idx=None,
22
+ device=None,
23
+ dtype=None,
24
+ ):
25
+ """
26
+ If max_position_embeddings <= 0, there's no position embeddings
27
+ If type_vocab_size <= 0, there's no token type embeddings
28
+ """
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ self.word_embeddings = nn.Embedding(
32
+ vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
33
+ )
34
+ self.max_position_embeddings = max_position_embeddings
35
+ self.type_vocab_size = type_vocab_size
36
+ if self.max_position_embeddings > 0:
37
+ self.position_embeddings = nn.Embedding(
38
+ max_position_embeddings, embed_dim, **factory_kwargs
39
+ )
40
+ if self.type_vocab_size > 0:
41
+ self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
+
43
+ def forward(self, input_ids, position_ids=None, token_type_ids=None):
44
+ """
45
+ input_ids: (batch, seqlen)
46
+ position_ids: (batch, seqlen)
47
+ token_type_ids: (batch, seqlen)
48
+ """
49
+ batch_size, seqlen = input_ids.shape
50
+ embeddings = self.word_embeddings(input_ids)
51
+ if self.max_position_embeddings > 0:
52
+ if position_ids is None:
53
+ position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
54
+ # position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
55
+ position_embeddings = self.position_embeddings(position_ids)
56
+ embeddings = embeddings + position_embeddings
57
+ if self.type_vocab_size > 0:
58
+ if token_type_ids is None:
59
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
60
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
61
+ embeddings = embeddings + token_type_embeddings
62
+ return embeddings
mha.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+ # Adapted from https://github.com/Dao-AILab/flash-attention/pull/556
3
+
4
+ import math
5
+ from functools import partial
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops import rearrange, repeat
10
+
11
+ try:
12
+ from flash_attn import (
13
+ flash_attn_kvpacked_func,
14
+ flash_attn_qkvpacked_func,
15
+ flash_attn_varlen_kvpacked_func,
16
+ flash_attn_varlen_qkvpacked_func,
17
+ flash_attn_with_kvcache,
18
+ )
19
+ except ImportError:
20
+ flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
21
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
22
+ flash_attn_with_kvcache = None
23
+
24
+ try:
25
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
26
+ except ImportError:
27
+ FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
28
+
29
+
30
+ class FlashSelfAttention(nn.Module):
31
+ """Implement the scaled dot product attention with softmax.
32
+ Arguments
33
+ ---------
34
+ softmax_scale: The temperature to use for the softmax attention.
35
+ (default: 1/sqrt(d_keys) where d_keys is computed at
36
+ runtime)
37
+ attention_dropout: The dropout rate to apply to the attention
38
+ (default: 0.0)
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ causal=False,
44
+ softmax_scale=None,
45
+ attention_dropout=0.0,
46
+ window_size=(-1, -1),
47
+ deterministic=False,
48
+ ):
49
+ super().__init__()
50
+ assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
51
+ assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
52
+ self.causal = causal
53
+ self.softmax_scale = softmax_scale
54
+ self.drop = nn.Dropout(attention_dropout)
55
+ self.window_size = window_size
56
+ self.deterministic = deterministic
57
+
58
+ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
59
+ """Implements the multihead softmax attention.
60
+ Arguments
61
+ ---------
62
+ qkv: The tensor containing the query, key, and value.
63
+ If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
64
+ If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
65
+ (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
66
+ causal: if passed, will override self.causal
67
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
68
+ of the sequences in the batch, used to index into qkv.
69
+ max_seqlen: int. Maximum sequence length in the batch.
70
+ Returns:
71
+ --------
72
+ out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
73
+ else (B, S, H, D).
74
+ """
75
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
76
+ assert qkv.is_cuda
77
+ causal = self.causal if causal is None else causal
78
+ unpadded = cu_seqlens is not None
79
+
80
+ if unpadded:
81
+ assert cu_seqlens.dtype == torch.int32
82
+ assert max_seqlen is not None
83
+ assert isinstance(max_seqlen, int)
84
+ return flash_attn_varlen_qkvpacked_func(
85
+ qkv,
86
+ cu_seqlens,
87
+ max_seqlen,
88
+ self.drop.p if self.training else 0.0,
89
+ softmax_scale=self.softmax_scale,
90
+ causal=causal,
91
+ alibi_slopes=None,
92
+ window_size=self.window_size,
93
+ deterministic=self.deterministic,
94
+ )
95
+ else:
96
+ return flash_attn_qkvpacked_func(
97
+ qkv,
98
+ self.drop.p if self.training else 0.0,
99
+ softmax_scale=self.softmax_scale,
100
+ causal=causal,
101
+ alibi_slopes=None,
102
+ window_size=self.window_size,
103
+ deterministic=self.deterministic,
104
+ )
105
+
106
+
107
+ class FlashCrossAttention(nn.Module):
108
+ """Implement the scaled dot product attention with softmax.
109
+ Arguments
110
+ ---------
111
+ softmax_scale: The temperature to use for the softmax attention.
112
+ (default: 1/sqrt(d_keys) where d_keys is computed at
113
+ runtime)
114
+ attention_dropout: The dropout rate to apply to the attention
115
+ (default: 0.0)
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ causal=False,
121
+ softmax_scale=None,
122
+ attention_dropout=0.0,
123
+ window_size=(-1, -1),
124
+ deterministic=False,
125
+ ):
126
+ super().__init__()
127
+ assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
128
+ assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
129
+ self.causal = causal
130
+ self.softmax_scale = softmax_scale
131
+ self.drop = nn.Dropout(attention_dropout)
132
+ self.window_size = window_size
133
+ self.deterministic = deterministic
134
+
135
+ def forward(
136
+ self,
137
+ q,
138
+ kv,
139
+ causal=None,
140
+ cu_seqlens=None,
141
+ max_seqlen=None,
142
+ cu_seqlens_k=None,
143
+ max_seqlen_k=None,
144
+ ):
145
+ """Implements the multihead softmax attention.
146
+ Arguments
147
+ ---------
148
+ q: The tensor containing the query. (B, Sq, H, D)
149
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
150
+ causal: if passed, will override self.causal
151
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
152
+ of the sequences in the batch, used to index into q.
153
+ max_seqlen: int. Maximum sequence length in the batch of q.
154
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
155
+ of the sequences in the batch, used to index into kv.
156
+ max_seqlen_k: int. Maximum sequence length in the batch of k and v.
157
+ """
158
+ assert q.dtype in [torch.float16, torch.bfloat16]
159
+ assert q.is_cuda and kv.is_cuda
160
+ causal = self.causal if causal is None else causal
161
+ unpadded = cu_seqlens is not None
162
+
163
+ if unpadded:
164
+ assert cu_seqlens.dtype == torch.int32
165
+ assert max_seqlen is not None
166
+ assert isinstance(max_seqlen, int)
167
+ assert cu_seqlens_k is not None
168
+ assert cu_seqlens_k.dtype == torch.int32
169
+ assert max_seqlen_k is not None
170
+ assert isinstance(max_seqlen, int)
171
+ return flash_attn_varlen_kvpacked_func(
172
+ q,
173
+ kv,
174
+ cu_seqlens,
175
+ cu_seqlens_k,
176
+ max_seqlen,
177
+ max_seqlen_k,
178
+ self.drop.p if self.training else 0.0,
179
+ softmax_scale=self.softmax_scale,
180
+ causal=causal,
181
+ alibi_slopes=None,
182
+ window_size=self.window_size,
183
+ deterministic=self.deterministic,
184
+ )
185
+ else:
186
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
187
+ seqlen_k = kv.shape[1]
188
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
189
+ return flash_attn_kvpacked_func(
190
+ q,
191
+ kv,
192
+ self.drop.p if self.training else 0.0,
193
+ causal=causal,
194
+ softmax_scale=self.softmax_scale,
195
+ alibi_slopes=None,
196
+ window_size=self.window_size,
197
+ deterministic=self.deterministic,
198
+ )
199
+
200
+
201
+ class SelfAttention(nn.Module):
202
+ """Implement the scaled dot product attention with softmax.
203
+ Arguments
204
+ ---------
205
+ softmax_scale: The temperature to use for the softmax attention.
206
+ (default: 1/sqrt(d_keys) where d_keys is computed at
207
+ runtime)
208
+ attention_dropout: The dropout rate to apply to the attention
209
+ (default: 0.0)
210
+ """
211
+
212
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
213
+ super().__init__()
214
+ self.causal = causal
215
+ self.softmax_scale = softmax_scale
216
+ self.drop = nn.Dropout(attention_dropout)
217
+
218
+ def forward(self, qkv, causal=None, key_padding_mask=None):
219
+ """Implements the multihead softmax attention.
220
+ Arguments
221
+ ---------
222
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
223
+ causal: if passed, will override self.causal
224
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
225
+ False means to mask out. (B, S)
226
+ """
227
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
228
+ causal = self.causal if causal is None else causal
229
+ q, k, v = qkv.unbind(dim=2)
230
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
231
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
232
+ if key_padding_mask is not None:
233
+ padding_mask = torch.full(
234
+ (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
235
+ )
236
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
237
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
238
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
239
+ if causal:
240
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
241
+ # So we have to construct the mask in float
242
+ causal_mask = torch.triu(
243
+ torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
244
+ )
245
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
246
+ scores = scores + causal_mask.to(dtype=scores.dtype)
247
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
248
+ attention_drop = self.drop(attention)
249
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
250
+ return output
251
+
252
+
253
+ class CrossAttention(nn.Module):
254
+ """Implement the scaled dot product attention with softmax.
255
+ Arguments
256
+ ---------
257
+ softmax_scale: The temperature to use for the softmax attention.
258
+ (default: 1/sqrt(d_keys) where d_keys is computed at
259
+ runtime)
260
+ attention_dropout: The dropout rate to apply to the attention
261
+ (default: 0.0)
262
+ """
263
+
264
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
265
+ super().__init__()
266
+ self.causal = causal
267
+ self.softmax_scale = softmax_scale
268
+ self.drop = nn.Dropout(attention_dropout)
269
+
270
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
271
+ """Implements the multihead softmax attention.
272
+ Arguments
273
+ ---------
274
+ q: The tensor containing the query. (B, Sq, H, D)
275
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
276
+ causal: if passed, will override self.causal
277
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
278
+ False means to mask out. (B, Sk)
279
+ """
280
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
281
+ causal = self.causal if causal is None else causal
282
+ seqlen_k = kv.shape[1]
283
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
284
+ if kv.shape[3] != q.shape[2]: # MQA/GQA
285
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
286
+ k, v = kv.unbind(dim=2)
287
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
288
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
289
+ if key_padding_mask is not None:
290
+ padding_mask = torch.full(
291
+ (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
292
+ )
293
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
294
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
295
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
296
+ if causal:
297
+ # causal mask needs to take into account the difference between seqlen_q and seqlen_k
298
+ row_idx = rearrange(
299
+ torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
300
+ )
301
+ col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
302
+ sk = (
303
+ seqlen_k
304
+ if key_padding_mask is None
305
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
306
+ )
307
+ causal_mask = col_idx > row_idx + sk - seqlen_q
308
+ scores = scores.masked_fill(causal_mask, -10000.0)
309
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
310
+ attention_drop = self.drop(attention)
311
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
312
+ return output
313
+
314
+
315
+ class LinearResidual(nn.Linear):
316
+ """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
317
+
318
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
319
+ return super().forward(input), input
320
+
321
+
322
+ def _update_kv_cache(kv, inference_params, layer_idx):
323
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
324
+ # Pre-allocate memory for key-values for inference.
325
+ num_heads, head_dim = kv.shape[-2:]
326
+ if layer_idx not in inference_params.key_value_memory_dict:
327
+ kv_cache = torch.empty(
328
+ inference_params.max_batch_size,
329
+ inference_params.max_seqlen,
330
+ 2,
331
+ num_heads,
332
+ head_dim,
333
+ dtype=kv.dtype,
334
+ device=kv.device,
335
+ )
336
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
337
+ else:
338
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
339
+ # Adjust key and value for inference
340
+ batch_start = inference_params.batch_size_offset
341
+ batch_end = batch_start + kv.shape[0]
342
+ sequence_start = inference_params.seqlen_offset
343
+ sequence_end = sequence_start + kv.shape[1]
344
+ assert batch_end <= kv_cache.shape[0]
345
+ assert sequence_end <= kv_cache.shape[1]
346
+ assert kv_cache is not None
347
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
348
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
349
+
350
+
351
+ class MHA(nn.Module):
352
+ """Multi-head self-attention and cross-attention"""
353
+
354
+ def __init__(
355
+ self,
356
+ embed_dim,
357
+ num_heads,
358
+ num_heads_kv=None,
359
+ cross_attn=False,
360
+ qkv_proj_bias=True,
361
+ out_proj_bias=True,
362
+ dropout=0.0,
363
+ softmax_scale=None,
364
+ causal=False,
365
+ layer_idx=None,
366
+ dwconv=False,
367
+ window_size=(-1, -1),
368
+ fused_bias_fc=False,
369
+ use_flash_attn=False,
370
+ return_residual=False,
371
+ checkpointing=False,
372
+ device=None,
373
+ dtype=None,
374
+ ) -> None:
375
+ """
376
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
377
+ return_residual: whether to return the input x along with the output. This is for
378
+ performance reason: for post-norm architecture, returning the input allows us
379
+ to fuse the backward of nn.Linear with the residual connection.
380
+ """
381
+ factory_kwargs = {"device": device, "dtype": dtype}
382
+ super().__init__()
383
+ self.embed_dim = embed_dim
384
+ self.cross_attn = cross_attn
385
+ self.causal = causal
386
+ self.layer_idx = layer_idx
387
+ self.dwconv = dwconv
388
+ self.use_flash_attn = use_flash_attn
389
+ self.return_residual = return_residual
390
+ self.checkpointing = checkpointing
391
+
392
+ if window_size != (-1, -1):
393
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
394
+
395
+ self.num_heads = num_heads
396
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
397
+ assert (
398
+ self.num_heads % self.num_heads_kv == 0
399
+ ), "num_heads must be divisible by num_heads_kv"
400
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
401
+ self.head_dim = self.embed_dim // num_heads
402
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
403
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
404
+
405
+ if fused_bias_fc and FusedDense is None:
406
+ raise ImportError("fused_dense is not installed")
407
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
408
+ linear_resid_cls = (
409
+ LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
410
+ )
411
+ wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
412
+ inner_attn_cls = (
413
+ partial(FlashSelfAttention, window_size=window_size)
414
+ if use_flash_attn
415
+ else SelfAttention
416
+ )
417
+ inner_cross_attn_cls = (
418
+ partial(FlashCrossAttention, window_size=window_size)
419
+ if use_flash_attn
420
+ else CrossAttention
421
+ )
422
+ if not self.cross_attn:
423
+ self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
424
+ else:
425
+ self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
426
+ self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
427
+ if self.dwconv:
428
+ if self.num_heads_kv == self.num_heads:
429
+ self.dwconv_qkv = nn.Conv1d(
430
+ qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
431
+ )
432
+ else:
433
+ self.dwconv_q = nn.Conv1d(
434
+ embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
435
+ )
436
+ self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
437
+ self.inner_attn = inner_attn_cls(
438
+ causal=causal,
439
+ softmax_scale=softmax_scale,
440
+ attention_dropout=dropout,
441
+ )
442
+ self.inner_cross_attn = inner_cross_attn_cls(
443
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
444
+ )
445
+ self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
446
+
447
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
448
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
449
+ device = self.out_proj.weight.device
450
+ return torch.empty(
451
+ batch_size,
452
+ max_seqlen,
453
+ 2,
454
+ self.num_heads_kv,
455
+ self.head_dim,
456
+ dtype=dtype,
457
+ device=device,
458
+ )
459
+
460
+ def _update_kv_cache(self, kv, inference_params):
461
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
462
+ assert not self.dwconv, "Generation does not support dwconv yet"
463
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
464
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
465
+
466
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
467
+ """
468
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
469
+ q: (batch_size, seqlen_q, nheads, head_dim)
470
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
471
+ """
472
+ assert inference_params is not None and inference_params.seqlen_offset > 0
473
+ assert self.use_flash_attn
474
+ batch = q.shape[0]
475
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
476
+ cache_seqlens = (
477
+ inference_params.lengths_per_sample[:batch]
478
+ if inference_params.lengths_per_sample is not None
479
+ else inference_params.seqlen_offset
480
+ )
481
+ context = flash_attn_with_kvcache(
482
+ q,
483
+ kv_cache[:, :, 0],
484
+ kv_cache[:, :, 1],
485
+ kv[:, :, 0],
486
+ kv[:, :, 1],
487
+ cache_seqlens=cache_seqlens,
488
+ softmax_scale=self.inner_cross_attn.softmax_scale,
489
+ causal=self.inner_cross_attn.causal,
490
+ rotary_interleaved=False,
491
+ alibi_slopes=None,
492
+ )
493
+ return context
494
+
495
+ def _update_kvcache_attention(self, q, kv, inference_params):
496
+ """Write kv to inference_params, then do attention"""
497
+ if (
498
+ inference_params.seqlen_offset == 0
499
+ or flash_attn_with_kvcache is None
500
+ or not self.use_flash_attn
501
+ ):
502
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
503
+ kv = self._update_kv_cache(kv, inference_params)
504
+ return self.inner_cross_attn(q, kv)
505
+ else:
506
+ batch = q.shape[0]
507
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
508
+ cache_seqlens = (
509
+ inference_params.lengths_per_sample[:batch]
510
+ if inference_params.lengths_per_sample is not None
511
+ else inference_params.seqlen_offset
512
+ )
513
+ return flash_attn_with_kvcache(
514
+ q,
515
+ kv_cache[:, :, 0],
516
+ kv_cache[:, :, 1],
517
+ kv[:, :, 0],
518
+ kv[:, :, 1],
519
+ cache_seqlens=cache_seqlens,
520
+ softmax_scale=self.inner_cross_attn.softmax_scale,
521
+ causal=self.inner_cross_attn.causal,
522
+ alibi_slopes=None,
523
+ )
524
+
525
+ def forward(
526
+ self,
527
+ x,
528
+ x_kv=None,
529
+ key_padding_mask=None,
530
+ cu_seqlens=None,
531
+ max_seqlen=None,
532
+ mixer_subset=None,
533
+ inference_params=None,
534
+ **kwargs,
535
+ ):
536
+ """
537
+ Arguments:
538
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
539
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
540
+ is the is the sum of the sequence lengths in the batch.
541
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
542
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
543
+ of the sequences in the batch, used to index into x. Only applicable when using
544
+ FlashAttention.
545
+ max_seqlen: int. Maximum sequence length in the batch.
546
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
547
+ (batch, seqlen). Only applicable when not using FlashAttention.
548
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
549
+ before applying the query projection. Useful for e.g., ViT where we only care
550
+ about the CLS token in the last layer.
551
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
552
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
553
+ """
554
+ if cu_seqlens is not None:
555
+ assert max_seqlen is not None
556
+ assert key_padding_mask is None
557
+ assert self.use_flash_attn
558
+ assert not self.dwconv
559
+ if key_padding_mask is not None:
560
+ assert cu_seqlens is None
561
+ assert max_seqlen is None
562
+ assert not self.use_flash_attn
563
+ if inference_params is not None:
564
+ assert key_padding_mask is None
565
+ assert cu_seqlens is None and max_seqlen is None
566
+ assert not self.dwconv
567
+
568
+ kwargs = (
569
+ {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
570
+ if self.use_flash_attn
571
+ else {"key_padding_mask": key_padding_mask, **kwargs}
572
+ )
573
+ seqlen_offset = (
574
+ 0
575
+ if inference_params is None
576
+ else (
577
+ inference_params.lengths_per_sample
578
+ if inference_params.lengths_per_sample is not None
579
+ else inference_params.seqlen_offset
580
+ )
581
+ )
582
+ rotary_max_seqlen = (
583
+ inference_params.max_sequence_len if inference_params is not None else max_seqlen
584
+ )
585
+ batch, seqlen = x.shape[:2]
586
+ if not self.cross_attn and self.num_heads_kv == self.num_heads:
587
+ assert x_kv is None and mixer_subset is None
588
+ if not self.return_residual:
589
+ qkv = self.Wqkv(x)
590
+ else:
591
+ qkv, x = self.Wqkv(x)
592
+ if self.dwconv:
593
+ qkv = rearrange(
594
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
595
+ ).contiguous()
596
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
597
+ if (
598
+ inference_params is None
599
+ or inference_params.seqlen_offset == 0
600
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
601
+ or not self.use_flash_attn
602
+ ):
603
+ if inference_params is None:
604
+ if not self.checkpointing:
605
+ context = self.inner_attn(qkv, **kwargs)
606
+ else:
607
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
608
+ else:
609
+ context = self._update_kvcache_attention(
610
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
611
+ )
612
+ else:
613
+ context = self._apply_rotary_update_kvcache_attention(
614
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
615
+ )
616
+ else:
617
+ if self.cross_attn:
618
+ if not self.return_residual:
619
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
620
+ kv = self.Wkv(x_kv if x_kv is not None else x)
621
+ else:
622
+ if x_kv is not None:
623
+ kv, x_kv = self.Wkv(x_kv)
624
+ else:
625
+ kv, x = self.Wkv(x)
626
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
627
+ else:
628
+ assert self.num_heads_kv != self.num_heads
629
+ if not self.return_residual:
630
+ qkv = self.Wqkv(x)
631
+ else:
632
+ qkv, x = self.Wqkv(x)
633
+ q = qkv[..., : self.num_heads * self.head_dim]
634
+ kv = qkv[..., self.num_heads * self.head_dim :]
635
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
636
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
637
+ if self.dwconv:
638
+ q = rearrange(
639
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
640
+ ).contiguous()
641
+ kv = rearrange(
642
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
643
+ ).contiguous()
644
+ if (
645
+ inference_params is None
646
+ or inference_params.seqlen_offset == 0
647
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
648
+ or not self.use_flash_attn
649
+ ):
650
+ if inference_params is None:
651
+ if not self.checkpointing:
652
+ context = self.inner_cross_attn(q, kv, **kwargs)
653
+ else:
654
+ context = torch.utils.checkpoint.checkpoint(
655
+ self.inner_cross_attn, q, kv, **kwargs
656
+ )
657
+ else:
658
+ context = self._update_kvcache_attention(q, kv, inference_params)
659
+ else:
660
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
661
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
662
+ return out if not self.return_residual else (out, x)
mlp.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mlp.py
2
+ # Commit id: c3b219665292c61a51153d0ded4473c494296382
3
+
4
+ # Copyright (c) 2023, Tri Dao.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.distributed import ProcessGroup
10
+
11
+
12
+ try:
13
+ from flash_attn.ops.activations import swiglu
14
+ except ImportError:
15
+ swiglu = None
16
+
17
+ try:
18
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
19
+ except ImportError:
20
+ ColumnParallelLinear, RowParallelLinear = None, None
21
+
22
+ try:
23
+ from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
24
+ except ImportError:
25
+ FusedMLP, ParallelFusedMLP = None, None
26
+
27
+
28
+ class Mlp(nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_features,
32
+ hidden_features=None,
33
+ out_features=None,
34
+ activation=F.gelu,
35
+ bias1=True,
36
+ bias2=True,
37
+ return_residual=False,
38
+ device=None,
39
+ dtype=None,
40
+ ):
41
+ factory_kwargs = {"device": device, "dtype": dtype}
42
+ super().__init__()
43
+ out_features = out_features if out_features is not None else in_features
44
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
45
+ self.return_residual = return_residual
46
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
47
+ self.activation = activation
48
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
+
50
+ def forward(self, x):
51
+ y = self.fc1(x)
52
+ y = self.activation(y)
53
+ y = self.fc2(y)
54
+ return y if not self.return_residual else (y, x)
55
+
56
+
57
+ class ParallelMLP(nn.Module):
58
+ def __init__(
59
+ self,
60
+ in_features,
61
+ hidden_features=None,
62
+ out_features=None,
63
+ activation=F.gelu,
64
+ process_group: ProcessGroup = None,
65
+ sequence_parallel=True,
66
+ bias1=True,
67
+ bias2=True,
68
+ device=None,
69
+ dtype=None,
70
+ ):
71
+ factory_kwargs = {"device": device, "dtype": dtype}
72
+ super().__init__()
73
+ assert ColumnParallelLinear is not None, "Need to install fused_dense"
74
+ assert RowParallelLinear is not None, "Need to install fused_dense"
75
+ out_features = out_features if out_features is not None else in_features
76
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
77
+ self.fc1 = ColumnParallelLinear(
78
+ in_features,
79
+ hidden_features,
80
+ process_group,
81
+ bias=bias1,
82
+ sequence_parallel=sequence_parallel,
83
+ **factory_kwargs,
84
+ )
85
+ self.activation = activation
86
+ self.fc2 = RowParallelLinear(
87
+ hidden_features,
88
+ out_features,
89
+ process_group,
90
+ bias=bias2,
91
+ sequence_parallel=sequence_parallel,
92
+ **factory_kwargs,
93
+ )
94
+
95
+ def forward(self, x):
96
+ y = self.fc1(x)
97
+ y = self.activation(y)
98
+ y = self.fc2(y)
99
+ return y
100
+
101
+
102
+ class GatedMlp(nn.Module):
103
+ def __init__(
104
+ self,
105
+ in_features,
106
+ hidden_features=None,
107
+ out_features=None,
108
+ activation=F.sigmoid,
109
+ bias1=True,
110
+ bias2=True,
111
+ multiple_of=128,
112
+ return_residual=False,
113
+ device=None,
114
+ dtype=None,
115
+ ):
116
+ factory_kwargs = {"device": device, "dtype": dtype}
117
+ super().__init__()
118
+ out_features = out_features if out_features is not None else in_features
119
+ hidden_features = (
120
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
121
+ )
122
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
123
+ self.return_residual = return_residual
124
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
125
+ self.activation = activation
126
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
127
+
128
+ def forward(self, x):
129
+ y = self.fc1(x)
130
+ if self.activation == F.sigmoid: # Special case for GLU
131
+ y = F.glu(y, dim=-1)
132
+ elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
133
+ y, gate = y.chunk(2, dim=-1)
134
+ y = swiglu(gate, y)
135
+ else:
136
+ y, gate = y.chunk(2, dim=-1)
137
+ y = y * self.activation(gate)
138
+ y = self.fc2(y)
139
+ return y if not self.return_residual else (y, x)
140
+
141
+
142
+ class ParallelGatedMlp(nn.Module):
143
+ """Parallel GatedMlp"""
144
+
145
+ def __init__(
146
+ self,
147
+ in_features,
148
+ process_group,
149
+ hidden_features=None,
150
+ out_features=None,
151
+ activation=F.sigmoid,
152
+ bias1=True,
153
+ bias2=True,
154
+ multiple_of=128,
155
+ sequence_parallel=True,
156
+ device=None,
157
+ dtype=None,
158
+ ):
159
+ factory_kwargs = {"device": device, "dtype": dtype}
160
+ super().__init__()
161
+ out_features = out_features if out_features is not None else in_features
162
+ hidden_features = (
163
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
164
+ )
165
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
166
+ if ColumnParallelLinear is None or RowParallelLinear is None:
167
+ raise ImportError("fused_dense is not installed")
168
+ self.fc1 = ColumnParallelLinear(
169
+ in_features,
170
+ 2 * hidden_features,
171
+ process_group,
172
+ bias=bias1,
173
+ sequence_parallel=sequence_parallel,
174
+ **factory_kwargs,
175
+ )
176
+ self.activation = activation
177
+ self.fc2 = RowParallelLinear(
178
+ hidden_features,
179
+ out_features,
180
+ process_group,
181
+ bias=bias2,
182
+ sequence_parallel=sequence_parallel,
183
+ **factory_kwargs,
184
+ )
185
+
186
+ def forward(self, x):
187
+ y = self.fc1(x)
188
+ if self.activation == F.sigmoid: # Special case for GLU
189
+ y = F.glu(y, dim=-1)
190
+ else:
191
+ y, gate = y.chunk(2, dim=-1)
192
+ y = y * self.activation(gate)
193
+ y = self.fc2(y)
194
+ return y
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8490bd5c0e2e58480ad1279ff2a7b2605c6c12a5f0fa9345141532bf166920b
3
+ size 244356202
modeling_xlm_roberta.py ADDED
@@ -0,0 +1,1119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+ # Copyright (c) 2022, Tri Dao.
4
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
5
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
6
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
7
+
8
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
+
10
+ import importlib.util
11
+ import logging
12
+ import re
13
+ from collections import OrderedDict
14
+ from collections.abc import Sequence
15
+ from functools import partial
16
+ import numpy as np
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.utils.checkpoint
22
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
+ from einops import rearrange
24
+ from transformers import PretrainedConfig
25
+ from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
28
+
29
+ from transformers.models.bert.modeling_bert import (
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ BertForPreTrainingOutput,
32
+ )
33
+
34
+ from typing import List, Optional, Tuple, Union
35
+
36
+ from .xlm_padding import (
37
+ index_first_axis,
38
+ index_first_axis_residual,
39
+ pad_input,
40
+ unpad_input,
41
+ )
42
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig
43
+ from .block import Block
44
+ from .embedding import XLMRobertaEmbeddings
45
+ from .mha import MHA
46
+ from .mlp import FusedMLP, Mlp
47
+
48
+ try:
49
+ from flash_attn.ops.fused_dense import FusedDense
50
+ except ImportError:
51
+ FusedDense = None
52
+
53
+ try:
54
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn
55
+ except ImportError:
56
+ layer_norm_fn = None
57
+
58
+
59
+ try:
60
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
61
+ except ImportError:
62
+ CrossEntropyLoss = torch.nn.CrossEntropyLoss
63
+
64
+ try:
65
+ from tqdm.autonotebook import trange
66
+ except ImportError:
67
+ trange = None
68
+
69
+
70
+ logger = logging.getLogger(__name__)
71
+
72
+
73
+ def get_use_flash_attn(config: XLMRobertaFlashConfig):
74
+ if not getattr(config, "use_flash_attn", False):
75
+ return False
76
+ if not torch.cuda.is_available():
77
+ return False
78
+ if importlib.util.find_spec("flash_attn") is None:
79
+ logger.warning(
80
+ 'flash_attn is not installed. Using PyTorch native attention implementation.'
81
+ )
82
+ return False
83
+ return True
84
+
85
+
86
+ def create_mixer_cls(config, cross_attn=False, return_residual=False):
87
+ use_flash_attn = get_use_flash_attn(config)
88
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
89
+
90
+ mixer_cls = partial(
91
+ MHA,
92
+ num_heads=config.num_attention_heads,
93
+ cross_attn=cross_attn,
94
+ dropout=config.attention_probs_dropout_prob,
95
+ causal=False,
96
+ fused_bias_fc=fused_bias_fc,
97
+ use_flash_attn=use_flash_attn,
98
+ return_residual=return_residual,
99
+ )
100
+ return mixer_cls
101
+
102
+
103
+ def create_mlp_cls(config, layer_idx=None, return_residual=False):
104
+ inner_dim = config.intermediate_size
105
+ fused_mlp = getattr(config, "fused_mlp", False)
106
+ if fused_mlp:
107
+ assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
108
+ "fused_mlp only " "supports approximate gelu"
109
+ )
110
+ if not fused_mlp:
111
+ approximate = (
112
+ "tanh"
113
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
114
+ else "none"
115
+ )
116
+ mlp_cls = partial(
117
+ Mlp,
118
+ hidden_features=inner_dim,
119
+ activation=partial(F.gelu, approximate=approximate),
120
+ return_residual=return_residual,
121
+ )
122
+ else:
123
+ if FusedMLP is None:
124
+ raise ImportError("fused_dense is not installed")
125
+ mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
126
+ # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
127
+ if isinstance(mlp_checkpoint_lvl, Sequence):
128
+ assert layer_idx is not None
129
+ mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
130
+ mlp_cls = partial(
131
+ FusedMLP,
132
+ hidden_features=inner_dim,
133
+ checkpoint_lvl=mlp_checkpoint_lvl,
134
+ return_residual=return_residual,
135
+ )
136
+ return mlp_cls
137
+
138
+
139
+ def create_block(config, layer_idx=None):
140
+ last_layer_subset = getattr(config, "last_layer_subset", False)
141
+ cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
142
+ # TD [2022-12-19]: For cross attention (last layer), we actually want to return the
143
+ # residual x_kv, not residual x. But it's annoying to change the API (and it only affects
144
+ # one layer) so we just choose not to return residual in this case.
145
+ return_residual = not cross_attn
146
+ mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
147
+ mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
148
+ norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
149
+ block = Block(
150
+ config.hidden_size,
151
+ mixer_cls,
152
+ mlp_cls,
153
+ norm_cls=norm_cls,
154
+ prenorm=False,
155
+ resid_dropout1=config.hidden_dropout_prob,
156
+ resid_dropout2=config.hidden_dropout_prob,
157
+ fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
158
+ return_residual=return_residual,
159
+ )
160
+ return block
161
+
162
+
163
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
164
+ def _init_weights(module, initializer_range=0.02):
165
+ if isinstance(module, nn.Linear):
166
+ nn.init.normal_(module.weight, std=initializer_range)
167
+ if module.bias is not None:
168
+ nn.init.zeros_(module.bias)
169
+ elif isinstance(module, nn.Embedding):
170
+ nn.init.normal_(module.weight, std=initializer_range)
171
+ if module.padding_idx is not None:
172
+ nn.init.zeros_(module.weight[module.padding_idx])
173
+
174
+
175
+ class XLMRobertaEncoder(nn.Module):
176
+ def __init__(self, config: XLMRobertaFlashConfig):
177
+ super().__init__()
178
+ self.use_flash_attn = get_use_flash_attn(config)
179
+ self.layers = nn.ModuleList(
180
+ [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
181
+ )
182
+ self._grad_checkpointing = False
183
+
184
+ @property
185
+ def gradient_checkpointing(self):
186
+ return self._grad_checkpointing
187
+
188
+ @gradient_checkpointing.setter
189
+ def gradient_checkpointing(self, value):
190
+ self._grad_checkpointing = value
191
+
192
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
193
+ """If subset_mask is not None, we only want output for the subset of the sequence.
194
+ This means that we only compute the last layer output for these tokens.
195
+ subset_mask: (batch, seqlen), dtype=torch.bool
196
+ """
197
+ if key_padding_mask is None or not self.use_flash_attn:
198
+ mixer_kwargs = (
199
+ {"key_padding_mask": key_padding_mask.bool()}
200
+ if key_padding_mask is not None
201
+ else None
202
+ )
203
+ for layer in self.layers:
204
+ if self._grad_checkpointing:
205
+ hidden_states = torch.utils.checkpoint.checkpoint(
206
+ layer,
207
+ hidden_states,
208
+ use_reentrant=False,
209
+ mixer_kwargs=mixer_kwargs,
210
+ )
211
+ else:
212
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
213
+ if subset_mask is not None:
214
+ hidden_states = hidden_states[subset_mask]
215
+ else:
216
+ batch, seqlen = hidden_states.shape[:2]
217
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
218
+ hidden_states, key_padding_mask
219
+ )
220
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
221
+ if subset_mask is None:
222
+ for layer in self.layers:
223
+ if self._grad_checkpointing:
224
+ hidden_states = torch.utils.checkpoint.checkpoint(
225
+ layer,
226
+ hidden_states,
227
+ use_reentrant=False,
228
+ mixer_kwargs=mixer_kwargs,
229
+ )
230
+ else:
231
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
232
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
233
+ else:
234
+ for layer in self.layers[:-1]:
235
+ if self._grad_checkpointing:
236
+ hidden_states = torch.utils.checkpoint.checkpoint(
237
+ layer,
238
+ hidden_states,
239
+ use_reentrant=False,
240
+ mixer_kwargs=mixer_kwargs,
241
+ )
242
+ else:
243
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
244
+ if key_padding_mask is not None:
245
+ subset_idx = torch.nonzero(
246
+ subset_mask[key_padding_mask], as_tuple=False
247
+ ).flatten()
248
+ subset_seqlens = (subset_mask & key_padding_mask).sum(
249
+ dim=-1, dtype=torch.int32
250
+ )
251
+ subset_cu_seqlens = F.pad(
252
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
253
+ (1, 0),
254
+ )
255
+ else:
256
+ subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
257
+ subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
258
+ subset_cu_seqlens = F.pad(
259
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
260
+ (1, 0),
261
+ )
262
+ hidden_states_subset, hidden_states = index_first_axis_residual(
263
+ hidden_states, subset_idx
264
+ )
265
+ # It's ok to set max_seqlen_q to be much larger
266
+ mixer_kwargs = {
267
+ "x_kv": hidden_states,
268
+ "cu_seqlens": subset_cu_seqlens,
269
+ "max_seqlen": max_seqlen_in_batch,
270
+ "cu_seqlens_k": cu_seqlens,
271
+ "max_seqlen_k": max_seqlen_in_batch,
272
+ }
273
+ if self._grad_checkpointing:
274
+ torch.utils.checkpoint.checkpoint(
275
+ self.layers[-1],
276
+ hidden_states_subset,
277
+ use_reentrant=False,
278
+ mixer_kwargs=mixer_kwargs,
279
+ )
280
+ else:
281
+ hidden_states = self.layers[-1](
282
+ hidden_states_subset, mixer_kwargs=mixer_kwargs
283
+ )
284
+ return hidden_states
285
+
286
+
287
+ class XLMRobertaPooler(nn.Module):
288
+ def __init__(self, config):
289
+ super().__init__()
290
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
291
+ if fused_bias_fc and FusedDense is None:
292
+ raise ImportError("fused_dense is not installed")
293
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
294
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
295
+ self.activation = nn.Tanh()
296
+
297
+ def forward(self, hidden_states, pool=True):
298
+ # We "pool" the model by simply taking the hidden state corresponding
299
+ # to the first token.
300
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
301
+ pooled_output = self.dense(first_token_tensor)
302
+ pooled_output = self.activation(pooled_output)
303
+ return pooled_output
304
+
305
+
306
+ class XLMRobertaPredictionHeadTransform(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
310
+ if fused_bias_fc and FusedDense is None:
311
+ raise ImportError("fused_dense is not installed")
312
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
313
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
314
+ raise ImportError("Triton is not installed")
315
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
316
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
317
+ approximate = (
318
+ "tanh"
319
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
320
+ else "none"
321
+ )
322
+ self.transform_act_fn = nn.GELU(approximate=approximate)
323
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
324
+
325
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
326
+ hidden_states = self.dense(hidden_states)
327
+ hidden_states = self.transform_act_fn(hidden_states)
328
+ if not self.fused_dropout_add_ln:
329
+ hidden_states = self.layer_norm(hidden_states)
330
+ else:
331
+ hidden_states = layer_norm_fn(
332
+ hidden_states,
333
+ self.layer_norm.weight,
334
+ self.layer_norm.bias,
335
+ eps=self.layer_norm.eps,
336
+ )
337
+ return hidden_states
338
+
339
+
340
+ class XLMRobertaLMPredictionHead(nn.Module):
341
+ def __init__(self, config):
342
+ super().__init__()
343
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
344
+ if fused_bias_fc and FusedDense is None:
345
+ raise ImportError("fused_dense is not installed")
346
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
347
+
348
+ self.transform = XLMRobertaPredictionHeadTransform(config)
349
+
350
+ # The output weights are the same as the input embeddings, but there is
351
+ # an output-only bias for each token.
352
+ self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
353
+
354
+ def forward(self, hidden_states):
355
+ hidden_states = self.transform(hidden_states)
356
+ hidden_states = self.decoder(hidden_states)
357
+ return hidden_states
358
+
359
+
360
+ class XLMRobertaPreTrainingHeads(nn.Module):
361
+ def __init__(self, config):
362
+ super().__init__()
363
+ self.predictions = XLMRobertaLMPredictionHead(config)
364
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
365
+
366
+ def forward(self, sequence_output, pooled_output):
367
+ prediction_scores = self.predictions(sequence_output)
368
+ seq_relationship_score = self.seq_relationship(pooled_output)
369
+ return prediction_scores, seq_relationship_score
370
+
371
+
372
+ class XLMRobertaPreTrainedModel(PreTrainedModel):
373
+ """An abstract class to handle weights initialization and
374
+ a simple interface for dowloading and loading pretrained models.
375
+ """
376
+
377
+ config_class = XLMRobertaFlashConfig
378
+ base_model_prefix = "roberta"
379
+ supports_gradient_checkpointing = True
380
+
381
+ def _set_gradient_checkpointing(self, module, value=False):
382
+ if isinstance(module, XLMRobertaEncoder):
383
+ module.gradient_checkpointing = value
384
+
385
+ @classmethod
386
+ def from_pretrained(
387
+ cls,
388
+ *args,
389
+ **kwargs,
390
+ ):
391
+ if not 'torch_dtype' in kwargs:
392
+ kwargs['torch_dtype'] = 'auto'
393
+ return super().from_pretrained(*args, **kwargs)
394
+
395
+
396
+
397
+ class XLMRobertaModel(XLMRobertaPreTrainedModel):
398
+ def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
399
+ super().__init__(config)
400
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
401
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
402
+ config.vocab_size += self.pad_vocab_size_multiple - (
403
+ config.vocab_size % self.pad_vocab_size_multiple
404
+ )
405
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
406
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
407
+ raise ImportError("Triton is not installed")
408
+ assert config.hidden_act in [
409
+ "gelu",
410
+ "gelu_new",
411
+ "gelu_fast",
412
+ "gelu_pytorch_tanh",
413
+ ]
414
+
415
+ self.embeddings = XLMRobertaEmbeddings(
416
+ config.hidden_size,
417
+ config.vocab_size,
418
+ config.max_position_embeddings if config.position_embedding_type == 'absolute' else -1,
419
+ config.type_vocab_size,
420
+ padding_idx=config.pad_token_id,
421
+ )
422
+ self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
423
+ self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
424
+ self.encoder = XLMRobertaEncoder(config)
425
+ self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
426
+
427
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
428
+
429
+
430
+ @torch.inference_mode()
431
+ def encode(
432
+ self: 'XLMRobertaModel',
433
+ sentences: Union[str, List[str]],
434
+ batch_size: int = 32,
435
+ show_progress_bar: Optional[bool] = None,
436
+ output_value: str = 'sentence_embedding',
437
+ convert_to_numpy: bool = True,
438
+ convert_to_tensor: bool = False,
439
+ device: Optional[torch.device] = None,
440
+ normalize_embeddings: bool = False,
441
+ truncate_dim: Optional[int] = None,
442
+ **tokenizer_kwargs,
443
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
444
+ """
445
+ Computes sentence embeddings
446
+ Args:
447
+ sentences(`str` or `List[str]`):
448
+ Sentence or sentences to be encoded
449
+ batch_size(`int`, *optional*, defaults to 32):
450
+ Batch size for the computation
451
+ show_progress_bar(`bool`, *optional*, defaults to None):
452
+ Show a progress bar when encoding sentences.
453
+ If set to None, progress bar is only shown when
454
+ `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
455
+ output_value(`str`, *optional*, defaults to 'sentence_embedding'):
456
+ Default sentence_embedding, to get sentence embeddings.
457
+ Can be set to token_embeddings to get wordpiece token embeddings.
458
+ Set to None, to get all output values
459
+ convert_to_numpy(`bool`, *optional*, defaults to True):
460
+ If true, the output is a list of numpy vectors.
461
+ Else, it is a list of pytorch tensors.
462
+ convert_to_tensor(`bool`, *optional*, defaults to False):
463
+ If true, you get one large tensor as return.
464
+ Overwrites any setting from convert_to_numpy
465
+ device(`torch.device`, *optional*, defaults to None):
466
+ Which torch.device to use for the computation
467
+ normalize_embeddings(`bool`, *optional*, defaults to False):
468
+ If set to true, returned vectors will have length 1. In that case, the
469
+ faster dot-product (util.dot_score) instead of cosine similarity can
470
+ be used.
471
+ truncate_dim(`int`, *optional*, defaults to None):
472
+ The dimension to truncate sentence embeddings to. `None` does no truncation.
473
+ tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
474
+ Keyword arguments for the tokenizer
475
+ Returns:
476
+ By default, a list of tensors is returned.
477
+ If convert_to_tensor, a stacked tensor is returned.
478
+ If convert_to_numpy, a numpy matrix is returned.
479
+ """
480
+ from transformers import AutoTokenizer
481
+
482
+ self.tokenizer = AutoTokenizer.from_pretrained(
483
+ self.name_or_path, trust_remote_code=True
484
+ )
485
+
486
+ is_training = self.training
487
+ self.eval()
488
+
489
+ if show_progress_bar is None:
490
+ show_progress_bar = (
491
+ logger.getEffectiveLevel() == logging.INFO
492
+ or logger.getEffectiveLevel() == logging.DEBUG
493
+ )
494
+
495
+ if convert_to_tensor:
496
+ convert_to_numpy = False
497
+
498
+ if output_value != 'sentence_embedding':
499
+ convert_to_tensor = False
500
+ convert_to_numpy = False
501
+
502
+ input_was_string = False
503
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
504
+ sentences = [sentences]
505
+ input_was_string = True
506
+
507
+ if device is not None:
508
+ self.to(device)
509
+
510
+ permutation = np.argsort([-len(i) for i in sentences])
511
+ inverse_permutation = np.argsort(permutation)
512
+ sentences = [sentences[idx] for idx in permutation]
513
+
514
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
515
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get(
516
+ 'max_length', self.tokenizer.init_kwargs.get('model_max_length', 8192)
517
+ )
518
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
519
+
520
+ all_embeddings = []
521
+
522
+ if trange is not None:
523
+ range_iter = trange(
524
+ 0,
525
+ len(sentences),
526
+ batch_size,
527
+ desc="Encoding",
528
+ disable=not show_progress_bar,
529
+ )
530
+ else:
531
+ range_iter = range(0, len(sentences), batch_size)
532
+
533
+ for i in range_iter:
534
+ encoded_input = self.tokenizer(
535
+ sentences[i : i + batch_size],
536
+ return_tensors='pt',
537
+ **tokenizer_kwargs,
538
+ ).to(self.device)
539
+ token_embs = self.forward(**encoded_input)[0]
540
+
541
+ # Accumulate in fp32 to avoid overflow
542
+ token_embs = token_embs.float()
543
+
544
+ if output_value == 'token_embeddings':
545
+ raise NotImplementedError
546
+ elif output_value is None:
547
+ raise NotImplementedError
548
+ else:
549
+ if self.config.emb_pooler == 'cls':
550
+ embeddings = self.cls_pooling(
551
+ token_embs, encoded_input['attention_mask']
552
+ )
553
+ else:
554
+ embeddings = self.mean_pooling(
555
+ token_embs, encoded_input['attention_mask']
556
+ )
557
+
558
+ if normalize_embeddings:
559
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
560
+
561
+ if convert_to_numpy:
562
+ embeddings = embeddings.cpu()
563
+ all_embeddings.extend(embeddings)
564
+
565
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
566
+
567
+ truncate_dim = truncate_dim or self.config.truncate_dim
568
+ if truncate_dim:
569
+ all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
570
+
571
+ if convert_to_tensor:
572
+ all_embeddings = torch.stack(all_embeddings)
573
+ elif convert_to_numpy:
574
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
575
+
576
+ if input_was_string:
577
+ all_embeddings = all_embeddings[0]
578
+
579
+ self.train(is_training)
580
+ return all_embeddings
581
+
582
+
583
+ def truncate_embeddings(self, embeddings, truncate_dim):
584
+ if not self.config.matryoshka_dimensions:
585
+ logger.warning(
586
+ 'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
587
+ )
588
+ return embeddings
589
+ elif truncate_dim in self.config.matryoshka_dimensions:
590
+ return [tensor[:truncate_dim] for tensor in embeddings]
591
+ else:
592
+ raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
593
+ f'Supported dimensions are {self.config.matryoshka_dimensions}.')
594
+
595
+ def mean_pooling(
596
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
597
+ ):
598
+ input_mask_expanded = (
599
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
600
+ )
601
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
602
+ input_mask_expanded.sum(1), min=1e-9
603
+ )
604
+
605
+
606
+ def cls_pooling(
607
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
608
+ ):
609
+ return token_embeddings[:,0]
610
+
611
+
612
+ def forward(
613
+ self,
614
+ input_ids,
615
+ position_ids=None,
616
+ token_type_ids=None,
617
+ attention_mask=None,
618
+ masked_tokens_mask=None,
619
+ return_dict=None,
620
+ **kwargs,
621
+ ):
622
+ """If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
623
+ we only want the output for the masked tokens. This means that we only compute the last
624
+ layer output for these tokens.
625
+ masked_tokens_mask: (batch, seqlen), dtype=torch.bool
626
+ """
627
+
628
+ if kwargs:
629
+ for key, value in kwargs.items():
630
+ if value is not None:
631
+ logger.warning(
632
+ 'Flash attention implementation does not support kwargs: %s',
633
+ key,
634
+ )
635
+
636
+ return_dict = (
637
+ return_dict if return_dict is not None else self.config.use_return_dict
638
+ )
639
+
640
+ hidden_states = self.embeddings(
641
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids
642
+ )
643
+ # TD [2022-12:18]: Don't need to force residual in fp32
644
+ # BERT puts embedding LayerNorm before embedding dropout.
645
+ if not self.fused_dropout_add_ln:
646
+ hidden_states = self.emb_ln(hidden_states)
647
+ else:
648
+ hidden_states = layer_norm_fn(
649
+ hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
650
+ )
651
+ hidden_states = self.emb_drop(hidden_states)
652
+
653
+ if masked_tokens_mask is not None:
654
+ batch_size, seqlen = input_ids.shape[:2]
655
+ # We also need the first column for the CLS token
656
+ first_col_mask = torch.zeros(
657
+ batch_size, seqlen, dtype=torch.bool, device=input_ids.device
658
+ )
659
+ first_col_mask[:, 0] = True
660
+ subset_mask = masked_tokens_mask | first_col_mask
661
+ else:
662
+ subset_mask = None
663
+
664
+ sequence_output = self.encoder(
665
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
666
+ )
667
+
668
+ if masked_tokens_mask is None:
669
+ pooled_output = (
670
+ self.pooler(sequence_output) if self.pooler is not None else None
671
+ )
672
+ else:
673
+ # TD [2022-03-01]: the indexing here is very tricky.
674
+ if attention_mask is not None:
675
+ subset_idx = subset_mask[attention_mask]
676
+ pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
677
+ sequence_output = sequence_output[
678
+ masked_tokens_mask[attention_mask][subset_idx]
679
+ ]
680
+ else:
681
+ pool_input = sequence_output[first_col_mask[subset_mask]]
682
+ sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
683
+ pooled_output = (
684
+ self.pooler(pool_input, pool=False) if self.pooler is not None else None
685
+ )
686
+
687
+ if not return_dict:
688
+ return sequence_output, pooled_output
689
+
690
+ return BaseModelOutputWithPoolingAndCrossAttentions(
691
+ last_hidden_state=sequence_output,
692
+ pooler_output=pooled_output,
693
+ )
694
+
695
+
696
+ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
697
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
698
+
699
+ def __init__(self, config):
700
+ super().__init__(config)
701
+
702
+ if config.is_decoder:
703
+ logger.warning(
704
+ "If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder=False` for "
705
+ "bi-directional self-attention."
706
+ )
707
+
708
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
709
+ self.lm_head = XLMRobertaLMHead(config)
710
+
711
+ # Initialize weights and apply final processing
712
+ self.post_init()
713
+
714
+ def get_input_embeddings(self):
715
+ return self.roberta.embeddings.word_embeddings
716
+
717
+ def get_output_embeddings(self):
718
+ return self.lm_head.decoder
719
+
720
+ def set_output_embeddings(self, new_embeddings):
721
+ self.lm_head.decoder = new_embeddings
722
+
723
+ def forward(
724
+ self,
725
+ input_ids: Optional[torch.LongTensor] = None,
726
+ attention_mask: Optional[torch.FloatTensor] = None,
727
+ token_type_ids: Optional[torch.LongTensor] = None,
728
+ position_ids: Optional[torch.LongTensor] = None,
729
+ head_mask: Optional[torch.FloatTensor] = None,
730
+ inputs_embeds: Optional[torch.FloatTensor] = None,
731
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
732
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
733
+ labels: Optional[torch.LongTensor] = None,
734
+ output_attentions: Optional[bool] = None,
735
+ output_hidden_states: Optional[bool] = None,
736
+ return_dict: Optional[bool] = None,
737
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
738
+ r"""
739
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
740
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
741
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
742
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
743
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
744
+ Used to hide legacy arguments that have been deprecated.
745
+ """
746
+ return_dict = (
747
+ return_dict if return_dict is not None else self.config.use_return_dict
748
+ )
749
+
750
+ outputs = self.roberta(
751
+ input_ids,
752
+ attention_mask=attention_mask,
753
+ token_type_ids=token_type_ids,
754
+ position_ids=position_ids,
755
+ head_mask=head_mask,
756
+ inputs_embeds=inputs_embeds,
757
+ encoder_hidden_states=encoder_hidden_states,
758
+ encoder_attention_mask=encoder_attention_mask,
759
+ output_attentions=output_attentions,
760
+ output_hidden_states=output_hidden_states,
761
+ return_dict=return_dict,
762
+ )
763
+ sequence_output = outputs[0]
764
+ prediction_scores = self.lm_head(sequence_output)
765
+
766
+ masked_lm_loss = None
767
+ if labels is not None:
768
+ # move labels to correct device to enable model parallelism
769
+ labels = labels.to(prediction_scores.device)
770
+ loss_fct = CrossEntropyLoss()
771
+ masked_lm_loss = loss_fct(
772
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
773
+ )
774
+
775
+ if not return_dict:
776
+ output = (prediction_scores,) + outputs[2:]
777
+ return (
778
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
779
+ )
780
+
781
+ return MaskedLMOutput(
782
+ loss=masked_lm_loss,
783
+ logits=prediction_scores,
784
+ hidden_states=outputs.hidden_states,
785
+ attentions=outputs.attentions,
786
+ )
787
+
788
+
789
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->XLMRoberta
790
+ class XLMRobertaClassificationHead(nn.Module):
791
+ """Head for sentence-level classification tasks."""
792
+
793
+ def __init__(self, config):
794
+ super().__init__()
795
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
796
+ if fused_bias_fc and FusedDense is None:
797
+ raise ImportError("fused_dense is not installed")
798
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
799
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
800
+ classifier_dropout = (
801
+ config.classifier_dropout
802
+ if config.classifier_dropout is not None
803
+ else config.hidden_dropout_prob
804
+ )
805
+ self.dropout = nn.Dropout(classifier_dropout)
806
+ self.out_proj = linear_cls(config.hidden_size, config.num_labels)
807
+
808
+ def forward(self, features, **kwargs):
809
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
810
+ x = self.dropout(x)
811
+ x = self.dense(x)
812
+ x = torch.tanh(x)
813
+ x = self.dropout(x)
814
+ x = self.out_proj(x)
815
+ return x
816
+
817
+
818
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
819
+ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
820
+ def __init__(self, config):
821
+ super().__init__(config)
822
+ self.num_labels = config.num_labels
823
+ self.config = config
824
+
825
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
826
+ self.classifier = XLMRobertaClassificationHead(config)
827
+
828
+ # Initialize weights and apply final processing
829
+ self.post_init()
830
+
831
+ def forward(
832
+ self,
833
+ input_ids: Optional[torch.LongTensor] = None,
834
+ attention_mask: Optional[torch.FloatTensor] = None,
835
+ token_type_ids: Optional[torch.LongTensor] = None,
836
+ position_ids: Optional[torch.LongTensor] = None,
837
+ head_mask: Optional[torch.FloatTensor] = None,
838
+ inputs_embeds: Optional[torch.FloatTensor] = None,
839
+ labels: Optional[torch.LongTensor] = None,
840
+ output_attentions: Optional[bool] = None,
841
+ output_hidden_states: Optional[bool] = None,
842
+ return_dict: Optional[bool] = None,
843
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
844
+ r"""
845
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
846
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
847
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
848
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
849
+ """
850
+ return_dict = (
851
+ return_dict if return_dict is not None else self.config.use_return_dict
852
+ )
853
+
854
+ outputs = self.roberta(
855
+ input_ids,
856
+ attention_mask=attention_mask,
857
+ token_type_ids=token_type_ids,
858
+ position_ids=position_ids,
859
+ head_mask=head_mask,
860
+ inputs_embeds=inputs_embeds,
861
+ output_attentions=output_attentions,
862
+ output_hidden_states=output_hidden_states,
863
+ return_dict=return_dict,
864
+ )
865
+ sequence_output = outputs[0]
866
+ logits = self.classifier(sequence_output)
867
+
868
+ loss = None
869
+ if labels is not None:
870
+ # move labels to correct device to enable model parallelism
871
+ labels = labels.to(logits.device)
872
+ if self.config.problem_type is None:
873
+ if self.num_labels == 1:
874
+ self.config.problem_type = "regression"
875
+ elif self.num_labels > 1 and (
876
+ labels.dtype == torch.long or labels.dtype == torch.int
877
+ ):
878
+ self.config.problem_type = "single_label_classification"
879
+ else:
880
+ self.config.problem_type = "multi_label_classification"
881
+
882
+ if self.config.problem_type == "regression":
883
+ loss_fct = MSELoss()
884
+ if self.num_labels == 1:
885
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
886
+ else:
887
+ loss = loss_fct(logits, labels)
888
+ elif self.config.problem_type == "single_label_classification":
889
+ loss_fct = CrossEntropyLoss()
890
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
891
+ elif self.config.problem_type == "multi_label_classification":
892
+ loss_fct = BCEWithLogitsLoss()
893
+ loss = loss_fct(logits, labels)
894
+
895
+ if not return_dict:
896
+ output = (logits,) + outputs[2:]
897
+ return ((loss,) + output) if loss is not None else output
898
+
899
+ return SequenceClassifierOutput(
900
+ loss=loss,
901
+ logits=logits,
902
+ hidden_states=outputs.hidden_states,
903
+ attentions=outputs.attentions,
904
+ )
905
+
906
+
907
+ @torch.inference_mode()
908
+ def compute_score(
909
+ self,
910
+ sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
911
+ batch_size: int = 32,
912
+ max_length: Optional[int] = None,
913
+ ) -> List[float]:
914
+
915
+ if not hasattr(self, "_tokenizer"):
916
+ from transformers import AutoTokenizer
917
+
918
+ self._tokenizer = AutoTokenizer.from_pretrained(
919
+ self.name_or_path, trust_remote_code=True
920
+ )
921
+
922
+ assert isinstance(sentence_pairs, list)
923
+ if isinstance(sentence_pairs[0], str):
924
+ sentence_pairs = [sentence_pairs]
925
+
926
+ all_scores = []
927
+ for start_index in range(
928
+ 0, len(sentence_pairs), batch_size
929
+ ):
930
+ sentences_batch = sentence_pairs[
931
+ start_index : start_index + batch_size
932
+ ]
933
+ inputs = self._tokenizer(
934
+ sentences_batch,
935
+ padding=True,
936
+ truncation=True,
937
+ return_tensors='pt',
938
+ max_length=max_length,
939
+ ).to(self.device)
940
+ scores = (
941
+ self.forward(**inputs, return_dict=True)
942
+ .logits.view(
943
+ -1,
944
+ )
945
+ .float()
946
+ )
947
+ scores = torch.sigmoid(scores)
948
+ all_scores.extend(scores.cpu().numpy().tolist())
949
+
950
+ if len(all_scores) == 1:
951
+ return all_scores[0]
952
+ return all_scores
953
+
954
+ def predict(
955
+ self,
956
+ sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
957
+ batch_size: int = 32,
958
+ max_length: Optional[int] = None,
959
+ ) -> List[float]:
960
+ # used for beir evaluation
961
+ return self.compute_score(sentence_pairs, batch_size=batch_size, max_length=max_length)
962
+
963
+ def rerank(
964
+ self,
965
+ query: str,
966
+ documents: List[str],
967
+ batch_size: int = 32,
968
+ max_length: int = 1024,
969
+ max_query_length: int = 512,
970
+ overlap_tokens: int = 80,
971
+ top_n: Optional[int] = None,
972
+ **kwargs,
973
+ ):
974
+ assert max_length >= max_query_length * 2, (
975
+ f'max_length ({max_length}) must be greater than or equal to '
976
+ f'max_query_length ({max_query_length}) * 2'
977
+ )
978
+
979
+ if not hasattr(self, "_tokenizer"):
980
+ from transformers import AutoTokenizer
981
+
982
+ self._tokenizer = AutoTokenizer.from_pretrained(
983
+ self.name_or_path, trust_remote_code=True
984
+ )
985
+
986
+ # preproc of tokenization
987
+ sentence_pairs, sentence_pairs_pids = reranker_tokenize_preproc(
988
+ query,
989
+ documents,
990
+ tokenizer=self._tokenizer,
991
+ max_length=max_length,
992
+ max_query_length=max_query_length,
993
+ overlap_tokens=overlap_tokens,
994
+ )
995
+
996
+ tot_scores = []
997
+ with torch.no_grad():
998
+ for k in range(0, len(sentence_pairs), batch_size):
999
+ batch = self._tokenizer.pad(
1000
+ sentence_pairs[k : k + batch_size],
1001
+ padding=True,
1002
+ max_length=max_length,
1003
+ pad_to_multiple_of=None,
1004
+ return_tensors="pt",
1005
+ )
1006
+ batch_on_device = {k: v.to(self.device) for k, v in batch.items()}
1007
+ scores = (
1008
+ self.forward(**batch_on_device, return_dict=True)
1009
+ .logits.view(
1010
+ -1,
1011
+ )
1012
+ .float()
1013
+ )
1014
+ scores = torch.sigmoid(scores)
1015
+ tot_scores.extend(scores.cpu().numpy().tolist())
1016
+
1017
+ # ranking
1018
+ merge_scores = [0 for _ in range(len(documents))]
1019
+ for pid, score in zip(sentence_pairs_pids, tot_scores):
1020
+ merge_scores[pid] = max(merge_scores[pid], score)
1021
+
1022
+ merge_scores_argsort = np.argsort(merge_scores)[::-1]
1023
+ sorted_documents = []
1024
+ sorted_scores = []
1025
+ for mid in merge_scores_argsort:
1026
+ sorted_scores.append(merge_scores[mid])
1027
+ sorted_documents.append(documents[mid])
1028
+
1029
+ top_n = min(top_n or len(sorted_documents), len(sorted_documents))
1030
+
1031
+ return [
1032
+ {
1033
+ 'document': sorted_documents[i],
1034
+ 'relevance_score': sorted_scores[i],
1035
+ 'index': merge_scores_argsort[i],
1036
+ }
1037
+ for i in range(top_n)
1038
+ ]
1039
+
1040
+
1041
+ def reranker_tokenize_preproc(
1042
+ query: str,
1043
+ passages: List[str],
1044
+ tokenizer=None,
1045
+ max_length: int = 1024,
1046
+ max_query_length: int = 512,
1047
+ overlap_tokens: int = 80,
1048
+ ):
1049
+ from copy import deepcopy
1050
+
1051
+ assert tokenizer is not None, "Please provide a valid tokenizer for tokenization!"
1052
+ sep_id = tokenizer.sep_token_id
1053
+
1054
+ def _merge_inputs(chunk1_raw, chunk2):
1055
+ chunk1 = deepcopy(chunk1_raw)
1056
+ chunk1['input_ids'].append(sep_id)
1057
+ chunk1['input_ids'].extend(chunk2['input_ids'])
1058
+ chunk1['input_ids'].append(sep_id)
1059
+ chunk1['attention_mask'].append(chunk2['attention_mask'][0])
1060
+ chunk1['attention_mask'].extend(chunk2['attention_mask'])
1061
+ chunk1['attention_mask'].append(chunk2['attention_mask'][-1])
1062
+ if 'token_type_ids' in chunk1:
1063
+ token_type_ids = [1 for _ in range(len(chunk2['token_type_ids']) + 2)]
1064
+ chunk1['token_type_ids'].extend(token_type_ids)
1065
+ return chunk1
1066
+
1067
+ # Note: the long query will be truncated to 256 tokens by default
1068
+ query_inputs = tokenizer.encode_plus(
1069
+ query, truncation=True, padding=False, max_length=max_query_length
1070
+ )
1071
+
1072
+ max_passage_inputs_length = max_length - len(query_inputs['input_ids']) - 2
1073
+ # assert (
1074
+ # max_passage_inputs_length > 100
1075
+ # ), "Your query is too long! Please make sure your query less than 500 tokens!"
1076
+
1077
+ overlap_tokens_implt = min(overlap_tokens, max_passage_inputs_length // 4)
1078
+
1079
+ res_merge_inputs = []
1080
+ res_merge_inputs_pids = []
1081
+ for pid, passage in enumerate(passages):
1082
+ passage_inputs = tokenizer.encode_plus(
1083
+ passage,
1084
+ truncation=False,
1085
+ padding=False,
1086
+ add_special_tokens=False,
1087
+ max_length=0,
1088
+ )
1089
+ passage_inputs_length = len(passage_inputs['input_ids'])
1090
+
1091
+ if passage_inputs_length <= max_passage_inputs_length:
1092
+ qp_merge_inputs = _merge_inputs(query_inputs, passage_inputs)
1093
+ res_merge_inputs.append(qp_merge_inputs)
1094
+ res_merge_inputs_pids.append(pid)
1095
+ else:
1096
+ start_id = 0
1097
+ while start_id < passage_inputs_length:
1098
+ end_id = start_id + max_passage_inputs_length
1099
+ # make sure the length of the last chunk is `max_passage_inputs_length`
1100
+ if end_id >= passage_inputs_length:
1101
+ sub_passage_inputs = {
1102
+ k: v[-max_passage_inputs_length:]
1103
+ for k, v in passage_inputs.items()
1104
+ }
1105
+ else:
1106
+ sub_passage_inputs = {
1107
+ k: v[start_id:end_id] for k, v in passage_inputs.items()
1108
+ }
1109
+ start_id = (
1110
+ end_id - overlap_tokens_implt
1111
+ if end_id < passage_inputs_length
1112
+ else end_id
1113
+ )
1114
+
1115
+ qp_merge_inputs = _merge_inputs(query_inputs, sub_passage_inputs)
1116
+ res_merge_inputs.append(qp_merge_inputs)
1117
+ res_merge_inputs_pids.append(pid)
1118
+
1119
+ return res_merge_inputs, res_merge_inputs_pids
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
stochastic_depth.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation modified from torchvision:
2
+ # https://github.com/pytorch/vision/blob/main/torchvision/ops/stochastic_depth.py
3
+ #
4
+ # License:
5
+ # BSD 3-Clause License
6
+ #
7
+ # Copyright (c) Soumith Chintala 2016,
8
+ # All rights reserved.
9
+ #
10
+ # Redistribution and use in source and binary forms, with or without
11
+ # modification, are permitted provided that the following conditions are met:
12
+ #
13
+ # * Redistributions of source code must retain the above copyright notice, this
14
+ # list of conditions and the following disclaimer.
15
+ #
16
+ # * Redistributions in binary form must reproduce the above copyright notice,
17
+ # this list of conditions and the following disclaimer in the documentation
18
+ # and/or other materials provided with the distribution.
19
+ #
20
+ # * Neither the name of the copyright holder nor the names of its
21
+ # contributors may be used to endorse or promote products derived from
22
+ # this software without specific prior written permission.
23
+ #
24
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34
+
35
+ import torch
36
+ import torch.fx
37
+ from torch import nn, Tensor
38
+
39
+
40
+ def stochastic_depth(
41
+ input: Tensor, p: float, mode: str, training: bool = True
42
+ ) -> Tensor:
43
+ """
44
+ Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
45
+ <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
46
+ branches of residual architectures.
47
+
48
+ Args:
49
+ input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
50
+ being its batch i.e. a batch with ``N`` rows.
51
+ p (float): probability of the input to be zeroed.
52
+ mode (str): ``"batch"`` or ``"row"``.
53
+ ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
54
+ randomly selected rows from the batch.
55
+ training: apply stochastic depth if is ``True``. Default: ``True``
56
+
57
+ Returns:
58
+ Tensor[N, ...]: The randomly zeroed tensor.
59
+ """
60
+ if p < 0.0 or p > 1.0:
61
+ raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
62
+ if mode not in ["batch", "row"]:
63
+ raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
64
+ if not training or p == 0.0:
65
+ return input
66
+
67
+ survival_rate = 1.0 - p
68
+ if mode == "row":
69
+ size = [input.shape[0]] + [1] * (input.ndim - 1)
70
+ else:
71
+ size = [1] * input.ndim
72
+ noise = torch.empty(size, dtype=input.dtype, device=input.device)
73
+ noise = noise.bernoulli_(survival_rate)
74
+ if survival_rate > 0.0:
75
+ noise.div_(survival_rate)
76
+ return input * noise
77
+
78
+
79
+ torch.fx.wrap("stochastic_depth")
80
+
81
+
82
+ class StochasticDepth(nn.Module):
83
+ """
84
+ See :func:`stochastic_depth`.
85
+ """
86
+
87
+ def __init__(self, p: float, mode: str) -> None:
88
+ super().__init__()
89
+ self.p = p
90
+ self.mode = mode
91
+
92
+ def forward(self, input: Tensor) -> Tensor:
93
+ return stochastic_depth(input, self.p, self.mode, self.training)
94
+
95
+ def __repr__(self) -> str:
96
+ s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
97
+ return s
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "46466": {
36
+ "content": "<mask>",
37
+ "lstrip": true,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "bos_token": "<s>",
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "<s>",
47
+ "eos_token": "</s>",
48
+ "mask_token": "<mask>",
49
+ "model_max_length": 1024,
50
+ "pad_token": "<pad>",
51
+ "sep_token": "</s>",
52
+ "tokenizer_class": "XLMRobertaTokenizer",
53
+ "unk_token": "<unk>"
54
+ }
xlm_padding.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: c94cd09744d20f0ac587a351ff6ff2e8ad11ae1b
3
+
4
+ # Previously adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+
10
+
11
+ class IndexFirstAxis(torch.autograd.Function):
12
+ @staticmethod
13
+ def forward(ctx, input, indices):
14
+ ctx.save_for_backward(indices)
15
+ assert input.ndim >= 2
16
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
17
+ second_dim = other_shape.numel()
18
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
19
+ # return input[indices]
20
+ return torch.gather(
21
+ rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
22
+ ).reshape(-1, *other_shape)
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad_output):
26
+ (indices,) = ctx.saved_tensors
27
+ assert grad_output.ndim >= 2
28
+ other_shape = grad_output.shape[1:]
29
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
30
+ grad_input = torch.zeros(
31
+ [ctx.first_axis_dim, grad_output.shape[1]],
32
+ device=grad_output.device,
33
+ dtype=grad_output.dtype,
34
+ )
35
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
36
+ # grad_input[indices] = grad_output
37
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
38
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
39
+
40
+
41
+ index_first_axis = IndexFirstAxis.apply
42
+
43
+
44
+ class IndexPutFirstAxis(torch.autograd.Function):
45
+ @staticmethod
46
+ def forward(ctx, values, indices, first_axis_dim):
47
+ ctx.save_for_backward(indices)
48
+ assert indices.ndim == 1
49
+ assert values.ndim >= 2
50
+ output = torch.zeros(
51
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
52
+ )
53
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
54
+ output[indices] = values
55
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
56
+ return output
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output):
60
+ (indices,) = ctx.saved_tensors
61
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
62
+ grad_values = grad_output[indices]
63
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
64
+ return grad_values, None, None
65
+
66
+
67
+ index_put_first_axis = IndexPutFirstAxis.apply
68
+
69
+
70
+ class IndexFirstAxisResidual(torch.autograd.Function):
71
+ @staticmethod
72
+ def forward(ctx, input, indices):
73
+ ctx.save_for_backward(indices)
74
+ assert input.ndim >= 2
75
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
76
+ second_dim = other_shape.numel()
77
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
78
+ output = input[indices]
79
+ # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
80
+ # memory format to channel_first. In other words, input might not be contiguous.
81
+ # If we don't detach, Pytorch complains about output being a view and is being modified inplace
82
+ return output, input.detach()
83
+
84
+ @staticmethod
85
+ def backward(ctx, grad_output, grad_residual):
86
+ (indices,) = ctx.saved_tensors
87
+ assert grad_output.ndim >= 2
88
+ other_shape = grad_output.shape[1:]
89
+ assert grad_residual.shape[1:] == other_shape
90
+ grad_input = grad_residual
91
+ # grad_input[indices] += grad_output
92
+ indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
93
+ indices = indices.expand_as(grad_output)
94
+ grad_input.scatter_add_(0, indices, grad_output)
95
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
96
+
97
+
98
+ index_first_axis_residual = IndexFirstAxisResidual.apply
99
+
100
+
101
+ def unpad_input(hidden_states, attention_mask):
102
+ """
103
+ Arguments:
104
+ hidden_states: (batch, seqlen, ...)
105
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
106
+ Return:
107
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
108
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
109
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
110
+ max_seqlen_in_batch: int
111
+ """
112
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
113
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
114
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
115
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
116
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
117
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
118
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
119
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
120
+ # so we write custom forward and backward to make it a bit faster.
121
+ return (
122
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
123
+ indices,
124
+ cu_seqlens,
125
+ max_seqlen_in_batch,
126
+ )
127
+
128
+
129
+ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
130
+ """
131
+ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
132
+ The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
133
+
134
+ For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
135
+ ```
136
+ [
137
+ [2, 3, 0, 0, 0, 0],
138
+ [3, 2, 0, 0, 0, 0],
139
+ [6, 0, 0, 0, 0, 0]
140
+ ]
141
+ ```
142
+ , which refers to the 3D-attention mask:
143
+ ```
144
+ [
145
+ [
146
+ [1, 0, 0, 0, 0, 0],
147
+ [1, 1, 0, 0, 0, 0],
148
+ [0, 0, 1, 0, 0, 0],
149
+ [0, 0, 1, 1, 0, 0],
150
+ [0, 0, 1, 1, 1, 0],
151
+ [0, 0, 0, 0, 0, 1]
152
+ ],
153
+ [
154
+ [1, 0, 0, 0, 0, 0],
155
+ [1, 1, 0, 0, 0, 0],
156
+ [1, 1, 1, 0, 0, 0],
157
+ [0, 0, 0, 1, 0, 0],
158
+ [0, 0, 0, 1, 1, 0],
159
+ [0, 0, 0, 0, 0, 1]
160
+ ],
161
+ [
162
+ [1, 0, 0, 0, 0, 0],
163
+ [1, 1, 0, 0, 0, 0],
164
+ [1, 1, 1, 0, 0, 0],
165
+ [1, 1, 1, 1, 0, 0],
166
+ [1, 1, 1, 1, 1, 0],
167
+ [1, 1, 1, 1, 1, 1]
168
+ ]
169
+ ]
170
+ ```.
171
+
172
+ Arguments:
173
+ hidden_states: (batch, seqlen, ...)
174
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
175
+ Return:
176
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
177
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
178
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
179
+ max_seqlen_in_batch: int
180
+ """
181
+ length = attention_mask_in_length.sum(dim=-1)
182
+ seqlen = attention_mask_in_length.size(-1)
183
+ attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length),
184
+ seqlen) < length.unsqueeze(
185
+ 1)
186
+ real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
187
+ seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
188
+ indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
189
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
190
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
191
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
192
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
193
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
194
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
195
+ # so we write custom forward and backward to make it a bit faster.
196
+ return (
197
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
198
+ indices,
199
+ cu_seqlens,
200
+ max_seqlen_in_batch,
201
+ )
202
+
203
+
204
+ def pad_input(hidden_states, indices, batch, seqlen):
205
+ """
206
+ Arguments:
207
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
208
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
209
+ batch: int, batch size for the padded sequence.
210
+ seqlen: int, maximum sequence length for the padded sequence.
211
+ Return:
212
+ hidden_states: (batch, seqlen, ...)
213
+ """
214
+ dim = hidden_states.shape[-1]
215
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
216
+ # output[indices] = hidden_states
217
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
218
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)