Geneformer / geneformer /mtl /eval_utils.py
ctheodoris's picture
precommit formatting
f07bfd7
raw
history blame
No virus
3.67 kB
import pandas as pd
from .data import prepare_test_loader
from .imports import *
from .model import GeneformerMultiTask
def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
task_pred_labels = {task_name: [] for task_name in config["task_names"]}
task_pred_probs = {task_name: [] for task_name in config["task_names"]}
cell_ids = []
# # Load task label mappings from pickle file
# with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
# task_label_mappings = pickle.load(f)
model.eval()
with torch.no_grad():
for batch in test_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
_, logits, _ = model(input_ids, attention_mask)
for sample_idx in range(len(batch["input_ids"])):
cell_id = cell_id_mapping[batch["cell_id"][sample_idx].item()]
cell_ids.append(cell_id)
for i, task_name in enumerate(config["task_names"]):
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
pred_prob = (
torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
)
task_pred_labels[task_name].append(pred_label)
task_pred_probs[task_name].append(pred_prob)
# Save test predictions with cell IDs and probabilities to CSV
test_results_dir = config["results_dir"]
os.makedirs(test_results_dir, exist_ok=True)
test_preds_file = os.path.join(test_results_dir, "test_preds.csv")
rows = []
for sample_idx in range(len(cell_ids)):
row = {"Cell ID": cell_ids[sample_idx]}
for task_name in config["task_names"]:
row[f"{task_name} Prediction"] = task_pred_labels[task_name][sample_idx]
row[f"{task_name} Probabilities"] = ",".join(
map(str, task_pred_probs[task_name][sample_idx])
)
rows.append(row)
df = pd.DataFrame(rows)
df.to_csv(test_preds_file, index=False)
print(f"Test predictions saved to {test_preds_file}")
def load_and_evaluate_test_model(config):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_loader, cell_id_mapping, num_labels_list = prepare_test_loader(config)
model_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
hyperparams_path = os.path.join(model_directory, "hyperparameters.json")
# Load the saved best hyperparameters
with open(hyperparams_path, "r") as f:
best_hyperparams = json.load(f)
# Extract the task weights if present, otherwise set to None
task_weights = best_hyperparams.get("task_weights", None)
normalized_task_weights = task_weights if task_weights else []
# Print the loaded hyperparameters
print("Loaded hyperparameters:")
for param, value in best_hyperparams.items():
if param == "task_weights":
print(f"normalized_task_weights: {value}")
else:
print(f"{param}: {value}")
best_model_path = os.path.join(model_directory, "pytorch_model.bin")
best_model = GeneformerMultiTask(
config["pretrained_path"],
num_labels_list,
dropout_rate=best_hyperparams["dropout_rate"],
use_task_weights=config["use_task_weights"],
task_weights=normalized_task_weights,
)
best_model.load_state_dict(torch.load(best_model_path))
best_model.to(device)
evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
print("Evaluation completed.")