Text2Text Generation
Transformers
PyTorch
French
mt5
Inference Endpoints
nreimers commited on
Commit
f77064e
1 Parent(s): f59449e
README.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: fr
3
+ datasets:
4
+ - unicamp-dl/mmarco
5
+ widget:
6
+ - text: "Python (prononcé /pi.tɔ̃/) est un langage de programmation interprété, multi-paradigme et multiplateformes. Il favorise la programmation impérative structurée, fonctionnelle et orientée objet. Il est doté d'un typage dynamique fort, d'une gestion automatique de la mémoire par ramasse-miettes et d'un système de gestion d'exceptions ; il est ainsi similaire à Perl, Ruby, Scheme, Smalltalk et Tcl."
7
+
8
+ license: apache-2.0
9
+ ---
10
+
11
+ # doc2query/msmarco-french-mt5-base-v1
12
+
13
+ This is a [doc2query](https://arxiv.org/abs/1904.08375) model based on mT5 (also known as [docT5query](https://cs.uwaterloo.ca/~jimmylin/publications/Nogueira_Lin_2019_docTTTTTquery-v2.pdf)).
14
+
15
+ It can be used for:
16
+ - **Document expansion**: You generate for your paragraphs 20-40 queries and index the paragraphs and the generates queries in a standard BM25 index like Elasticsearch, OpenSearch, or Lucene. The generated queries help to close the lexical gap of lexical search, as the generate queries contain synonyms. Further, it re-weights words giving important words a higher weight even if they appear seldomn in a paragraph. In our [BEIR](https://arxiv.org/abs/2104.08663) paper we showed that BM25+docT5query is a powerful search engine. In the [BEIR repository](https://github.com/beir-cellar/beir) we have an example how to use docT5query with Pyserini.
17
+ - **Domain Specific Training Data Generation**: It can be used to generate training data to learn an embedding model. In our [GPL-Paper](https://arxiv.org/abs/2112.07577) / [GPL Example on SBERT.net](https://www.sbert.net/examples/domain_adaptation/README.html#gpl-generative-pseudo-labeling) we have an example how to use the model to generate (query, text) pairs for a given collection of unlabeled texts. These pairs can then be used to train powerful dense embedding models.
18
+
19
+ ## Usage
20
+ ```python
21
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
22
+ import torch
23
+
24
+ model_name = 'doc2query/msmarco-french-mt5-base-v1'
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
27
+
28
+ text = "Python (prononcé /pi.tɔ̃/) est un langage de programmation interprété, multi-paradigme et multiplateformes. Il favorise la programmation impérative structurée, fonctionnelle et orientée objet. Il est doté d'un typage dynamique fort, d'une gestion automatique de la mémoire par ramasse-miettes et d'un système de gestion d'exceptions ; il est ainsi similaire à Perl, Ruby, Scheme, Smalltalk et Tcl."
29
+
30
+
31
+ def create_queries(para):
32
+ input_ids = tokenizer.encode(para, return_tensors='pt')
33
+ with torch.no_grad():
34
+ # Here we use top_k / top_k random sampling. It generates more diverse queries, but of lower quality
35
+ sampling_outputs = model.generate(
36
+ input_ids=input_ids,
37
+ max_length=64,
38
+ do_sample=True,
39
+ top_p=0.95,
40
+ top_k=10,
41
+ num_return_sequences=5
42
+ )
43
+
44
+ # Here we use Beam-search. It generates better quality queries, but with less diversity
45
+ beam_outputs = model.generate(
46
+ input_ids=input_ids,
47
+ max_length=64,
48
+ num_beams=5,
49
+ no_repeat_ngram_size=2,
50
+ num_return_sequences=5,
51
+ early_stopping=True
52
+ )
53
+
54
+
55
+ print("Paragraph:")
56
+ print(para)
57
+
58
+ print("\nBeam Outputs:")
59
+ for i in range(len(beam_outputs)):
60
+ query = tokenizer.decode(beam_outputs[i], skip_special_tokens=True)
61
+ print(f'{i + 1}: {query}')
62
+
63
+ print("\nSampling Outputs:")
64
+ for i in range(len(sampling_outputs)):
65
+ query = tokenizer.decode(sampling_outputs[i], skip_special_tokens=True)
66
+ print(f'{i + 1}: {query}')
67
+
68
+ create_queries(text)
69
+
70
+ ```
71
+
72
+ **Note:** `model.generate()` is non-deterministic for top_k/top_n sampling. It produces different queries each time you run it.
73
+
74
+ ## Training
75
+ This model fine-tuned [google/mt5-base](https://huggingface.co/google/mt5-base) for 66k training steps (4 epochs on the 500k training pairs from MS MARCO). For the training script, see the `train_script.py` in this repository.
76
+
77
+ The input-text was truncated to 320 word pieces. Output text was generated up to 64 word pieces.
78
+
79
+ This model was trained on a (query, passage) from the [mMARCO dataset](https://github.com/unicamp-dl/mMARCO).
80
+
81
+
82
+
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google/mt5-base",
3
+ "architectures": [
4
+ "MT5ForConditionalGeneration"
5
+ ],
6
+ "d_ff": 2048,
7
+ "d_kv": 64,
8
+ "d_model": 768,
9
+ "decoder_start_token_id": 0,
10
+ "dropout_rate": 0.1,
11
+ "eos_token_id": 1,
12
+ "feed_forward_proj": "gated-gelu",
13
+ "initializer_factor": 1.0,
14
+ "is_encoder_decoder": true,
15
+ "layer_norm_epsilon": 1e-06,
16
+ "model_type": "mt5",
17
+ "num_decoder_layers": 12,
18
+ "num_heads": 12,
19
+ "num_layers": 12,
20
+ "output_past": true,
21
+ "pad_token_id": 0,
22
+ "relative_attention_max_distance": 128,
23
+ "relative_attention_num_buckets": 32,
24
+ "tie_word_embeddings": false,
25
+ "tokenizer_class": "T5Tokenizer",
26
+ "torch_dtype": "float32",
27
+ "transformers_version": "4.18.0",
28
+ "use_cache": true,
29
+ "vocab_size": 250112
30
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:665884df373a99a1269420c437a8d34c645724c4ac0471fabae45e3ff254aa28
3
+ size 2329700301
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>"}
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef78f86560d809067d12bac6c09f19a462cb3af3f54d2b8acbba26e1433125d6
3
+ size 4309802
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d3fca0dbb3a53bc1eddfc2e47ef441d7a94a70879e6750baddab04441a78305
3
+ size 16330621
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "extra_ids": 0, "additional_special_tokens": null, "special_tokens_map_file": "/home/patrick/.cache/torch/transformers/685ac0ca8568ec593a48b61b0a3c272beee9bc194a3c7241d15dcadb5f875e53.f76030f3ec1b96a8199b2593390c610e76ca8028ef3d24680000619ffb646276", "name_or_path": "google/mt5-base", "sp_model_kwargs": {}, "tokenizer_class": "T5Tokenizer"}
train_script.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ from torch.utils.data import Dataset, IterableDataset
4
+ import gzip
5
+ import json
6
+ from transformers import Seq2SeqTrainer, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments
7
+ import sys
8
+ from datetime import datetime
9
+ import torch
10
+ import random
11
+ from shutil import copyfile
12
+ import os
13
+ import wandb
14
+ import random
15
+ import re
16
+ from datasets import load_dataset
17
+ import tqdm
18
+
19
+
20
+ logging.basicConfig(
21
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
22
+ datefmt="%Y-%m-%d %H:%M:%S",
23
+ handlers=[logging.StreamHandler(sys.stdout)],
24
+ )
25
+
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument("--lang", required=True)
28
+ parser.add_argument("--model_name", default="google/mt5-base")
29
+ parser.add_argument("--epochs", default=4, type=int)
30
+ parser.add_argument("--batch_size", default=32, type=int)
31
+ parser.add_argument("--max_source_length", default=320, type=int)
32
+ parser.add_argument("--max_target_length", default=64, type=int)
33
+ parser.add_argument("--eval_size", default=1000, type=int)
34
+ #parser.add_argument("--fp16", default=False, action='store_true')
35
+ args = parser.parse_args()
36
+
37
+ wandb.init(project="doc2query", name=f"{args.lang}-{args.model_name}")
38
+
39
+
40
+
41
+
42
+
43
+ def main():
44
+ ############ Load dataset
45
+ queries = {}
46
+ for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'queries-{args.lang}')['train']):
47
+ queries[row['id']] = row['text']
48
+
49
+ """
50
+ collection = {}
51
+ for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'collection-{args.lang}')['collection']):
52
+ collection[row['id']] = row['text']
53
+ """
54
+ collection = load_dataset('unicamp-dl/mmarco', f'collection-{args.lang}')['collection']
55
+
56
+ train_pairs = []
57
+ eval_pairs = []
58
+
59
+
60
+ with open('qrels.train.tsv') as fIn:
61
+ for line in fIn:
62
+ qid, _, did, _ = line.strip().split("\t")
63
+
64
+ qid = int(qid)
65
+ did = int(did)
66
+
67
+ assert did == collection[did]['id']
68
+ text = collection[did]['text']
69
+
70
+ pair = (queries[qid], text)
71
+ if len(eval_pairs) < args.eval_size:
72
+ eval_pairs.append(pair)
73
+ else:
74
+ train_pairs.append(pair)
75
+
76
+
77
+ print(f"Train pairs: {len(train_pairs)}")
78
+
79
+
80
+ ############ Model
81
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
82
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
83
+
84
+ save_steps = 1000
85
+
86
+ output_dir = 'output/'+args.lang+'-'+args.model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
87
+ print("Output dir:", output_dir)
88
+
89
+ # Write self to path
90
+ os.makedirs(output_dir, exist_ok=True)
91
+
92
+ train_script_path = os.path.join(output_dir, 'train_script.py')
93
+ copyfile(__file__, train_script_path)
94
+ with open(train_script_path, 'a') as fOut:
95
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
96
+
97
+ ####
98
+
99
+ training_args = Seq2SeqTrainingArguments(
100
+ output_dir=output_dir,
101
+ bf16=True,
102
+ per_device_train_batch_size=args.batch_size,
103
+ evaluation_strategy="steps",
104
+ save_steps=save_steps,
105
+ logging_steps=100,
106
+ eval_steps=save_steps, #logging_steps,
107
+ warmup_steps=1000,
108
+ save_total_limit=1,
109
+ num_train_epochs=args.epochs,
110
+ report_to="wandb",
111
+ )
112
+
113
+ ############ Arguments
114
+
115
+ ############ Load datasets
116
+
117
+
118
+ print("Input:", train_pairs[0][1])
119
+ print("Target:", train_pairs[0][0])
120
+
121
+ print("Input:", eval_pairs[0][1])
122
+ print("Target:", eval_pairs[0][0])
123
+
124
+
125
+ def data_collator(examples):
126
+ targets = [row[0] for row in examples]
127
+ inputs = [row[1] for row in examples]
128
+ label_pad_token_id = -100
129
+
130
+ model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8 if training_args.fp16 else None)
131
+
132
+ # Setup the tokenizer for targets
133
+ with tokenizer.as_target_tokenizer():
134
+ labels = tokenizer(targets, max_length=args.max_target_length, padding=True, truncation=True, pad_to_multiple_of=8 if training_args.fp16 else None)
135
+
136
+ # replace all tokenizer.pad_token_id in the labels by -100 to ignore padding in the loss.
137
+ labels["input_ids"] = [
138
+ [(l if l != tokenizer.pad_token_id else label_pad_token_id) for l in label] for label in labels["input_ids"]
139
+ ]
140
+
141
+
142
+ model_inputs["labels"] = torch.tensor(labels["input_ids"])
143
+ return model_inputs
144
+
145
+ ## Define the trainer
146
+ trainer = Seq2SeqTrainer(
147
+ model=model,
148
+ args=training_args,
149
+ train_dataset=train_pairs,
150
+ eval_dataset=eval_pairs,
151
+ tokenizer=tokenizer,
152
+ data_collator=data_collator
153
+ )
154
+
155
+ ### Save the model
156
+ train_result = trainer.train()
157
+ trainer.save_model()
158
+
159
+
160
+ if __name__ == "__main__":
161
+ main()
162
+
163
+ # Script was called via:
164
+ #python train_hf_trainer_multilingual.py --lang french
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73e8589ec7a543dadfa39d29109503ff65ec3e64a6d92cd4a47a16f35b8e6c81
3
+ size 3247