Alejandro Velez
tdc geneformer
47990ca
raw
history blame
6.26 kB
import os
from .collators import DataCollatorForMultitaskCellClassification
from .imports import *
def validate_columns(dataset, required_columns, dataset_type):
"""Ensures required columns are present in the dataset."""
missing_columns = [col for col in required_columns if col not in dataset.column_names]
if missing_columns:
raise KeyError(
f"Missing columns in {dataset_type} dataset: {missing_columns}. "
f"Available columns: {dataset.column_names}"
)
def create_label_mappings(dataset, task_to_column):
"""Creates label mappings for the dataset."""
task_label_mappings = {}
num_labels_list = []
for task, column in task_to_column.items():
unique_values = sorted(set(dataset[column]))
mapping = {label: idx for idx, label in enumerate(unique_values)}
task_label_mappings[task] = mapping
num_labels_list.append(len(unique_values))
return task_label_mappings, num_labels_list
def save_label_mappings(mappings, path):
"""Saves label mappings to a pickle file."""
with open(path, "wb") as f:
pickle.dump(mappings, f)
def load_label_mappings(path):
"""Loads label mappings from a pickle file."""
with open(path, "rb") as f:
return pickle.load(f)
def transform_dataset(dataset, task_to_column, task_label_mappings, config, is_test):
"""Transforms the dataset to the required format."""
transformed_dataset = []
cell_id_mapping = {}
for idx, record in enumerate(dataset):
transformed_record = {
"input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
"cell_id": idx, # Index-based cell ID
}
if not is_test:
label_dict = {
task: task_label_mappings[task][record[column]]
for task, column in task_to_column.items()
}
else:
label_dict = {task: -1 for task in config["task_names"]}
transformed_record["label"] = label_dict
transformed_dataset.append(transformed_record)
cell_id_mapping[idx] = record.get("unique_cell_id", idx)
return transformed_dataset, cell_id_mapping
def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
"""Main function to load and preprocess data."""
try:
dataset = load_from_disk(dataset_path)
# Setup task and column mappings
task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
task_to_column = dict(zip(task_names, config["task_columns"]))
config["task_names"] = task_names
label_mappings_path = os.path.join(
config["results_dir"],
f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
)
if not is_test:
validate_columns(dataset, task_to_column.values(), dataset_type)
# Create and save label mappings
task_label_mappings, num_labels_list = create_label_mappings(dataset, task_to_column)
save_label_mappings(task_label_mappings, label_mappings_path)
else:
# Load existing mappings for test data
task_label_mappings = load_label_mappings(label_mappings_path)
num_labels_list = [len(mapping) for mapping in task_label_mappings.values()]
# Transform dataset
transformed_dataset, cell_id_mapping = transform_dataset(
dataset, task_to_column, task_label_mappings, config, is_test
)
return transformed_dataset, cell_id_mapping, num_labels_list
except KeyError as e:
raise ValueError(f"Configuration error or dataset key missing: {e}")
except Exception as e:
raise RuntimeError(f"Error during data loading or preprocessing: {e}")
def preload_and_process_data(config):
"""Preloads and preprocesses train and validation datasets."""
# Process train data and save mappings
train_data = load_and_preprocess_data(config["train_path"], config, dataset_type="train")
# Process validation data and save mappings
val_data = load_and_preprocess_data(config["val_path"], config, dataset_type="validation")
# Validate that the mappings match
validate_label_mappings(config)
return (*train_data, *val_data[:2]) # Return train and val data along with mappings
def validate_label_mappings(config):
"""Ensures train and validation label mappings are consistent."""
train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
train_mappings = load_label_mappings(train_mappings_path)
val_mappings = load_label_mappings(val_mappings_path)
for task_name in config["task_names"]:
if train_mappings[task_name] != val_mappings[task_name]:
raise ValueError(
f"Mismatch in label mappings for task '{task_name}'.\n"
f"Train Mapping: {train_mappings[task_name]}\n"
f"Validation Mapping: {val_mappings[task_name]}"
)
def get_data_loader(preprocessed_dataset, batch_size):
"""Creates a DataLoader with optimal settings."""
return DataLoader(
preprocessed_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=DataCollatorForMultitaskCellClassification(),
num_workers=os.cpu_count(),
pin_memory=True,
)
def preload_data(config):
"""Preprocesses train and validation data for trials."""
train_loader = get_data_loader(*preload_and_process_data(config)[:2], config["batch_size"])
val_loader = get_data_loader(*preload_and_process_data(config)[2:4], config["batch_size"])
return train_loader, val_loader
def load_and_preprocess_test_data(config):
"""Loads and preprocesses test data."""
return load_and_preprocess_data(config["test_path"], config, is_test=True)
def prepare_test_loader(config):
"""Prepares DataLoader for test data."""
test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config)
test_loader = get_data_loader(test_dataset, config["batch_size"])
return test_loader, cell_id_mapping, num_labels_list