Spaces:
Running
Running
# Lint as: python3 | |
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Library for processing crawled content and generating tfrecords.""" | |
import collections | |
import json | |
import multiprocessing | |
import os | |
import urllib.parse | |
import tensorflow as tf | |
from official.nlp.bert import tokenization | |
from official.nlp.data import classifier_data_lib | |
class RawDataProcessor(object): | |
"""Data converter for story examples.""" | |
def __init__(self, | |
vocab: str, | |
do_lower_case: bool, | |
len_title: int = 15, | |
len_passage: int = 200, | |
max_num_articles: int = 5, | |
include_article_title_in_passage: bool = False, | |
include_text_snippet_in_example: bool = False): | |
"""Constructs a RawDataProcessor. | |
Args: | |
vocab: Filepath of the BERT vocabulary. | |
do_lower_case: Whether the vocabulary is uncased or not. | |
len_title: Maximum number of tokens in story headline. | |
len_passage: Maximum number of tokens in article passage. | |
max_num_articles: Maximum number of articles in a story. | |
include_article_title_in_passage: Whether to include article title in | |
article passage. | |
include_text_snippet_in_example: Whether to include text snippet | |
(headline and article content) in generated tensorflow Examples, for | |
debug usage. If include_article_title_in_passage=True, title and body | |
will be separated by [SEP]. | |
""" | |
self.articles = dict() | |
self.tokenizer = tokenization.FullTokenizer( | |
vocab, do_lower_case=do_lower_case, split_on_punc=False) | |
self.len_title = len_title | |
self.len_passage = len_passage | |
self.max_num_articles = max_num_articles | |
self.include_article_title_in_passage = include_article_title_in_passage | |
self.include_text_snippet_in_example = include_text_snippet_in_example | |
# ex_index=5 deactivates printing inside convert_single_example. | |
self.ex_index = 5 | |
# Parameters used in InputExample, not used in NHNet. | |
self.label = 0 | |
self.guid = 0 | |
self.num_generated_examples = 0 | |
def read_crawled_articles(self, folder_path): | |
"""Reads crawled articles under folder_path.""" | |
for path, _, files in os.walk(folder_path): | |
for name in files: | |
if not name.endswith(".json"): | |
continue | |
url, article = self._get_article_content_from_json( | |
os.path.join(path, name)) | |
if not article.text_a: | |
continue | |
self.articles[RawDataProcessor.normalize_url(url)] = article | |
if len(self.articles) % 5000 == 0: | |
print("Number of articles loaded: %d\r" % len(self.articles), end="") | |
print() | |
return len(self.articles) | |
def generate_examples(self, input_file, output_files): | |
"""Loads story from input json file and exports examples in output_files.""" | |
writers = [] | |
story_partition = [] | |
for output_file in output_files: | |
writers.append(tf.io.TFRecordWriter(output_file)) | |
story_partition.append(list()) | |
with tf.io.gfile.GFile(input_file, "r") as story_json_file: | |
stories = json.load(story_json_file) | |
writer_index = 0 | |
for story in stories: | |
articles = [] | |
for url in story["urls"]: | |
normalized_url = RawDataProcessor.normalize_url(url) | |
if normalized_url in self.articles: | |
articles.append(self.articles[normalized_url]) | |
if not articles: | |
continue | |
story_partition[writer_index].append((story["label"], articles)) | |
writer_index = (writer_index + 1) % len(writers) | |
lock = multiprocessing.Lock() | |
pool = multiprocessing.pool.ThreadPool(len(writers)) | |
data = [(story_partition[i], writers[i], lock) for i in range(len(writers))] | |
pool.map(self._write_story_partition, data) | |
return len(stories), self.num_generated_examples | |
def normalize_url(cls, url): | |
"""Normalize url for better matching.""" | |
url = urllib.parse.unquote( | |
urllib.parse.urlsplit(url)._replace(query=None).geturl()) | |
output, part = [], None | |
for part in url.split("//"): | |
if part == "http:" or part == "https:": | |
continue | |
else: | |
output.append(part) | |
return "//".join(output) | |
def _get_article_content_from_json(self, file_path): | |
"""Returns (url, InputExample) keeping content extracted from file_path.""" | |
with tf.io.gfile.GFile(file_path, "r") as article_json_file: | |
article = json.load(article_json_file) | |
if self.include_article_title_in_passage: | |
return article["url"], classifier_data_lib.InputExample( | |
guid=self.guid, | |
text_a=article["title"], | |
text_b=article["maintext"], | |
label=self.label) | |
else: | |
return article["url"], classifier_data_lib.InputExample( | |
guid=self.guid, text_a=article["maintext"], label=self.label) | |
def _write_story_partition(self, data): | |
"""Writes stories in a partition into file.""" | |
for (story_headline, articles) in data[0]: | |
story_example = tf.train.Example( | |
features=tf.train.Features( | |
feature=self._get_single_story_features(story_headline, | |
articles))) | |
data[1].write(story_example.SerializeToString()) | |
data[2].acquire() | |
try: | |
self.num_generated_examples += 1 | |
if self.num_generated_examples % 1000 == 0: | |
print( | |
"Number of stories written: %d\r" % self.num_generated_examples, | |
end="") | |
finally: | |
data[2].release() | |
def _get_single_story_features(self, story_headline, articles): | |
"""Converts a list of articles to a tensorflow Example.""" | |
def get_text_snippet(article): | |
if article.text_b: | |
return " [SEP] ".join([article.text_a, article.text_b]) | |
else: | |
return article.text_a | |
story_features = collections.OrderedDict() | |
story_headline_feature = classifier_data_lib.convert_single_example( | |
ex_index=self.ex_index, | |
example=classifier_data_lib.InputExample( | |
guid=self.guid, text_a=story_headline, label=self.label), | |
label_list=[self.label], | |
max_seq_length=self.len_title, | |
tokenizer=self.tokenizer) | |
if self.include_text_snippet_in_example: | |
story_headline_feature.label_id = story_headline | |
self._add_feature_with_suffix( | |
feature=story_headline_feature, | |
suffix="a", | |
story_features=story_features) | |
for (article_index, article) in enumerate(articles): | |
if article_index == self.max_num_articles: | |
break | |
article_feature = classifier_data_lib.convert_single_example( | |
ex_index=self.ex_index, | |
example=article, | |
label_list=[self.label], | |
max_seq_length=self.len_passage, | |
tokenizer=self.tokenizer) | |
if self.include_text_snippet_in_example: | |
article_feature.label_id = get_text_snippet(article) | |
suffix = chr(ord("b") + article_index) | |
self._add_feature_with_suffix( | |
feature=article_feature, suffix=suffix, story_features=story_features) | |
# Adds empty features as placeholder. | |
for article_index in range(len(articles), self.max_num_articles): | |
suffix = chr(ord("b") + article_index) | |
empty_article = classifier_data_lib.InputExample( | |
guid=self.guid, text_a="", label=self.label) | |
empty_feature = classifier_data_lib.convert_single_example( | |
ex_index=self.ex_index, | |
example=empty_article, | |
label_list=[self.label], | |
max_seq_length=self.len_passage, | |
tokenizer=self.tokenizer) | |
if self.include_text_snippet_in_example: | |
empty_feature.label_id = "" | |
self._add_feature_with_suffix( | |
feature=empty_feature, suffix=suffix, story_features=story_features) | |
return story_features | |
def _add_feature_with_suffix(self, feature, suffix, story_features): | |
"""Appends suffix to feature names and fills in the corresponding values.""" | |
def _create_int_feature(values): | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) | |
def _create_string_feature(value): | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | |
story_features["input_ids_%c" % suffix] = _create_int_feature( | |
feature.input_ids) | |
story_features["input_mask_%c" % suffix] = _create_int_feature( | |
feature.input_mask) | |
story_features["segment_ids_%c" % suffix] = _create_int_feature( | |
feature.segment_ids) | |
if self.include_text_snippet_in_example: | |
story_features["text_snippet_%c" % suffix] = _create_string_feature( | |
bytes(feature.label_id.encode())) | |