Model Details
Model Description
This model is fine-tuned from stable-diffusion-v1-5 on 110,000 image-text pairs from the MIMIC dataset using the SVDIFF [1] PEFT method. Under this fine-tuning strategy, fine-tune only the singular values of weight matrices in the U-Net while keeping everything else frozen.
- Developed by: Raman Dutt
- Shared by: Raman Dutt
- Model type: [Stable Diffusion fine-tuned using Parameter-Efficient Fine-Tuning]
- Finetuned from model: stable-diffusion-v1-5
Model Sources
- Paper: Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity
- Demo: MIMIC-SD-PEFT-Demo
Direct Use
This model can be directly used to generate realistic medical images from text prompts.
How to Get Started with the Model
import os
from safetensors.torch import load_file
from diffusers.pipelines import StableDiffusionPipeline
#### Defining loading function
def load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=None, hf_hub_kwargs=None, **kwargs):
print(pretrained_model_name_or_path)
config = UNet2DConditionModel.load_config(pretrained_model_name_or_path, **kwargs)
original_model = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
state_dict = original_model.state_dict()
with accelerate.init_empty_weights():
model = UNet2DConditionModelForSVDiff.from_config(config)
# load pre-trained weights
param_device = "cpu"
torch_dtype = kwargs["torch_dtype"] if "torch_dtype" in kwargs else None
spectral_shifts_weights = {n: torch.zeros(p.shape) for n, p in model.named_parameters() if "delta" in n}
state_dict.update(spectral_shifts_weights)
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0:
raise ValueError(
f"Cannot load {model.__class__.__name__} from {pretrained_model_name_or_path} because the following keys are"
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
" those weights or else make sure your checkpoint file is correct."
)
for param_name, param in state_dict.items():
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
if accepts_dtype:
set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
else:
set_module_tensor_to_device(model, param_name, param_device, value=param)
if spectral_shifts_ckpt:
if os.path.isdir(spectral_shifts_ckpt):
spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts.safetensors")
elif not os.path.exists(spectral_shifts_ckpt):
# download from hub
hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs
spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts.safetensors", **hf_hub_kwargs)
assert os.path.exists(spectral_shifts_ckpt)
with safe_open(spectral_shifts_ckpt, framework="pt", device="cpu") as f:
for key in f.keys():
# spectral_shifts_weights[key] = f.get_tensor(key)
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
if accepts_dtype:
set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key), dtype=torch_dtype)
else:
set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key))
print(f"Resumed from {spectral_shifts_ckpt}")
if "torch_dtype"in kwargs:
model = model.to(kwargs["torch_dtype"])
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
del original_model
torch.cuda.empty_cache()
return model
pipe.unet = load_unet_for_svdiff(
"runwayml/stable-diffusion-v1-5",
spectral_shifts_ckpt=os.path.join('unet', "spectral_shifts.safetensors"),
subfolder="unet",
)
for module in pipe.unet.modules():
if hasattr(module, "perform_svd"):
module.perform_svd()
# Load the adapted U-Net
pipe.unet.load_state_dict(state_dict, strict=False)
pipe.to('cuda:0')
# Generate images with text prompts
TEXT_PROMPT = "No acute cardiopulmonary abnormality."
GUIDANCE_SCALE = 4
INFERENCE_STEPS = 75
result_image = pipe(
prompt=TEXT_PROMPT,
height=224,
width=224,
guidance_scale=GUIDANCE_SCALE,
num_inference_steps=INFERENCE_STEPS,
)
result_pil_image = result_image["images"][0]
Training Details
Training Data
This model has been fine-tuned on 110K image-text pairs from the MIMIC dataset.
Training Procedure
The training procedure has been described in detail in Section 4.3 of this paper.
Metrics
This model has been evaluated using the Fréchet inception distance (FID) Score on MIMIC dataset.
Results
Fine-Tuning Strategy | FID Score |
---|---|
Full FT | 58.74 |
Attention | 52.41 |
Bias | 20.81 |
Norm | 29.84 |
Bias+Norm+Attention | 35.93 |
LoRA | 439.65 |
SV-Diff | 23.59 |
DiffFit | 42.50 |
Environmental Impact
Using Parameter-Efficient Fine-Tuning potentially causes lesser harm to the environment since we fine-tune a significantly lesser number of parameters in a model. This results in much lesser computing and hardware requirements.
Citation
BibTeX:
@article{dutt2023parameter, title={Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity}, author={Dutt, Raman and Ericsson, Linus and Sanchez, Pedro and Tsaftaris, Sotirios A and Hospedales, Timothy}, journal={arXiv preprint arXiv:2305.08252}, year={2023} }
APA:
Dutt, R., Ericsson, L., Sanchez, P., Tsaftaris, S. A., & Hospedales, T. (2023). Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity. arXiv preprint arXiv:2305.08252.
Model Card Authors
Raman Dutt
Twitter
LinkedIn
Email
References
- Han, Ligong, et al. "Svdiff: Compact parameter space for diffusion fine-tuning." arXiv preprint arXiv:2303.11305 (2023).
- Downloads last month
- 6