Upload model
Browse files- configuration_cxrmate_ed.py +0 -164
- modelling_cxrmate_ed.py +267 -445
configuration_cxrmate_ed.py
CHANGED
@@ -46,167 +46,3 @@ class CXRMateEDConfig(transformers.PretrainedConfig):
|
|
46 |
text_config = CONFIG_MAPPING[text_config['model_type']](**text_config)
|
47 |
|
48 |
self.text_config = text_config
|
49 |
-
|
50 |
-
# class CXRMateEDConfig(transformers.PretrainedConfig):
|
51 |
-
|
52 |
-
# model_type = 'cxrmate-ed'
|
53 |
-
|
54 |
-
# # def __init__(
|
55 |
-
# # self,
|
56 |
-
# # index_value_encoder_intermediate_size: int = 2048,
|
57 |
-
# # include_time_delta: bool = True,
|
58 |
-
# # time_delta_monotonic_inversion: bool = True,
|
59 |
-
# # add_time_deltas: bool = True,
|
60 |
-
# # history: int = 0,
|
61 |
-
# # tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'],
|
62 |
-
# # prompt_report_sections_filter: list = ['indication', 'history'],
|
63 |
-
# # pad_token_id: int = 4,
|
64 |
-
# # **kwargs: Any,
|
65 |
-
# # ) -> None:
|
66 |
-
# # super().__init__(**kwargs)
|
67 |
-
# # self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size
|
68 |
-
# # self.include_time_delta = include_time_delta
|
69 |
-
# # self.time_delta_monotonic_inversion = time_delta_monotonic_inversion
|
70 |
-
# # self.add_time_deltas = add_time_deltas
|
71 |
-
# # self.history = history
|
72 |
-
# # self.tables_filter = tables_filter
|
73 |
-
# # self.prompt_report_sections_filter = prompt_report_sections_filter
|
74 |
-
# # self.pad_token_id = pad_token_id
|
75 |
-
|
76 |
-
# # self.hidden_size = self.text_config.hidden_size
|
77 |
-
|
78 |
-
# def __init__(
|
79 |
-
# self,
|
80 |
-
# vision_config=None,
|
81 |
-
# text_config=None,
|
82 |
-
# # ignore_index=-100,
|
83 |
-
# # image_token_index=32000,
|
84 |
-
# # projector_hidden_act="gelu",
|
85 |
-
# # vision_feature_select_strategy="default",
|
86 |
-
# # vision_feature_layer=-2,
|
87 |
-
# # image_seq_length=576,
|
88 |
-
# index_value_encoder_intermediate_size: int = 2048,
|
89 |
-
# include_time_delta: bool = True,
|
90 |
-
# time_delta_monotonic_inversion: bool = True,
|
91 |
-
# add_time_deltas: bool = True,
|
92 |
-
# history: int = 0,
|
93 |
-
# tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'],
|
94 |
-
# prompt_report_sections_filter: list = ['indication', 'history'],
|
95 |
-
# pad_token_id: int = 4,
|
96 |
-
# **kwargs,
|
97 |
-
# ):
|
98 |
-
# transformers.PretrainedConfig.__init__(self, **kwargs)
|
99 |
-
|
100 |
-
# self.vision_config = vision_config
|
101 |
-
# self.text_config = text_config
|
102 |
-
# self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size
|
103 |
-
# self.include_time_delta = include_time_delta
|
104 |
-
# self.time_delta_monotonic_inversion = time_delta_monotonic_inversion
|
105 |
-
# self.add_time_deltas = add_time_deltas
|
106 |
-
# self.history = history
|
107 |
-
# self.tables_filter = tables_filter
|
108 |
-
# self.prompt_report_sections_filter = prompt_report_sections_filter
|
109 |
-
# self.pad_token_id = pad_token_id
|
110 |
-
|
111 |
-
# self.ignore_index = ignore_index
|
112 |
-
# self.image_token_index = image_token_index
|
113 |
-
# self.projector_hidden_act = projector_hidden_act
|
114 |
-
# self.image_seq_length = image_seq_length
|
115 |
-
|
116 |
-
# if vision_feature_select_strategy not in ["default", "full"]:
|
117 |
-
# raise ValueError(
|
118 |
-
# "vision_feature_select_strategy should be one of 'default', 'full'."
|
119 |
-
# f"Got: {vision_feature_select_strategy}"
|
120 |
-
# )
|
121 |
-
|
122 |
-
# self.vision_feature_select_strategy = vision_feature_select_strategy
|
123 |
-
# self.vision_feature_layer = vision_feature_layer
|
124 |
-
|
125 |
-
# if isinstance(vision_config, dict):
|
126 |
-
# vision_config["model_type"] = (
|
127 |
-
# vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
|
128 |
-
# )
|
129 |
-
# vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
130 |
-
# elif vision_config is None:
|
131 |
-
# vision_config = CONFIG_MAPPING["clip_vision_model"](
|
132 |
-
# intermediate_size=4096,
|
133 |
-
# hidden_size=1024,
|
134 |
-
# patch_size=14,
|
135 |
-
# image_size=336,
|
136 |
-
# num_hidden_layers=24,
|
137 |
-
# num_attention_heads=16,
|
138 |
-
# vocab_size=32000,
|
139 |
-
# projection_dim=768,
|
140 |
-
# )
|
141 |
-
|
142 |
-
|
143 |
-
# if isinstance(text_config, dict):
|
144 |
-
# text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
|
145 |
-
# text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
146 |
-
# elif text_config is None:
|
147 |
-
# text_config = CONFIG_MAPPING["llama"]()
|
148 |
-
|
149 |
-
# super().__init__(**kwargs)
|
150 |
-
|
151 |
-
|
152 |
-
# import transformers
|
153 |
-
# from transformers.configuration_utils import PretrainedConfig
|
154 |
-
# from transformers.utils import logging
|
155 |
-
|
156 |
-
# logger = logging.get_logger(__name__)
|
157 |
-
|
158 |
-
|
159 |
-
# class CXRMateEDConfig(PretrainedConfig):
|
160 |
-
|
161 |
-
# model_type = "cxrmate-ed"
|
162 |
-
|
163 |
-
# def __init__(self, **kwargs):
|
164 |
-
# super().__init__(**kwargs)
|
165 |
-
|
166 |
-
# if 'decoder' not in kwargs:
|
167 |
-
|
168 |
-
# self.decoder = transformers.LlamaConfig(
|
169 |
-
# vocab_size=30000,
|
170 |
-
# hidden_size=768,
|
171 |
-
# intermediate_size=3072,
|
172 |
-
# num_attention_heads=12,
|
173 |
-
# num_hidden_layers=6,
|
174 |
-
# max_position_embeddings=2048,
|
175 |
-
# )
|
176 |
-
# self.decoder.is_decoder = True
|
177 |
-
|
178 |
-
# self.decoder.index_value_encoder_intermediate_size = 2048
|
179 |
-
# self.decoder.include_time_delta = True
|
180 |
-
# self.decoder.time_delta_monotonic_inversion = True
|
181 |
-
# self.decoder.add_time_deltas = True
|
182 |
-
# self.decoder.history = 0
|
183 |
-
# self.decoder.tables_filter = ["mimic_cxr_sectioned", "triage", "medrecon"]
|
184 |
-
# self.decoder.prompt_report_sections_filter = ["indication", "history"]
|
185 |
-
# self.decoder.pad_token_id = 4
|
186 |
-
|
187 |
-
# else:
|
188 |
-
# self.decoder = kwargs.pop("decoder")
|
189 |
-
|
190 |
-
|
191 |
-
# if 'encoder' not in kwargs:
|
192 |
-
# self.encoder = transformers.AutoConfig.from_pretrained(
|
193 |
-
# 'aehrc/uniformer_base_tl_384',
|
194 |
-
# projection_size=768,
|
195 |
-
# trust_remote_code=True,
|
196 |
-
# )
|
197 |
-
# else:
|
198 |
-
# self.encoder = kwargs.pop("encoder")
|
199 |
-
|
200 |
-
|
201 |
-
# self.is_encoder_decoder = True
|
202 |
-
|
203 |
-
# @classmethod
|
204 |
-
# def from_encoder_decoder_configs(
|
205 |
-
# cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
|
206 |
-
# ) -> PretrainedConfig:
|
207 |
-
|
208 |
-
# logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
|
209 |
-
# decoder_config.is_decoder = True
|
210 |
-
# decoder_config.add_cross_attention = True
|
211 |
-
|
212 |
-
# return cls(encoder=encoder_config, decoder=decoder_config, **kwargs)
|
|
|
46 |
text_config = CONFIG_MAPPING[text_config['model_type']](**text_config)
|
47 |
|
48 |
self.text_config = text_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modelling_cxrmate_ed.py
CHANGED
@@ -8,13 +8,13 @@ import datasets
|
|
8 |
import torch
|
9 |
import transformers
|
10 |
from huggingface_hub import hf_hub_download
|
|
|
11 |
from torch.nn import CrossEntropyLoss
|
12 |
from torch.utils.data import Subset
|
13 |
from torchvision.io import decode_image
|
14 |
-
from
|
15 |
-
from transformers
|
16 |
from transformers.modeling_outputs import ModelOutput, Seq2SeqLMOutput
|
17 |
-
from transformers.modeling_utils import PreTrainedModel
|
18 |
from transformers.utils import check_min_version, logging
|
19 |
|
20 |
from .configuration_cxrmate_ed import CXRMateEDConfig
|
@@ -187,162 +187,39 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
187 |
|
188 |
self.inf_time_delta_value = self.time_delta_map(float('inf'))
|
189 |
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
# decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
|
222 |
-
# Information necessary to initiate the text decoder. Can be either:
|
223 |
-
|
224 |
-
# - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
225 |
-
# - A path to a *directory* containing model weights saved using
|
226 |
-
# [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
227 |
-
# - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
|
228 |
-
# this case, `from_tf` should be set to `True` and a configuration object should be provided as
|
229 |
-
# `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
|
230 |
-
# PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
231 |
-
|
232 |
-
# model_args (remaining positional arguments, *optional*):
|
233 |
-
# All remaning positional arguments will be passed to the underlying model's `__init__` method.
|
234 |
-
|
235 |
-
# kwargs (remaining dictionary of keyword arguments, *optional*):
|
236 |
-
# Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
237 |
-
# `output_attentions=True`).
|
238 |
-
|
239 |
-
# - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
|
240 |
-
# - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
|
241 |
-
# - To update the parent model configuration, do not use a prefix for each configuration parameter.
|
242 |
-
|
243 |
-
# Behaves differently depending on whether a `config` is provided or automatically loaded.
|
244 |
-
|
245 |
-
# Example:
|
246 |
-
|
247 |
-
# ```python
|
248 |
-
# >>> from transformers import VisionEncoderDecoderModel
|
249 |
-
|
250 |
-
# >>> # initialize a vit-bert from a pretrained ViT and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized
|
251 |
-
# >>> model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
252 |
-
# ... "google/vit-base-patch16-224-in21k", "google-bert/bert-base-uncased"
|
253 |
-
# ... )
|
254 |
-
# >>> # saving model after fine-tuning
|
255 |
-
# >>> model.save_pretrained("./vit-bert")
|
256 |
-
# >>> # load fine-tuned model
|
257 |
-
# >>> model = VisionEncoderDecoderModel.from_pretrained("./vit-bert")
|
258 |
-
# ```"""
|
259 |
-
|
260 |
-
# kwargs_encoder = {
|
261 |
-
# argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
|
262 |
-
# }
|
263 |
-
|
264 |
-
# kwargs_decoder = {
|
265 |
-
# argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
266 |
-
# }
|
267 |
-
|
268 |
-
# # remove encoder, decoder kwargs from kwargs
|
269 |
-
# for key in kwargs_encoder.keys():
|
270 |
-
# del kwargs["encoder_" + key]
|
271 |
-
# for key in kwargs_decoder.keys():
|
272 |
-
# del kwargs["decoder_" + key]
|
273 |
-
|
274 |
-
# # Load and initialize the encoder and decoder
|
275 |
-
# # The distinction between encoder and decoder at the model level is made
|
276 |
-
# # by the value of the flag `is_decoder` that we need to set correctly.
|
277 |
-
# encoder = kwargs_encoder.pop("model", None)
|
278 |
-
# if encoder is None:
|
279 |
-
# if encoder_pretrained_model_name_or_path is None:
|
280 |
-
# raise ValueError(
|
281 |
-
# "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
282 |
-
# "to be defined."
|
283 |
-
# )
|
284 |
-
|
285 |
-
# if "config" not in kwargs_encoder:
|
286 |
-
# encoder_config, kwargs_encoder = transformers.AutoConfig.from_pretrained(
|
287 |
-
# encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
|
288 |
-
# )
|
289 |
-
|
290 |
-
# if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
291 |
-
# logger.info(
|
292 |
-
# f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
293 |
-
# "from a decoder model. Cross-attention and casual mask are disabled."
|
294 |
-
# )
|
295 |
-
# encoder_config.is_decoder = False
|
296 |
-
# encoder_config.add_cross_attention = False
|
297 |
-
|
298 |
-
# kwargs_encoder["config"] = encoder_config
|
299 |
-
|
300 |
-
# encoder = transformers.AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
301 |
-
|
302 |
-
# decoder = kwargs_decoder.pop("model", None)
|
303 |
-
# if decoder is None:
|
304 |
-
# if decoder_pretrained_model_name_or_path is None:
|
305 |
-
# raise ValueError(
|
306 |
-
# "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
307 |
-
# "to be defined."
|
308 |
-
# )
|
309 |
-
|
310 |
-
# if "config" not in kwargs_decoder:
|
311 |
-
# decoder_config, kwargs_decoder = transformers.AutoConfig.from_pretrained(
|
312 |
-
# decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
|
313 |
-
# )
|
314 |
-
|
315 |
-
# if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
316 |
-
# logger.info(
|
317 |
-
# f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
|
318 |
-
# f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
|
319 |
-
# f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
320 |
-
# )
|
321 |
-
# decoder_config.is_decoder = True
|
322 |
-
# decoder_config.add_cross_attention = False
|
323 |
-
|
324 |
-
# kwargs_decoder["config"] = decoder_config
|
325 |
-
|
326 |
-
# if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
327 |
-
# logger.warning(
|
328 |
-
# f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
329 |
-
# f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
330 |
-
# "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
331 |
-
# "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
332 |
-
# "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
333 |
-
# )
|
334 |
-
|
335 |
-
# decoder = transformers.AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
336 |
-
|
337 |
-
# # instantiate config with corresponding kwargs
|
338 |
-
# config = CXRMateEDConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
339 |
-
|
340 |
-
# # make sure input & output embeddings is not tied
|
341 |
-
# config.tie_word_embeddings = False
|
342 |
|
343 |
-
|
344 |
-
|
345 |
-
# return cls(encoder=encoder, decoder=decoder, config=config)
|
346 |
|
347 |
def forward(
|
348 |
self,
|
@@ -712,80 +589,7 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
712 |
sections[j].append(section_string)
|
713 |
|
714 |
return tuple(sections.values())
|
715 |
-
|
716 |
-
def tokenize_text_prompt(self, tokenizer: PreTrainedTokenizerFast, **kwargs):
|
717 |
-
"""
|
718 |
-
Tokenize the text columns from MIMIC-IV ED and MIMIC-CXR (excluding the findings and impression sections).
|
719 |
-
Time deltas for the input_ids are also prepared here.
|
720 |
-
|
721 |
-
Argument/s:
|
722 |
-
tokenizer - Hugging Face tokenizer.
|
723 |
-
|
724 |
-
Returns:
|
725 |
-
ed - dictionary containing the input_ids, token_type_ids, attention_mask and time_deltas for the ED module columns.
|
726 |
-
cxr - dictionary containing the input_ids, token_type_ids, and attention_mask for MIMIC-CXR columns.
|
727 |
-
"""
|
728 |
-
|
729 |
-
batch_size = len(kwargs['study_id'])
|
730 |
-
|
731 |
-
tokenized = {
|
732 |
-
'input_ids': {i: [] for i in range(batch_size)},
|
733 |
-
'token_type_ids': {i: [] for i in range(batch_size)},
|
734 |
-
'time_delta': {i: [] for i in range(batch_size)},
|
735 |
-
'attention_mask': torch.empty(batch_size, 0, 1, device=self.device),
|
736 |
-
}
|
737 |
-
|
738 |
-
prompt_text_columns = [f'{k}_{j}' if k != 'mimic_cxr_sectioned' else j for k, v in self.tables.items() if 'text_columns' in v for j in (v['text_columns'] if isinstance(v['text_columns'], list) else [v['text_columns']])] + ['prior_findings', 'prior_impression']
|
739 |
-
|
740 |
-
for i in prompt_text_columns:
|
741 |
-
if i in kwargs:
|
742 |
-
if f'{i}_time_delta' not in kwargs:
|
743 |
-
kwargs[f'{i}_time_delta'] = [[self.zero_time_delta_value for _ in j] if j is not None else None for j in kwargs[i]]
|
744 |
-
for x, (y, z) in enumerate(zip(kwargs[i], kwargs[f'{i}_time_delta'])):
|
745 |
-
if y is not None:
|
746 |
-
assert isinstance(y, list)
|
747 |
-
assert isinstance(z, list)
|
748 |
-
for text, time_delta in zip(y, z):
|
749 |
-
if text is not None:
|
750 |
-
tokenized['input_ids'][x].append(
|
751 |
-
tokenizer(text, add_special_tokens=False, return_tensors='pt')['input_ids'].to(device=self.device)
|
752 |
-
)
|
753 |
-
tokenized['token_type_ids'][x].append(
|
754 |
-
torch.full(
|
755 |
-
(1, tokenized['input_ids'][x][-1].shape[-1]),
|
756 |
-
self.token_type_to_token_type_id[i],
|
757 |
-
dtype=torch.long,
|
758 |
-
device=self.device,
|
759 |
-
)
|
760 |
-
)
|
761 |
-
tokenized['time_delta'][x].append(
|
762 |
-
torch.full(
|
763 |
-
(1, tokenized['input_ids'][x][-1].shape[-1]),
|
764 |
-
time_delta,
|
765 |
-
dtype=torch.float32,
|
766 |
-
device=self.device,
|
767 |
-
)
|
768 |
-
)
|
769 |
|
770 |
-
tokenized['input_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['input_ids'].values()]
|
771 |
-
tokenized['token_type_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['token_type_ids'].values()]
|
772 |
-
tokenized['time_delta'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, device=self.device) for j in tokenized['time_delta'].values()]
|
773 |
-
|
774 |
-
tokenized['input_ids'] = torch.nn.utils.rnn.pad_sequence(
|
775 |
-
tokenized['input_ids'], batch_first=True, padding_value=tokenizer.pad_token_id
|
776 |
-
)[:, :, 0]
|
777 |
-
tokenized['token_type_ids'] = torch.nn.utils.rnn.pad_sequence(
|
778 |
-
tokenized['token_type_ids'], batch_first=True, padding_value=0,
|
779 |
-
)[:, :, 0]
|
780 |
-
|
781 |
-
tokenized['attention_mask'] = (tokenized['input_ids'] != tokenizer.pad_token_id).int()
|
782 |
-
|
783 |
-
tokenized['time_delta'] = torch.nn.utils.rnn.pad_sequence(
|
784 |
-
tokenized['time_delta'], batch_first=True, padding_value=0,
|
785 |
-
)
|
786 |
-
|
787 |
-
return tokenized
|
788 |
-
|
789 |
def prepare_inputs(
|
790 |
self,
|
791 |
images,
|
@@ -914,7 +718,219 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
914 |
assert inputs_embeds.shape[1] == token_type_ids.shape[1]
|
915 |
|
916 |
return inputs_embeds, attention_mask, token_type_ids, position_ids, bos_token_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
917 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
918 |
@staticmethod
|
919 |
def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask, dtype):
|
920 |
|
@@ -983,86 +999,24 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
983 |
mixed_causality_4d_attention_mask[mixed_causality_4d_attention_mask == 1] = 0.0
|
984 |
|
985 |
return mixed_causality_4d_attention_mask
|
986 |
-
|
987 |
-
# @staticmethod
|
988 |
-
# def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask):
|
989 |
-
|
990 |
-
# prompt_seq_len = non_causal_2d_attention_mask.shape[-1]
|
991 |
-
# report_seq_len = causal_2d_attention_mask.shape[-1]
|
992 |
-
|
993 |
-
# non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
|
994 |
-
# causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
|
995 |
-
|
996 |
-
# # Upper left of attention matrix:
|
997 |
-
# upper_left = non_causal_2d_attention_mask.expand(-1, -1, prompt_seq_len, -1)
|
998 |
-
# upper_left = upper_left * non_causal_2d_attention_mask
|
999 |
-
# upper_left = upper_left * non_causal_2d_attention_mask.permute(0, 1, 3, 2)
|
1000 |
-
|
1001 |
-
# causal_mask = torch.tril(
|
1002 |
-
# torch.ones(
|
1003 |
-
# (
|
1004 |
-
# report_seq_len,
|
1005 |
-
# report_seq_len,
|
1006 |
-
# ),
|
1007 |
-
# dtype=torch.long,
|
1008 |
-
# device=causal_2d_attention_mask.device,
|
1009 |
-
# ),
|
1010 |
-
# )
|
1011 |
-
|
1012 |
-
# # Lower right of attention matrix:
|
1013 |
-
# lower_right = causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
|
1014 |
-
# lower_right = lower_right * causal_2d_attention_mask.permute(0, 1, 3, 2)
|
1015 |
-
# lower_right = lower_right * causal_mask
|
1016 |
-
|
1017 |
-
# # Upper right of attention matrix:
|
1018 |
-
# upper_right = torch.zeros(
|
1019 |
-
# causal_2d_attention_mask.shape[0],
|
1020 |
-
# 1,
|
1021 |
-
# prompt_seq_len,
|
1022 |
-
# report_seq_len,
|
1023 |
-
# dtype=torch.long,
|
1024 |
-
# device=causal_2d_attention_mask.device,
|
1025 |
-
# )
|
1026 |
-
|
1027 |
-
# # Lower left of attention matrix:
|
1028 |
-
# lower_left = non_causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
|
1029 |
-
# lower_left = lower_left * causal_2d_attention_mask.permute(0, 1, 3, 2)
|
1030 |
-
|
1031 |
-
# left = torch.cat((upper_left, lower_left), dim=2)
|
1032 |
-
# right = torch.cat((upper_right, lower_right), dim=2)
|
1033 |
|
1034 |
-
|
1035 |
-
|
1036 |
-
|
1037 |
-
|
1038 |
-
|
1039 |
-
|
1040 |
-
|
1041 |
-
# non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
|
1042 |
-
# causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
|
1043 |
-
|
1044 |
-
# mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
|
1045 |
-
# return mixed_causality_4d_attention_mask
|
1046 |
|
1047 |
-
|
1048 |
-
|
1049 |
-
|
1050 |
-
masked_time_deltas = torch.where(attention_mask == 1, time_deltas[:, :, 0], mask_value)
|
1051 |
-
_, col_indices = torch.sort(masked_time_deltas, descending=not self.config.time_delta_monotonic_inversion)
|
1052 |
|
1053 |
-
|
1054 |
|
1055 |
-
|
1056 |
-
position_ids = torch.zeros_like(col_indices, device=time_deltas.device)
|
1057 |
-
position_ids[row_indices, col_indices.flatten()] = torch.arange(num_cols, device=time_deltas.device)[None, :].expand(num_rows, -1).flatten()
|
1058 |
-
position_ids.masked_fill_(attention_mask == 0, 1) # Following: https://github.com/huggingface/transformers/blob/c5f0288bc7d76f65996586f79f69fba8867a0e67/src/transformers/models/llama/modeling_llama.py#L1285
|
1059 |
|
1060 |
-
|
1061 |
-
|
1062 |
-
def get_dataset(self, dataset_path, train_transforms=None, test_transforms=None, max_train_images_per_study=None, study_id_split='mimic_iv_ed_mimic_cxr_jpg', test_set_only=False):
|
1063 |
|
1064 |
-
assert max_train_images_per_study is not None, 'max_train_images_per_study must be defined.'
|
1065 |
-
assert test_transforms is not None, 'test_transforms must be defined.'
|
1066 |
|
1067 |
def train_set_transform(batch):
|
1068 |
|
@@ -1081,7 +1035,7 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
1081 |
|
1082 |
# Sort based on ViewPosition:
|
1083 |
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
|
1084 |
-
batch['images'] = [torch.stack([train_transforms(j) for j in i]) for i in batch['images']]
|
1085 |
max_size = max(i.shape[0] for i in batch['images'])
|
1086 |
|
1087 |
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
|
@@ -1104,7 +1058,7 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
1104 |
|
1105 |
# Sort based on ViewPosition:
|
1106 |
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
|
1107 |
-
batch['images'] = [torch.stack([test_transforms(j) for j in i]) for i in batch['images']]
|
1108 |
max_size = max(i.shape[0] for i in batch['images'])
|
1109 |
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
|
1110 |
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
|
@@ -1177,7 +1131,9 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
1177 |
else:
|
1178 |
return test_set
|
1179 |
|
1180 |
-
def get_stage_1_dataset(self,
|
|
|
|
|
1181 |
|
1182 |
def train_set_transform(batch):
|
1183 |
|
@@ -1192,7 +1148,7 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
1192 |
|
1193 |
# Sort based on ViewPosition:
|
1194 |
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
|
1195 |
-
batch['images'] = [torch.stack([train_transforms(j) for j in i]) for i in batch['images']]
|
1196 |
max_size = max(i.shape[0] for i in batch['images'])
|
1197 |
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
|
1198 |
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
|
@@ -1204,7 +1160,7 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
1204 |
|
1205 |
# Sort based on ViewPosition:
|
1206 |
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
|
1207 |
-
batch['images'] = [torch.stack([test_transforms(j) for j in i]) for i in batch['images']]
|
1208 |
max_size = max(i.shape[0] for i in batch['images'])
|
1209 |
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
|
1210 |
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
|
@@ -1256,138 +1212,4 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
1256 |
test_set = Subset(test_set, indices)
|
1257 |
|
1258 |
return train_set, val_set, test_set
|
1259 |
-
|
1260 |
-
def prepare_index_value_feats(self, table, batch):
|
1261 |
-
|
1262 |
-
index_value_columns = (self.tables[table].get('index_columns', []) + self.tables[table].get('value_columns', []))
|
1263 |
-
index_value_columns = [f'{table}_{i}' for i in index_value_columns] if table != 'mimic_cxr_2_0_0_metadata' else index_value_columns
|
1264 |
-
|
1265 |
-
# Map to indices with lookup table:
|
1266 |
-
if 'index_columns' in self.tables[table]:
|
1267 |
-
for i in self.tables[table]['index_columns']:
|
1268 |
-
k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
|
1269 |
-
batch[k] = [
|
1270 |
-
[self.luts[table][i][str(k)] if k is not None else None for k in j] if j is not None else None for j in batch[k]
|
1271 |
-
]
|
1272 |
-
|
1273 |
-
batch_index_value_feats_list = []
|
1274 |
-
batch_token_type_ids_list = []
|
1275 |
-
batch_time_deltas_list = []
|
1276 |
-
|
1277 |
-
for batch_idx in range(len(batch['study_id'])):
|
1278 |
-
|
1279 |
-
if any([batch[k][batch_idx] for k in index_value_columns]):
|
1280 |
-
|
1281 |
-
num_rows = [len(batch[i][batch_idx]) for i in index_value_columns]
|
1282 |
-
assert all(x == num_rows[0] for x in num_rows)
|
1283 |
-
num_rows = num_rows[0]
|
1284 |
-
|
1285 |
-
# The y-index and the datetime for each group:
|
1286 |
-
if isinstance(batch[self.tables[table]['groupby']][batch_idx], list):
|
1287 |
-
y_indices = [d.setdefault(x, len(d)) for d in [{}] for x in batch[self.tables[table]['groupby']][batch_idx]]
|
1288 |
-
datetime = [j for i, j in enumerate(batch[self.tables[table]['time_column']][batch_idx]) if j not in batch[self.tables[table]['time_column']][batch_idx][:i]]
|
1289 |
-
assert len(set(y_indices)) == len(datetime)
|
1290 |
-
else:
|
1291 |
-
y_indices = [0] * num_rows
|
1292 |
-
datetime = batch[self.tables[table]['time_column']][batch_idx] if 'time_column' in self.tables[table] else [batch['latest_study_datetime'][batch_idx]]
|
1293 |
-
|
1294 |
-
time_deltas = torch.tensor([compute_time_delta(i, batch['latest_study_datetime'][batch_idx], self.time_delta_map, to_tensor=False) for i in datetime])[:, None]
|
1295 |
-
|
1296 |
-
tensor = torch.zeros(max(y_indices) + 1, self.luts[table]['total'])
|
1297 |
-
|
1298 |
-
# Index columns to feats:
|
1299 |
-
if 'index_columns' in self.tables[table]:
|
1300 |
-
|
1301 |
-
for i in self.tables[table]['index_columns']:
|
1302 |
-
k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
|
1303 |
-
y_indices_column = [y_idx for y_idx, x_idx in zip(y_indices, batch[k][batch_idx]) if x_idx is not None]
|
1304 |
-
x_indices_column = [x_idx for x_idx in batch[k][batch_idx] if x_idx is not None]
|
1305 |
-
|
1306 |
-
tensor[y_indices_column, x_indices_column] = 1.0
|
1307 |
-
|
1308 |
-
if 'value_columns' in self.tables[table]:
|
1309 |
-
for i in self.tables[table]['value_columns']:
|
1310 |
-
|
1311 |
-
k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
|
1312 |
-
y_indices_column = [y_idx for y_idx, value in zip(y_indices, batch[k][batch_idx]) if value is not None]
|
1313 |
-
x_indices_column = [self.luts[table][i] for value in batch[k][batch_idx] if value is not None]
|
1314 |
-
values = [value for value in batch[k][batch_idx] if value is not None]
|
1315 |
-
|
1316 |
-
tensor[y_indices_column, x_indices_column] = torch.tensor(values, dtype=tensor.dtype)
|
1317 |
-
assert not torch.isnan(tensor).any()
|
1318 |
-
else:
|
1319 |
-
tensor = torch.empty(0, self.luts[table]['total'])
|
1320 |
-
time_deltas = torch.empty(0, 1)
|
1321 |
-
|
1322 |
-
batch_index_value_feats_list.append(tensor)
|
1323 |
-
batch_token_type_ids_list.append(torch.full(
|
1324 |
-
[tensor.shape[0]],
|
1325 |
-
self.token_type_to_token_type_id[table],
|
1326 |
-
dtype=torch.long,
|
1327 |
-
)
|
1328 |
-
)
|
1329 |
-
batch_time_deltas_list.append(time_deltas)
|
1330 |
-
|
1331 |
-
assert tensor.shape[0] == batch_token_type_ids_list[-1].shape[0]
|
1332 |
-
assert tensor.shape[0] == time_deltas.shape[0]
|
1333 |
-
|
1334 |
-
batch_index_value_feats = torch.nn.utils.rnn.pad_sequence(batch_index_value_feats_list, batch_first=True, padding_value=-1) # Pad value of -1 is not ideal. Need to use something else.
|
1335 |
-
batch_token_type_ids = torch.nn.utils.rnn.pad_sequence(batch_token_type_ids_list, batch_first=True, padding_value=0)
|
1336 |
-
batch_time_deltas = torch.nn.utils.rnn.pad_sequence(batch_time_deltas_list, batch_first=True, padding_value=0)
|
1337 |
-
|
1338 |
-
batch_mask = (batch_index_value_feats != -1).any(dim=-1).int()
|
1339 |
-
|
1340 |
-
return batch_index_value_feats, batch_token_type_ids, batch_time_deltas, batch_mask
|
1341 |
-
|
1342 |
-
def prepare_text_prompt(self, table, column, batch):
|
1343 |
-
|
1344 |
-
key = f'{table}_{column}' if not table == 'mimic_cxr_sectioned' else column
|
1345 |
-
|
1346 |
-
batch_text_list = []
|
1347 |
-
batch_time_deltas_list = []
|
1348 |
-
|
1349 |
-
for batch_idx in range(len(batch['study_id'])):
|
1350 |
-
if batch[key][batch_idx]:
|
1351 |
-
|
1352 |
-
num_rows = len(batch[key][batch_idx])
|
1353 |
-
|
1354 |
-
# The y-index and the datetime for each group:
|
1355 |
-
if isinstance(batch[self.tables[table]['groupby']][batch_idx], list):
|
1356 |
-
y_indices = [d.setdefault(x, len(d)) for d in [{}] for x in batch[self.tables[table]['groupby']][batch_idx]]
|
1357 |
-
datetime = [j for i, j in enumerate(batch[self.tables[table]['time_column']][batch_idx]) if j not in batch[self.tables[table]['time_column']][batch_idx][:i]]
|
1358 |
-
assert len(set(y_indices)) == len(datetime)
|
1359 |
-
else:
|
1360 |
-
y_indices = [0] * num_rows
|
1361 |
-
datetime = batch[self.tables[table]['time_column']][batch_idx] if 'time_column' in self.tables[table] else [batch['latest_study_datetime'][batch_idx]]
|
1362 |
-
|
1363 |
-
# Remove None values:
|
1364 |
-
text_rows = batch[key][batch_idx] if isinstance(batch[key][batch_idx], list) else [batch[key][batch_idx]]
|
1365 |
-
y_indices = [i for i, j in zip(y_indices, text_rows) if j is not None]
|
1366 |
-
text_rows = [i for i in text_rows if i is not None]
|
1367 |
-
datetime = [datetime[i] for i in set(y_indices)]
|
1368 |
-
if text_rows:
|
1369 |
-
|
1370 |
-
# Those in the same group (or those with the same y-index) get joined as the same string:
|
1371 |
-
batch_text_list.append([', '.join([text_rows[j] for j in range(len(y_indices)) if y_indices[j] == k]) + '.' for k in set(y_indices)])
|
1372 |
-
batch_time_deltas_list.append([compute_time_delta(i, batch['latest_study_datetime'][batch_idx], self.time_delta_map, to_tensor=False) for i in datetime])
|
1373 |
-
|
1374 |
-
assert len(batch_time_deltas_list[-1]) == len(batch_text_list[-1])
|
1375 |
-
else:
|
1376 |
-
batch_text_list.append([])
|
1377 |
-
batch_time_deltas_list.append([])
|
1378 |
-
else:
|
1379 |
-
batch_text_list.append([])
|
1380 |
-
batch_time_deltas_list.append([])
|
1381 |
-
|
1382 |
-
return batch_text_list, batch_time_deltas_list
|
1383 |
-
|
1384 |
-
@staticmethod
|
1385 |
-
def collate_fn(batch):
|
1386 |
-
keys = set().union(*(d.keys() for d in batch))
|
1387 |
-
batch = {j: [i.setdefault(j, None) for i in batch] for j in keys}
|
1388 |
-
batch = {k: torch.stack(v) if isinstance(v[0], torch.Tensor) else v for k, v in batch.items()}
|
1389 |
-
return batch
|
1390 |
-
|
1391 |
-
@staticmethod
|
1392 |
-
def prepare_dataset(physionet_dir: str, database_dir: str):
|
1393 |
-
prepare_dataset(physionet_dir=physionet_dir, database_dir=database_dir)
|
|
|
8 |
import torch
|
9 |
import transformers
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
12 |
from torch.nn import CrossEntropyLoss
|
13 |
from torch.utils.data import Subset
|
14 |
from torchvision.io import decode_image
|
15 |
+
from torchvision.transforms import v2
|
16 |
+
from transformers import PreTrainedTokenizerFast
|
17 |
from transformers.modeling_outputs import ModelOutput, Seq2SeqLMOutput
|
|
|
18 |
from transformers.utils import check_min_version, logging
|
19 |
|
20 |
from .configuration_cxrmate_ed import CXRMateEDConfig
|
|
|
187 |
|
188 |
self.inf_time_delta_value = self.time_delta_map(float('inf'))
|
189 |
|
190 |
+
# Image transformations:
|
191 |
+
self.train_transforms = v2.Compose(
|
192 |
+
[
|
193 |
+
v2.Grayscale(num_output_channels=3),
|
194 |
+
v2.Resize(
|
195 |
+
size=self.config.vision_config.image_size,
|
196 |
+
antialias=True,
|
197 |
+
interpolation=v2.InterpolationMode.BICUBIC,
|
198 |
+
),
|
199 |
+
v2.RandomCrop(
|
200 |
+
size=[self.config.vision_config.image_size, self.config.vision_config.image_size],
|
201 |
+
pad_if_needed=True,
|
202 |
+
),
|
203 |
+
v2.RandomRotation(degrees=5),
|
204 |
+
v2.ToDtype(torch.float32, scale=True),
|
205 |
+
v2.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
206 |
+
]
|
207 |
+
)
|
208 |
+
self.test_transforms = v2.Compose(
|
209 |
+
[
|
210 |
+
v2.Grayscale(num_output_channels=3),
|
211 |
+
v2.Resize(
|
212 |
+
size=self.config.vision_config.image_size,
|
213 |
+
antialias=True,
|
214 |
+
interpolation=v2.InterpolationMode.BICUBIC,
|
215 |
+
),
|
216 |
+
v2.CenterCrop(size=[self.config.vision_config.image_size, self.config.vision_config.image_size]),
|
217 |
+
v2.ToDtype(torch.float32, scale=True),
|
218 |
+
v2.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
219 |
+
]
|
220 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
+
self.post_init()
|
|
|
|
|
223 |
|
224 |
def forward(
|
225 |
self,
|
|
|
589 |
sections[j].append(section_string)
|
590 |
|
591 |
return tuple(sections.values())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
592 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
def prepare_inputs(
|
594 |
self,
|
595 |
images,
|
|
|
718 |
assert inputs_embeds.shape[1] == token_type_ids.shape[1]
|
719 |
|
720 |
return inputs_embeds, attention_mask, token_type_ids, position_ids, bos_token_ids
|
721 |
+
|
722 |
+
def tokenize_text_prompt(self, tokenizer: PreTrainedTokenizerFast, **kwargs):
|
723 |
+
"""
|
724 |
+
Tokenize the text columns from MIMIC-IV ED and MIMIC-CXR (excluding the findings and impression sections).
|
725 |
+
Time deltas for the input_ids are also prepared here.
|
726 |
+
|
727 |
+
Argument/s:
|
728 |
+
tokenizer - Hugging Face tokenizer.
|
729 |
+
|
730 |
+
Returns:
|
731 |
+
ed - dictionary containing the input_ids, token_type_ids, attention_mask and time_deltas for the ED module columns.
|
732 |
+
cxr - dictionary containing the input_ids, token_type_ids, and attention_mask for MIMIC-CXR columns.
|
733 |
+
"""
|
734 |
+
|
735 |
+
batch_size = len(kwargs['study_id'])
|
736 |
+
|
737 |
+
tokenized = {
|
738 |
+
'input_ids': {i: [] for i in range(batch_size)},
|
739 |
+
'token_type_ids': {i: [] for i in range(batch_size)},
|
740 |
+
'time_delta': {i: [] for i in range(batch_size)},
|
741 |
+
'attention_mask': torch.empty(batch_size, 0, 1, device=self.device),
|
742 |
+
}
|
743 |
+
|
744 |
+
prompt_text_columns = [f'{k}_{j}' if k != 'mimic_cxr_sectioned' else j for k, v in self.tables.items() if 'text_columns' in v for j in (v['text_columns'] if isinstance(v['text_columns'], list) else [v['text_columns']])] + ['prior_findings', 'prior_impression']
|
745 |
+
|
746 |
+
for i in prompt_text_columns:
|
747 |
+
if i in kwargs:
|
748 |
+
if f'{i}_time_delta' not in kwargs:
|
749 |
+
kwargs[f'{i}_time_delta'] = [[self.zero_time_delta_value for _ in j] if j is not None else None for j in kwargs[i]]
|
750 |
+
for x, (y, z) in enumerate(zip(kwargs[i], kwargs[f'{i}_time_delta'])):
|
751 |
+
if y is not None:
|
752 |
+
assert isinstance(y, list)
|
753 |
+
assert isinstance(z, list)
|
754 |
+
for text, time_delta in zip(y, z):
|
755 |
+
if text is not None:
|
756 |
+
tokenized['input_ids'][x].append(
|
757 |
+
tokenizer(text, add_special_tokens=False, return_tensors='pt')['input_ids'].to(device=self.device)
|
758 |
+
)
|
759 |
+
tokenized['token_type_ids'][x].append(
|
760 |
+
torch.full(
|
761 |
+
(1, tokenized['input_ids'][x][-1].shape[-1]),
|
762 |
+
self.token_type_to_token_type_id[i],
|
763 |
+
dtype=torch.long,
|
764 |
+
device=self.device,
|
765 |
+
)
|
766 |
+
)
|
767 |
+
tokenized['time_delta'][x].append(
|
768 |
+
torch.full(
|
769 |
+
(1, tokenized['input_ids'][x][-1].shape[-1]),
|
770 |
+
time_delta,
|
771 |
+
dtype=torch.float32,
|
772 |
+
device=self.device,
|
773 |
+
)
|
774 |
+
)
|
775 |
+
|
776 |
+
tokenized['input_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['input_ids'].values()]
|
777 |
+
tokenized['token_type_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['token_type_ids'].values()]
|
778 |
+
tokenized['time_delta'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, device=self.device) for j in tokenized['time_delta'].values()]
|
779 |
+
|
780 |
+
tokenized['input_ids'] = torch.nn.utils.rnn.pad_sequence(
|
781 |
+
tokenized['input_ids'], batch_first=True, padding_value=tokenizer.pad_token_id
|
782 |
+
)[:, :, 0]
|
783 |
+
tokenized['token_type_ids'] = torch.nn.utils.rnn.pad_sequence(
|
784 |
+
tokenized['token_type_ids'], batch_first=True, padding_value=0,
|
785 |
+
)[:, :, 0]
|
786 |
+
|
787 |
+
tokenized['attention_mask'] = (tokenized['input_ids'] != tokenizer.pad_token_id).int()
|
788 |
+
|
789 |
+
tokenized['time_delta'] = torch.nn.utils.rnn.pad_sequence(
|
790 |
+
tokenized['time_delta'], batch_first=True, padding_value=0,
|
791 |
+
)
|
792 |
+
|
793 |
+
return tokenized
|
794 |
|
795 |
+
def position_ids_from_time_deltas_and_attention_mask(self, time_deltas, attention_mask):
|
796 |
+
mask_value = torch.finfo(time_deltas.dtype).max if self.config.time_delta_monotonic_inversion else torch.finfo(time_deltas.dtype).min
|
797 |
+
|
798 |
+
masked_time_deltas = torch.where(attention_mask == 1, time_deltas[:, :, 0], mask_value)
|
799 |
+
_, col_indices = torch.sort(masked_time_deltas, descending=not self.config.time_delta_monotonic_inversion)
|
800 |
+
|
801 |
+
num_rows, num_cols, _ = time_deltas.shape
|
802 |
+
|
803 |
+
row_indices = torch.arange(num_rows, device=time_deltas.device).view(-1, 1).repeat(1, num_cols).view(-1)
|
804 |
+
position_ids = torch.zeros_like(col_indices, device=time_deltas.device)
|
805 |
+
position_ids[row_indices, col_indices.flatten()] = torch.arange(num_cols, device=time_deltas.device)[None, :].expand(num_rows, -1).flatten()
|
806 |
+
position_ids.masked_fill_(attention_mask == 0, 1) # Following: https://github.com/huggingface/transformers/blob/c5f0288bc7d76f65996586f79f69fba8867a0e67/src/transformers/models/llama/modeling_llama.py#L1285
|
807 |
+
|
808 |
+
return position_ids
|
809 |
+
|
810 |
+
def prepare_index_value_feats(self, table, batch):
|
811 |
+
|
812 |
+
index_value_columns = (self.tables[table].get('index_columns', []) + self.tables[table].get('value_columns', []))
|
813 |
+
index_value_columns = [f'{table}_{i}' for i in index_value_columns] if table != 'mimic_cxr_2_0_0_metadata' else index_value_columns
|
814 |
+
|
815 |
+
# Map to indices with lookup table:
|
816 |
+
if 'index_columns' in self.tables[table]:
|
817 |
+
for i in self.tables[table]['index_columns']:
|
818 |
+
k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
|
819 |
+
batch[k] = [
|
820 |
+
[self.luts[table][i][str(k)] if k is not None else None for k in j] if j is not None else None for j in batch[k]
|
821 |
+
]
|
822 |
+
|
823 |
+
batch_index_value_feats_list = []
|
824 |
+
batch_token_type_ids_list = []
|
825 |
+
batch_time_deltas_list = []
|
826 |
+
|
827 |
+
for batch_idx in range(len(batch['study_id'])):
|
828 |
+
|
829 |
+
if any([batch[k][batch_idx] for k in index_value_columns]):
|
830 |
+
|
831 |
+
num_rows = [len(batch[i][batch_idx]) for i in index_value_columns]
|
832 |
+
assert all(x == num_rows[0] for x in num_rows)
|
833 |
+
num_rows = num_rows[0]
|
834 |
+
|
835 |
+
# The y-index and the datetime for each group:
|
836 |
+
if isinstance(batch[self.tables[table]['groupby']][batch_idx], list):
|
837 |
+
y_indices = [d.setdefault(x, len(d)) for d in [{}] for x in batch[self.tables[table]['groupby']][batch_idx]]
|
838 |
+
datetime = [j for i, j in enumerate(batch[self.tables[table]['time_column']][batch_idx]) if j not in batch[self.tables[table]['time_column']][batch_idx][:i]]
|
839 |
+
assert len(set(y_indices)) == len(datetime)
|
840 |
+
else:
|
841 |
+
y_indices = [0] * num_rows
|
842 |
+
datetime = batch[self.tables[table]['time_column']][batch_idx] if 'time_column' in self.tables[table] else [batch['latest_study_datetime'][batch_idx]]
|
843 |
+
|
844 |
+
time_deltas = torch.tensor([compute_time_delta(i, batch['latest_study_datetime'][batch_idx], self.time_delta_map, to_tensor=False) for i in datetime])[:, None]
|
845 |
+
|
846 |
+
tensor = torch.zeros(max(y_indices) + 1, self.luts[table]['total'])
|
847 |
+
|
848 |
+
# Index columns to feats:
|
849 |
+
if 'index_columns' in self.tables[table]:
|
850 |
+
|
851 |
+
for i in self.tables[table]['index_columns']:
|
852 |
+
k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
|
853 |
+
y_indices_column = [y_idx for y_idx, x_idx in zip(y_indices, batch[k][batch_idx]) if x_idx is not None]
|
854 |
+
x_indices_column = [x_idx for x_idx in batch[k][batch_idx] if x_idx is not None]
|
855 |
+
|
856 |
+
tensor[y_indices_column, x_indices_column] = 1.0
|
857 |
+
|
858 |
+
if 'value_columns' in self.tables[table]:
|
859 |
+
for i in self.tables[table]['value_columns']:
|
860 |
+
|
861 |
+
k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
|
862 |
+
y_indices_column = [y_idx for y_idx, value in zip(y_indices, batch[k][batch_idx]) if value is not None]
|
863 |
+
x_indices_column = [self.luts[table][i] for value in batch[k][batch_idx] if value is not None]
|
864 |
+
values = [value for value in batch[k][batch_idx] if value is not None]
|
865 |
+
|
866 |
+
tensor[y_indices_column, x_indices_column] = torch.tensor(values, dtype=tensor.dtype)
|
867 |
+
assert not torch.isnan(tensor).any()
|
868 |
+
else:
|
869 |
+
tensor = torch.empty(0, self.luts[table]['total'])
|
870 |
+
time_deltas = torch.empty(0, 1)
|
871 |
+
|
872 |
+
batch_index_value_feats_list.append(tensor)
|
873 |
+
batch_token_type_ids_list.append(torch.full(
|
874 |
+
[tensor.shape[0]],
|
875 |
+
self.token_type_to_token_type_id[table],
|
876 |
+
dtype=torch.long,
|
877 |
+
)
|
878 |
+
)
|
879 |
+
batch_time_deltas_list.append(time_deltas)
|
880 |
+
|
881 |
+
assert tensor.shape[0] == batch_token_type_ids_list[-1].shape[0]
|
882 |
+
assert tensor.shape[0] == time_deltas.shape[0]
|
883 |
+
|
884 |
+
batch_index_value_feats = torch.nn.utils.rnn.pad_sequence(batch_index_value_feats_list, batch_first=True, padding_value=-1) # Pad value of -1 is not ideal. Need to use something else.
|
885 |
+
batch_token_type_ids = torch.nn.utils.rnn.pad_sequence(batch_token_type_ids_list, batch_first=True, padding_value=0)
|
886 |
+
batch_time_deltas = torch.nn.utils.rnn.pad_sequence(batch_time_deltas_list, batch_first=True, padding_value=0)
|
887 |
+
|
888 |
+
batch_mask = (batch_index_value_feats != -1).any(dim=-1).int()
|
889 |
+
|
890 |
+
return batch_index_value_feats, batch_token_type_ids, batch_time_deltas, batch_mask
|
891 |
+
|
892 |
+
def prepare_text_prompt(self, table, column, batch):
|
893 |
+
|
894 |
+
key = f'{table}_{column}' if not table == 'mimic_cxr_sectioned' else column
|
895 |
+
|
896 |
+
batch_text_list = []
|
897 |
+
batch_time_deltas_list = []
|
898 |
+
|
899 |
+
for batch_idx in range(len(batch['study_id'])):
|
900 |
+
if batch[key][batch_idx]:
|
901 |
+
|
902 |
+
num_rows = len(batch[key][batch_idx])
|
903 |
+
|
904 |
+
# The y-index and the datetime for each group:
|
905 |
+
if isinstance(batch[self.tables[table]['groupby']][batch_idx], list):
|
906 |
+
y_indices = [d.setdefault(x, len(d)) for d in [{}] for x in batch[self.tables[table]['groupby']][batch_idx]]
|
907 |
+
datetime = [j for i, j in enumerate(batch[self.tables[table]['time_column']][batch_idx]) if j not in batch[self.tables[table]['time_column']][batch_idx][:i]]
|
908 |
+
assert len(set(y_indices)) == len(datetime)
|
909 |
+
else:
|
910 |
+
y_indices = [0] * num_rows
|
911 |
+
datetime = batch[self.tables[table]['time_column']][batch_idx] if 'time_column' in self.tables[table] else [batch['latest_study_datetime'][batch_idx]]
|
912 |
+
|
913 |
+
# Remove None values:
|
914 |
+
text_rows = batch[key][batch_idx] if isinstance(batch[key][batch_idx], list) else [batch[key][batch_idx]]
|
915 |
+
y_indices = [i for i, j in zip(y_indices, text_rows) if j is not None]
|
916 |
+
text_rows = [i for i in text_rows if i is not None]
|
917 |
+
datetime = [datetime[i] for i in set(y_indices)]
|
918 |
+
if text_rows:
|
919 |
+
|
920 |
+
# Those in the same group (or those with the same y-index) get joined as the same string:
|
921 |
+
batch_text_list.append([', '.join([text_rows[j] for j in range(len(y_indices)) if y_indices[j] == k]) + '.' for k in set(y_indices)])
|
922 |
+
batch_time_deltas_list.append([compute_time_delta(i, batch['latest_study_datetime'][batch_idx], self.time_delta_map, to_tensor=False) for i in datetime])
|
923 |
+
|
924 |
+
assert len(batch_time_deltas_list[-1]) == len(batch_text_list[-1])
|
925 |
+
else:
|
926 |
+
batch_text_list.append([])
|
927 |
+
batch_time_deltas_list.append([])
|
928 |
+
else:
|
929 |
+
batch_text_list.append([])
|
930 |
+
batch_time_deltas_list.append([])
|
931 |
+
|
932 |
+
return batch_text_list, batch_time_deltas_list
|
933 |
+
|
934 |
@staticmethod
|
935 |
def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask, dtype):
|
936 |
|
|
|
999 |
mixed_causality_4d_attention_mask[mixed_causality_4d_attention_mask == 1] = 0.0
|
1000 |
|
1001 |
return mixed_causality_4d_attention_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1002 |
|
1003 |
+
@staticmethod
|
1004 |
+
def collate_fn(batch):
|
1005 |
+
keys = set().union(*(d.keys() for d in batch))
|
1006 |
+
batch = {j: [i.setdefault(j, None) for i in batch] for j in keys}
|
1007 |
+
batch = {k: torch.stack(v) if isinstance(v[0], torch.Tensor) else v for k, v in batch.items()}
|
1008 |
+
return batch
|
|
|
|
|
|
|
|
|
|
|
|
|
1009 |
|
1010 |
+
@staticmethod
|
1011 |
+
def prepare_dataset(physionet_dir: str, database_dir: str):
|
|
|
|
|
|
|
1012 |
|
1013 |
+
prepare_dataset(physionet_dir=physionet_dir, database_dir=database_dir)
|
1014 |
|
1015 |
+
def get_dataset(self, database_dir, max_train_images_per_study=None, study_id_split='mimic_iv_ed_mimic_cxr_jpg', test_set_only=False):
|
|
|
|
|
|
|
1016 |
|
1017 |
+
dataset_path = os.path.join(database_dir, 'mimic_iv_ed_mimic_cxr_jpg_dataset')
|
|
|
|
|
1018 |
|
1019 |
+
assert max_train_images_per_study is not None or test_set_only, 'max_train_images_per_study must be defined if training.'
|
|
|
1020 |
|
1021 |
def train_set_transform(batch):
|
1022 |
|
|
|
1035 |
|
1036 |
# Sort based on ViewPosition:
|
1037 |
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
|
1038 |
+
batch['images'] = [torch.stack([self.train_transforms(j) for j in i]) for i in batch['images']]
|
1039 |
max_size = max(i.shape[0] for i in batch['images'])
|
1040 |
|
1041 |
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
|
|
|
1058 |
|
1059 |
# Sort based on ViewPosition:
|
1060 |
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
|
1061 |
+
batch['images'] = [torch.stack([self.test_transforms(j) for j in i]) for i in batch['images']]
|
1062 |
max_size = max(i.shape[0] for i in batch['images'])
|
1063 |
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
|
1064 |
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
|
|
|
1131 |
else:
|
1132 |
return test_set
|
1133 |
|
1134 |
+
def get_stage_1_dataset(self, database_dir, max_train_images_per_study):
|
1135 |
+
|
1136 |
+
dataset_path = os.path.join(database_dir, 'mimic_iv_ed_mimic_cxr_jpg_dataset')
|
1137 |
|
1138 |
def train_set_transform(batch):
|
1139 |
|
|
|
1148 |
|
1149 |
# Sort based on ViewPosition:
|
1150 |
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
|
1151 |
+
batch['images'] = [torch.stack([self.train_transforms(j) for j in i]) for i in batch['images']]
|
1152 |
max_size = max(i.shape[0] for i in batch['images'])
|
1153 |
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
|
1154 |
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
|
|
|
1160 |
|
1161 |
# Sort based on ViewPosition:
|
1162 |
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
|
1163 |
+
batch['images'] = [torch.stack([self.test_transforms(j) for j in i]) for i in batch['images']]
|
1164 |
max_size = max(i.shape[0] for i in batch['images'])
|
1165 |
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
|
1166 |
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
|
|
|
1212 |
test_set = Subset(test_set, indices)
|
1213 |
|
1214 |
return train_set, val_set, test_set
|
1215 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|