Ababababababbababa's picture
Duplicate from arbml/Ashaar
6faf7e7
raw
history blame
No virus
7.68 kB
"""
Loading the diacritization dataset
"""
import os
from diacritization_evaluation import util
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from .config_manager import ConfigManager
BASIC_HARAQAT = {
"ูŽ": "Fatha ",
"ู‹": "Fathatah ",
"ู": "Damma ",
"ูŒ": "Dammatan ",
"ู": "Kasra ",
"ู": "Kasratan ",
"ู’": "Sukun ",
"ู‘": "Shaddah ",
}
class DiacritizationDataset(Dataset):
"""
The diacritization dataset
"""
def __init__(self, config_manager: ConfigManager, list_ids, data):
"Initialization"
self.list_ids = list_ids
self.data = data
self.text_encoder = config_manager.text_encoder
self.config = config_manager.config
def __len__(self):
"Denotes the total number of samples"
return len(self.list_ids)
def preprocess(self, book):
out = ""
i = 0
while i < len(book):
if i < len(book) - 1:
if book[i] in BASIC_HARAQAT and book[i + 1] in BASIC_HARAQAT:
i += 1
continue
out += book[i]
i += 1
return out
def __getitem__(self, index):
"Generates one sample of data"
# Select sample
id = self.list_ids[index]
if self.config["is_data_preprocessed"]:
data = self.data.iloc[id]
inputs = torch.Tensor(self.text_encoder.input_to_sequence(data[1]))
targets = torch.Tensor(
self.text_encoder.target_to_sequence(
data[2].split(self.config["diacritics_separator"])
)
)
return inputs, targets, data[0]
data = self.data[id]
non_cleaned = data
data = self.text_encoder.clean(data)
data = data[: self.config["max_sen_len"]]
text, inputs, diacritics = util.extract_haraqat(data)
inputs = torch.Tensor(self.text_encoder.input_to_sequence("".join(inputs)))
diacritics = torch.Tensor(self.text_encoder.target_to_sequence(diacritics))
return inputs, diacritics, text
def collate_fn(data):
"""
Padding the input and output sequences
"""
def merge(sequences):
lengths = [len(seq) for seq in sequences]
padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
return padded_seqs, lengths
data.sort(key=lambda x: len(x[0]), reverse=True)
# separate source and target sequences
src_seqs, trg_seqs, original = zip(*data)
# merge sequences (from tuple of 1D tensor to 2D tensor)
src_seqs, src_lengths = merge(src_seqs)
trg_seqs, trg_lengths = merge(trg_seqs)
batch = {
"original": original,
"src": src_seqs,
"target": trg_seqs,
"lengths": torch.LongTensor(src_lengths), # src_lengths = trg_lengths
}
return batch
def load_training_data(config_manager: ConfigManager, loader_parameters):
"""
Loading the training data using pandas
"""
if not config_manager.config["load_training_data"]:
return []
path = os.path.join(config_manager.data_dir, "train.csv")
if config_manager.config["is_data_preprocessed"]:
train_data = pd.read_csv(
path,
encoding="utf-8",
sep=config_manager.config["data_separator"],
nrows=config_manager.config["n_training_examples"],
header=None,
)
# train_data = train_data[train_data[0] <= config_manager.config["max_len"]]
training_set = DiacritizationDataset(
config_manager, train_data.index, train_data
)
else:
with open(path, encoding="utf8") as file:
train_data = file.readlines()
train_data = [
text
for text in train_data
if len(text) <= config_manager.config["max_len"] and len(text) > 0
]
training_set = DiacritizationDataset(
config_manager, [idx for idx in range(len(train_data))], train_data
)
train_iterator = DataLoader(
training_set, collate_fn=collate_fn, **loader_parameters
)
print(f"Length of training iterator = {len(train_iterator)}")
return train_iterator
def load_test_data(config_manager: ConfigManager, loader_parameters):
"""
Loading the test data using pandas
"""
if not config_manager.config["load_test_data"]:
return []
test_file_name = config_manager.config.get("test_file_name", "test.csv")
path = os.path.join(config_manager.data_dir, test_file_name)
if config_manager.config["is_data_preprocessed"]:
test_data = pd.read_csv(
path,
encoding="utf-8",
sep=config_manager.config["data_separator"],
nrows=config_manager.config["n_test_examples"],
header=None,
)
# test_data = test_data[test_data[0] <= config_manager.config["max_len"]]
test_dataset = DiacritizationDataset(config_manager, test_data.index, test_data)
else:
with open(path, encoding="utf8") as file:
test_data = file.readlines()
max_len = config_manager.config["max_len"]
test_data = [text[:max_len] for text in test_data]
test_dataset = DiacritizationDataset(
config_manager, [idx for idx in range(len(test_data))], test_data
)
test_iterator = DataLoader(test_dataset, collate_fn=collate_fn, **loader_parameters)
print(f"Length of test iterator = {len(test_iterator)}")
return test_iterator
def load_validation_data(config_manager: ConfigManager, loader_parameters):
"""
Loading the validation data using pandas
"""
if not config_manager.config["load_validation_data"]:
return []
path = os.path.join(config_manager.data_dir, "eval.csv")
if config_manager.config["is_data_preprocessed"]:
valid_data = pd.read_csv(
path,
encoding="utf-8",
sep=config_manager.config["data_separator"],
nrows=config_manager.config["n_validation_examples"],
header=None,
)
valid_data = valid_data[valid_data[0] <= config_manager.config["max_len"]]
valid_dataset = DiacritizationDataset(
config_manager, valid_data.index, valid_data
)
else:
with open(path, encoding="utf8") as file:
valid_data = file.readlines()
max_len = config_manager.config["max_len"]
valid_data = [text[:max_len] for text in valid_data]
valid_dataset = DiacritizationDataset(
config_manager, [idx for idx in range(len(valid_data))], valid_data
)
valid_iterator = DataLoader(
valid_dataset, collate_fn=collate_fn, **loader_parameters
)
print(f"Length of valid iterator = {len(valid_iterator)}")
return valid_iterator
def load_iterators(config_manager: ConfigManager):
"""
Load the data iterators
Args:
"""
params = {
"batch_size": config_manager.config["batch_size"],
"shuffle": True,
"num_workers": 2,
}
train_iterator = load_training_data(config_manager, loader_parameters=params)
valid_iterator = load_validation_data(config_manager, loader_parameters=params)
test_iterator = load_test_data(config_manager, loader_parameters=params)
return train_iterator, test_iterator, valid_iterator