Help for model interpretation

#295
by MinieRosie - opened

Hello, I have been running the geneformer on a difficult problem.

So this model is able to get good accuracy ~0.85-88 when asking to classify between Diseases A and B. I then split disease B in two classes (subtype 1 and subtype 2). Here the model only gets a ~0.5 accuracy. Just to note, we expect about a ~20% error rate for the actual labels. Meaning that some of Disease B subtype 1 are actually subtype 2 and some of subtype 2 are actually 1. Obviously it is unknown to us which are which so we do the best we can with the data as it is presented to us. So essentially this model has difficulty classifying between Disease B subtype 1 and 2.

However, when looking at the embeddings, they are still clustering pretty well. (I labeled the Disease B subtype 2 into a and b but the model was trained only for Disease A, B subtype 1 and subtype 2).

I am wondering if it would be a mistake to go ahead and just average these cell embeddings and perform the insilico perturbation analysis anyway. Assuming that the average is still a decent representation of the class states. Or is this model just overfitting the data and any further interpretation is not useful?

I really appreciate any insight, thank you very much.

Screen Shot 2024-02-03 at 10.02.14 AM.png

Thank you for your question. Assuming the UMAP you included is test data that wasn't used for training or validation used to select hyperparameters, the classes Disease B subtype 1 vs 2 are reasonably separated as you said, so the in silico perturbation analysis should still indicate which genes when modulated are predicted to shift between the states. Given they are pretty well separated, if this is truly held-out test data, I would expect the classification accuracy to be better at least when separating the classes that appear distinct here (Disease A, Disease B subtype 1, and Disease B subtype 2). It can be informative to determine where the misclassifications are occurring. For example, if the accuracy is poor because most of the misclassifications are between Disease B subtype 2a and 2b, but this is less critical to differentiate, that would be informative. It is also perhaps a bit tricky here because you expect some misclassifications at baseline due to the estimated 20% error rate in the labels.

In general, like all modeling approaches, to avoid overfitting, it's best to include as many patients as possible (and ideally multiple completely separate datasets if available) in the training data so that the model learns the generalizable features of the classes as opposed to overfitting to a particular individual (or study). It's also best practice to match the other biological characteristics between classes to ensure the model isn't learning a different confounding feature to separate them (for example if all of Disease A was in children and all of Disease B was in adults, the model may learn to distinguish them based on their age rather than the disease state). We do recommend hyperparameter tuning and we always separate the training, validation, and test sets by individuals so that we can get a better estimate of the generalizability of the model to unseen individuals than if a subsample of cells from all individuals was included in the training or validation data and also in the test data. One other thing to keep in mind is that these models can very easily memorize the data, so we usually restrict fine-tuning to 1 epoch.

ctheodoris changed discussion status to closed

Sign up or log in to comment