Spaces:
Running
Running
# Copyright 2017 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Input readers and document/token generators for datasets.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
from collections import namedtuple | |
import csv | |
import os | |
import random | |
# Dependency imports | |
import tensorflow as tf | |
from data import data_utils | |
flags = tf.app.flags | |
FLAGS = flags.FLAGS | |
flags.DEFINE_string('dataset', '', 'Which dataset to generate data for') | |
# Preprocessing config | |
flags.DEFINE_boolean('output_unigrams', True, 'Whether to output unigrams.') | |
flags.DEFINE_boolean('output_bigrams', False, 'Whether to output bigrams.') | |
flags.DEFINE_boolean('output_char', False, 'Whether to output characters.') | |
flags.DEFINE_boolean('lowercase', True, 'Whether to lowercase document terms.') | |
# IMDB | |
flags.DEFINE_string('imdb_input_dir', '', 'The input directory containing the ' | |
'IMDB sentiment dataset.') | |
flags.DEFINE_integer('imdb_validation_pos_start_id', 10621, 'File id of the ' | |
'first file in the pos sentiment validation set.') | |
flags.DEFINE_integer('imdb_validation_neg_start_id', 10625, 'File id of the ' | |
'first file in the neg sentiment validation set.') | |
# DBpedia | |
flags.DEFINE_string('dbpedia_input_dir', '', | |
'Path to DBpedia directory containing train.csv and ' | |
'test.csv.') | |
# Reuters Corpus (rcv1) | |
flags.DEFINE_string('rcv1_input_dir', '', | |
'Path to rcv1 directory containing train.csv, unlab.csv, ' | |
'and test.csv.') | |
# Rotten Tomatoes | |
flags.DEFINE_string('rt_input_dir', '', | |
'The Rotten Tomatoes dataset input directory.') | |
# The amazon reviews input file to use in either the RT or IMDB datasets. | |
flags.DEFINE_string('amazon_unlabeled_input_file', '', | |
'The unlabeled Amazon Reviews dataset input file. If set, ' | |
'the input file is used to augment RT and IMDB vocab.') | |
Document = namedtuple('Document', | |
'content is_validation is_test label add_tokens') | |
def documents(dataset='train', | |
include_unlabeled=False, | |
include_validation=False): | |
"""Generates Documents based on FLAGS.dataset. | |
Args: | |
dataset: str, identifies folder within IMDB data directory, test or train. | |
include_unlabeled: bool, whether to include the unsup directory. Only valid | |
when dataset=train. | |
include_validation: bool, whether to include validation data. | |
Yields: | |
Document | |
Raises: | |
ValueError: if include_unlabeled is true but dataset is not 'train' | |
""" | |
if include_unlabeled and dataset != 'train': | |
raise ValueError('If include_unlabeled=True, must use train dataset') | |
# Set the random seed so that we have the same validation set when running | |
# gen_data and gen_vocab. | |
random.seed(302) | |
ds = FLAGS.dataset | |
if ds == 'imdb': | |
docs_gen = imdb_documents | |
elif ds == 'dbpedia': | |
docs_gen = dbpedia_documents | |
elif ds == 'rcv1': | |
docs_gen = rcv1_documents | |
elif ds == 'rt': | |
docs_gen = rt_documents | |
else: | |
raise ValueError('Unrecognized dataset %s' % FLAGS.dataset) | |
for doc in docs_gen(dataset, include_unlabeled, include_validation): | |
yield doc | |
def tokens(doc): | |
"""Given a Document, produces character or word tokens. | |
Tokens can be either characters, or word-level tokens (unigrams and/or | |
bigrams). | |
Args: | |
doc: Document to produce tokens from. | |
Yields: | |
token | |
Raises: | |
ValueError: if all FLAGS.{output_unigrams, output_bigrams, output_char} | |
are False. | |
""" | |
if not (FLAGS.output_unigrams or FLAGS.output_bigrams or FLAGS.output_char): | |
raise ValueError( | |
'At least one of {FLAGS.output_unigrams, FLAGS.output_bigrams, ' | |
'FLAGS.output_char} must be true') | |
content = doc.content.strip() | |
if FLAGS.lowercase: | |
content = content.lower() | |
if FLAGS.output_char: | |
for char in content: | |
yield char | |
else: | |
tokens_ = data_utils.split_by_punct(content) | |
for i, token in enumerate(tokens_): | |
if FLAGS.output_unigrams: | |
yield token | |
if FLAGS.output_bigrams: | |
previous_token = (tokens_[i - 1] if i > 0 else data_utils.EOS_TOKEN) | |
bigram = '_'.join([previous_token, token]) | |
yield bigram | |
if (i + 1) == len(tokens_): | |
bigram = '_'.join([token, data_utils.EOS_TOKEN]) | |
yield bigram | |
def imdb_documents(dataset='train', | |
include_unlabeled=False, | |
include_validation=False): | |
"""Generates Documents for IMDB dataset. | |
Data from http://ai.stanford.edu/~amaas/data/sentiment/ | |
Args: | |
dataset: str, identifies folder within IMDB data directory, test or train. | |
include_unlabeled: bool, whether to include the unsup directory. Only valid | |
when dataset=train. | |
include_validation: bool, whether to include validation data. | |
Yields: | |
Document | |
Raises: | |
ValueError: if FLAGS.imdb_input_dir is empty. | |
""" | |
if not FLAGS.imdb_input_dir: | |
raise ValueError('Must provide FLAGS.imdb_input_dir') | |
tf.logging.info('Generating IMDB documents...') | |
def check_is_validation(filename, class_label): | |
if class_label is None: | |
return False | |
file_idx = int(filename.split('_')[0]) | |
is_pos_valid = (class_label and | |
file_idx >= FLAGS.imdb_validation_pos_start_id) | |
is_neg_valid = (not class_label and | |
file_idx >= FLAGS.imdb_validation_neg_start_id) | |
return is_pos_valid or is_neg_valid | |
dirs = [(dataset + '/pos', True), (dataset + '/neg', False)] | |
if include_unlabeled: | |
dirs.append(('train/unsup', None)) | |
for d, class_label in dirs: | |
for filename in os.listdir(os.path.join(FLAGS.imdb_input_dir, d)): | |
is_validation = check_is_validation(filename, class_label) | |
if is_validation and not include_validation: | |
continue | |
with open(os.path.join(FLAGS.imdb_input_dir, d, filename), encoding='utf-8') as imdb_f: | |
content = imdb_f.read() | |
yield Document( | |
content=content, | |
is_validation=is_validation, | |
is_test=False, | |
label=class_label, | |
add_tokens=True) | |
if FLAGS.amazon_unlabeled_input_file and include_unlabeled: | |
with open(FLAGS.amazon_unlabeled_input_file, encoding='utf-8') as rt_f: | |
for content in rt_f: | |
yield Document( | |
content=content, | |
is_validation=False, | |
is_test=False, | |
label=None, | |
add_tokens=False) | |
def dbpedia_documents(dataset='train', | |
include_unlabeled=False, | |
include_validation=False): | |
"""Generates Documents for DBpedia dataset. | |
Dataset linked to at https://github.com/zhangxiangxiao/Crepe. | |
Args: | |
dataset: str, identifies the csv file within the DBpedia data directory, | |
test or train. | |
include_unlabeled: bool, unused. | |
include_validation: bool, whether to include validation data, which is a | |
randomly selected 10% of the data. | |
Yields: | |
Document | |
Raises: | |
ValueError: if FLAGS.dbpedia_input_dir is empty. | |
""" | |
del include_unlabeled | |
if not FLAGS.dbpedia_input_dir: | |
raise ValueError('Must provide FLAGS.dbpedia_input_dir') | |
tf.logging.info('Generating DBpedia documents...') | |
with open(os.path.join(FLAGS.dbpedia_input_dir, dataset + '.csv')) as db_f: | |
reader = csv.reader(db_f) | |
for row in reader: | |
# 10% of the data is randomly held out | |
is_validation = random.randint(1, 10) == 1 | |
if is_validation and not include_validation: | |
continue | |
content = row[1] + ' ' + row[2] | |
yield Document( | |
content=content, | |
is_validation=is_validation, | |
is_test=False, | |
label=int(row[0]) - 1, # Labels should start from 0 | |
add_tokens=True) | |
def rcv1_documents(dataset='train', | |
include_unlabeled=True, | |
include_validation=False): | |
# pylint:disable=line-too-long | |
"""Generates Documents for Reuters Corpus (rcv1) dataset. | |
Dataset described at | |
http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/lyrl2004_rcv1v2_README.htm | |
Args: | |
dataset: str, identifies the csv file within the rcv1 data directory. | |
include_unlabeled: bool, whether to include the unlab file. Only valid | |
when dataset=train. | |
include_validation: bool, whether to include validation data, which is a | |
randomly selected 10% of the data. | |
Yields: | |
Document | |
Raises: | |
ValueError: if FLAGS.rcv1_input_dir is empty. | |
""" | |
# pylint:enable=line-too-long | |
if not FLAGS.rcv1_input_dir: | |
raise ValueError('Must provide FLAGS.rcv1_input_dir') | |
tf.logging.info('Generating rcv1 documents...') | |
datasets = [dataset] | |
if include_unlabeled: | |
if dataset == 'train': | |
datasets.append('unlab') | |
for dset in datasets: | |
with open(os.path.join(FLAGS.rcv1_input_dir, dset + '.csv')) as db_f: | |
reader = csv.reader(db_f) | |
for row in reader: | |
# 10% of the data is randomly held out | |
is_validation = random.randint(1, 10) == 1 | |
if is_validation and not include_validation: | |
continue | |
content = row[1] | |
yield Document( | |
content=content, | |
is_validation=is_validation, | |
is_test=False, | |
label=int(row[0]), | |
add_tokens=True) | |
def rt_documents(dataset='train', | |
include_unlabeled=True, | |
include_validation=False): | |
# pylint:disable=line-too-long | |
"""Generates Documents for the Rotten Tomatoes dataset. | |
Dataset available at http://www.cs.cornell.edu/people/pabo/movie-review-data/ | |
In this dataset, amazon reviews are used for the unlabeled data. | |
Args: | |
dataset: str, identifies the data subdirectory. | |
include_unlabeled: bool, whether to include the unlabeled data. Only valid | |
when dataset=train. | |
include_validation: bool, whether to include validation data, which is a | |
randomly selected 10% of the data. | |
Yields: | |
Document | |
Raises: | |
ValueError: if FLAGS.rt_input_dir is empty. | |
""" | |
# pylint:enable=line-too-long | |
if not FLAGS.rt_input_dir: | |
raise ValueError('Must provide FLAGS.rt_input_dir') | |
tf.logging.info('Generating rt documents...') | |
data_files = [] | |
input_filenames = os.listdir(FLAGS.rt_input_dir) | |
for inp_fname in input_filenames: | |
if inp_fname.endswith('.pos'): | |
data_files.append((os.path.join(FLAGS.rt_input_dir, inp_fname), True)) | |
elif inp_fname.endswith('.neg'): | |
data_files.append((os.path.join(FLAGS.rt_input_dir, inp_fname), False)) | |
if include_unlabeled and FLAGS.amazon_unlabeled_input_file: | |
data_files.append((FLAGS.amazon_unlabeled_input_file, None)) | |
for filename, class_label in data_files: | |
with open(filename) as rt_f: | |
for content in rt_f: | |
if class_label is None: | |
# Process Amazon Review data for unlabeled dataset | |
if content.startswith('review/text'): | |
yield Document( | |
content=content, | |
is_validation=False, | |
is_test=False, | |
label=None, | |
add_tokens=False) | |
else: | |
# 10% of the data is randomly held out for the validation set and | |
# another 10% of it is randomly held out for the test set | |
random_int = random.randint(1, 10) | |
is_validation = random_int == 1 | |
is_test = random_int == 2 | |
if (is_test and dataset != 'test') or (is_validation and | |
not include_validation): | |
continue | |
yield Document( | |
content=content, | |
is_validation=is_validation, | |
is_test=is_test, | |
label=class_label, | |
add_tokens=True) | |