krotima1 commited on
Commit
2b70e5b
1 Parent(s): 7005a40

feat: add summarizer

Browse files
Files changed (1) hide show
  1. MultilingualSummarizer.ipynb +42 -19
MultilingualSummarizer.ipynb CHANGED
@@ -13,11 +13,7 @@
13
  {
14
  "cell_type": "code",
15
  "execution_count": null,
16
- "metadata": {
17
- "vscode": {
18
- "languageId": "python"
19
- }
20
- },
21
  "outputs": [],
22
  "source": [
23
  "import torch as pt\n",
@@ -30,6 +26,7 @@
30
  "from transformers import AutoTokenizer\n",
31
  "import datasets\n",
32
  "\n",
 
33
  "import logging\n",
34
  "logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(name)s | %(levelname)s | %(message)s')\n",
35
  "\n",
@@ -56,10 +53,12 @@
56
  " #\n",
57
  " def __init__(self, model_name, language, inference_cfg=None, **kwargs):\n",
58
  " logging.info(f\"Initializing multilingual summarizer {model_name}\")\n",
 
59
  " self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
60
- " self.dstTokenizer = DatasetTokenizer(model_name.split('/')[-1], model_name, language)\n",
61
  " self.tokenizer = self.dstTokenizer.get_tokenizer()\n",
62
  " self.langid = self.dstTokenizer.get_langid()\n",
 
63
  " self.inference_cfg = inference_cfg\n",
64
  " self.enc_max_len = 512\n",
65
  " self.language = language\n",
@@ -114,7 +113,8 @@
114
  " summarizer = Summarizer(model = self.model, tokenizer = self.tokenizer,lcode=self.langid, batch_size = 8)\n",
115
  " \n",
116
  " #Summarize texts\n",
117
- " summarizer.summarize_dst(tok_dst,**self.inference_cfg)\n",
 
118
  " \n",
119
  " \n",
120
  " scores = {}\n",
@@ -125,17 +125,16 @@
125
  " \n",
126
  " \n",
127
  " return (summarizer.summarized_dst['summary'], scores)\n",
128
- " \n"
 
 
 
129
  ]
130
  },
131
  {
132
  "cell_type": "code",
133
  "execution_count": null,
134
- "metadata": {
135
- "vscode": {
136
- "languageId": "python"
137
- }
138
- },
139
  "outputs": [],
140
  "source": [
141
  "## Configuration of summarization pipeline\n",
@@ -185,24 +184,36 @@
185
  " ])\n",
186
  " return cfg\n",
187
  "\n",
 
 
 
 
 
 
 
 
 
188
  "cfg = summ_config()\n",
189
  "msummarizer = MultiSummarizer(**cfg)\n",
190
- "ret = msummarizer(**cfg)\n"
191
  ]
192
  },
193
  {
194
  "cell_type": "code",
195
  "execution_count": null,
196
- "metadata": {
197
- "vscode": {
198
- "languageId": "python"
199
- }
200
- },
201
  "outputs": [],
202
  "source": [
203
  "ret = msummarizer(**cfg)\n",
204
  "print(ret)"
205
  ]
 
 
 
 
 
 
 
206
  }
207
  ],
208
  "metadata": {
@@ -211,6 +222,18 @@
211
  "language": "python",
212
  "name": "python3"
213
  },
 
 
 
 
 
 
 
 
 
 
 
 
214
  "orig_nbformat": 4
215
  },
216
  "nbformat": 4,
 
13
  {
14
  "cell_type": "code",
15
  "execution_count": null,
16
+ "metadata": {},
 
 
 
 
17
  "outputs": [],
18
  "source": [
19
  "import torch as pt\n",
 
26
  "from transformers import AutoTokenizer\n",
27
  "import datasets\n",
28
  "\n",
29
+ "import re\n",
30
  "import logging\n",
31
  "logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(name)s | %(levelname)s | %(message)s')\n",
32
  "\n",
 
53
  " #\n",
54
  " def __init__(self, model_name, language, inference_cfg=None, **kwargs):\n",
55
  " logging.info(f\"Initializing multilingual summarizer {model_name}\")\n",
56
+ " self.name = model_name.split('/')[-1]\n",
57
  " self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
58
+ " self.dstTokenizer = DatasetTokenizer(self.name, model_name, language)\n",
59
  " self.tokenizer = self.dstTokenizer.get_tokenizer()\n",
60
  " self.langid = self.dstTokenizer.get_langid()\n",
61
+ " self.lang_token = self.dstTokenizer.get_lang_token()\n",
62
  " self.inference_cfg = inference_cfg\n",
63
  " self.enc_max_len = 512\n",
64
  " self.language = language\n",
 
113
  " summarizer = Summarizer(model = self.model, tokenizer = self.tokenizer,lcode=self.langid, batch_size = 8)\n",
114
  " \n",
115
  " #Summarize texts\n",
116
+ " filter_fc = self._filter_final_summaries if self.name.startswith('mt5') else None\n",
117
+ " summarizer.summarize_dst(tok_dst, filter_fc_batch = filter_fc,**self.inference_cfg)\n",
118
  " \n",
119
  " \n",
120
  " scores = {}\n",
 
125
  " \n",
126
  " \n",
127
  " return (summarizer.summarized_dst['summary'], scores)\n",
128
+ " \n",
129
+ " def _filter_final_summaries(self, batch, **kwargs):\n",
130
+ " batch[\"summary\"] = [ re.sub(self.lang_token, '', tmp) for tmp in batch[\"summary\"]]\n",
131
+ " return batch"
132
  ]
133
  },
134
  {
135
  "cell_type": "code",
136
  "execution_count": null,
137
+ "metadata": {},
 
 
 
 
138
  "outputs": [],
139
  "source": [
140
  "## Configuration of summarization pipeline\n",
 
184
  " ])\n",
185
  " return cfg\n",
186
  "\n",
187
+ "\n"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": null,
193
+ "metadata": {},
194
+ "outputs": [],
195
+ "source": [
196
  "cfg = summ_config()\n",
197
  "msummarizer = MultiSummarizer(**cfg)\n",
198
+ "ret = msummarizer(**cfg)"
199
  ]
200
  },
201
  {
202
  "cell_type": "code",
203
  "execution_count": null,
204
+ "metadata": {},
 
 
 
 
205
  "outputs": [],
206
  "source": [
207
  "ret = msummarizer(**cfg)\n",
208
  "print(ret)"
209
  ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": []
217
  }
218
  ],
219
  "metadata": {
 
222
  "language": "python",
223
  "name": "python3"
224
  },
225
+ "language_info": {
226
+ "codemirror_mode": {
227
+ "name": "ipython",
228
+ "version": 3
229
+ },
230
+ "file_extension": ".py",
231
+ "mimetype": "text/x-python",
232
+ "name": "python",
233
+ "nbconvert_exporter": "python",
234
+ "pygments_lexer": "ipython3",
235
+ "version": "3.6.8"
236
+ },
237
  "orig_nbformat": 4
238
  },
239
  "nbformat": 4,