RP3D-DiagModel
About Checkpoint
The detailed parameter we use for training is in the following:
start_class: 0
end_clas: 5569
backbone: 'resnet'
level: 'articles' # represents the disorder level
depth: 32
ltype: 'MultiLabel' # represents the Binary Cross Entropy Loss
augment: True # represents the medical data augmentation
split: 'late' # represents the late fusion strategy
Load Model
# Load backnone
model = RadNet(num_cls=num_classes, backbone=backbone, depth=depth, ltype=ltype, augment=augment, fuse=fuse, ke=ke, encoded=encoded, adapter=adapter)
pretrained_weights = torch.load("path/to/pytorch_model_32_late.bin")
missing, unexpect = model.load_state_dict(pretrained_weights,strict=False)
print("missing_cpt:", missing)
print("unexpect_cpt:", unexpect)
# If KE is set True, load text encoder
medcpt = MedCPT_clinical(bert_model_name = 'ncbi/MedCPT-Query-Encoder')
checkpoint = torch.load('path/to/epoch_state.pt',map_location='cpu')['state_dict']
load_checkpoint = {key.replace('module.', ''): value for key, value in checkpoint.items()}
missing, unexpect = medcpt.load_state_dict(load_checkpoint, strict=False)
print("missing_cpt:", missing)
print("unexpect_cpt:", unexpect)
Why we provide this checkpoint?
All the early fusion checkpoint can be further finetuned from this checkpoint. If you need other checkpoints using different parameter settings, there are two possible ways:
Finetune from this checkpoint
''' checkpoint: "None" safetensor: path to this checkpoint(pytorch_model.bin) '''
Contact Us
Email the author: three-world@sjtu.edu.cn
About Dataset
Please refer to RP3D-DiagDS
For more information, please refer to our instructions on github to download and use.