anicolson commited on
Commit
43d63f0
1 Parent(s): 8d6b94b

Upload model

Browse files
Files changed (2) hide show
  1. configuration_cxrmate_ed.py +0 -164
  2. modelling_cxrmate_ed.py +267 -445
configuration_cxrmate_ed.py CHANGED
@@ -46,167 +46,3 @@ class CXRMateEDConfig(transformers.PretrainedConfig):
46
  text_config = CONFIG_MAPPING[text_config['model_type']](**text_config)
47
 
48
  self.text_config = text_config
49
-
50
- # class CXRMateEDConfig(transformers.PretrainedConfig):
51
-
52
- # model_type = 'cxrmate-ed'
53
-
54
- # # def __init__(
55
- # # self,
56
- # # index_value_encoder_intermediate_size: int = 2048,
57
- # # include_time_delta: bool = True,
58
- # # time_delta_monotonic_inversion: bool = True,
59
- # # add_time_deltas: bool = True,
60
- # # history: int = 0,
61
- # # tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'],
62
- # # prompt_report_sections_filter: list = ['indication', 'history'],
63
- # # pad_token_id: int = 4,
64
- # # **kwargs: Any,
65
- # # ) -> None:
66
- # # super().__init__(**kwargs)
67
- # # self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size
68
- # # self.include_time_delta = include_time_delta
69
- # # self.time_delta_monotonic_inversion = time_delta_monotonic_inversion
70
- # # self.add_time_deltas = add_time_deltas
71
- # # self.history = history
72
- # # self.tables_filter = tables_filter
73
- # # self.prompt_report_sections_filter = prompt_report_sections_filter
74
- # # self.pad_token_id = pad_token_id
75
-
76
- # # self.hidden_size = self.text_config.hidden_size
77
-
78
- # def __init__(
79
- # self,
80
- # vision_config=None,
81
- # text_config=None,
82
- # # ignore_index=-100,
83
- # # image_token_index=32000,
84
- # # projector_hidden_act="gelu",
85
- # # vision_feature_select_strategy="default",
86
- # # vision_feature_layer=-2,
87
- # # image_seq_length=576,
88
- # index_value_encoder_intermediate_size: int = 2048,
89
- # include_time_delta: bool = True,
90
- # time_delta_monotonic_inversion: bool = True,
91
- # add_time_deltas: bool = True,
92
- # history: int = 0,
93
- # tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'],
94
- # prompt_report_sections_filter: list = ['indication', 'history'],
95
- # pad_token_id: int = 4,
96
- # **kwargs,
97
- # ):
98
- # transformers.PretrainedConfig.__init__(self, **kwargs)
99
-
100
- # self.vision_config = vision_config
101
- # self.text_config = text_config
102
- # self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size
103
- # self.include_time_delta = include_time_delta
104
- # self.time_delta_monotonic_inversion = time_delta_monotonic_inversion
105
- # self.add_time_deltas = add_time_deltas
106
- # self.history = history
107
- # self.tables_filter = tables_filter
108
- # self.prompt_report_sections_filter = prompt_report_sections_filter
109
- # self.pad_token_id = pad_token_id
110
-
111
- # self.ignore_index = ignore_index
112
- # self.image_token_index = image_token_index
113
- # self.projector_hidden_act = projector_hidden_act
114
- # self.image_seq_length = image_seq_length
115
-
116
- # if vision_feature_select_strategy not in ["default", "full"]:
117
- # raise ValueError(
118
- # "vision_feature_select_strategy should be one of 'default', 'full'."
119
- # f"Got: {vision_feature_select_strategy}"
120
- # )
121
-
122
- # self.vision_feature_select_strategy = vision_feature_select_strategy
123
- # self.vision_feature_layer = vision_feature_layer
124
-
125
- # if isinstance(vision_config, dict):
126
- # vision_config["model_type"] = (
127
- # vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
128
- # )
129
- # vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
130
- # elif vision_config is None:
131
- # vision_config = CONFIG_MAPPING["clip_vision_model"](
132
- # intermediate_size=4096,
133
- # hidden_size=1024,
134
- # patch_size=14,
135
- # image_size=336,
136
- # num_hidden_layers=24,
137
- # num_attention_heads=16,
138
- # vocab_size=32000,
139
- # projection_dim=768,
140
- # )
141
-
142
-
143
- # if isinstance(text_config, dict):
144
- # text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
145
- # text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
146
- # elif text_config is None:
147
- # text_config = CONFIG_MAPPING["llama"]()
148
-
149
- # super().__init__(**kwargs)
150
-
151
-
152
- # import transformers
153
- # from transformers.configuration_utils import PretrainedConfig
154
- # from transformers.utils import logging
155
-
156
- # logger = logging.get_logger(__name__)
157
-
158
-
159
- # class CXRMateEDConfig(PretrainedConfig):
160
-
161
- # model_type = "cxrmate-ed"
162
-
163
- # def __init__(self, **kwargs):
164
- # super().__init__(**kwargs)
165
-
166
- # if 'decoder' not in kwargs:
167
-
168
- # self.decoder = transformers.LlamaConfig(
169
- # vocab_size=30000,
170
- # hidden_size=768,
171
- # intermediate_size=3072,
172
- # num_attention_heads=12,
173
- # num_hidden_layers=6,
174
- # max_position_embeddings=2048,
175
- # )
176
- # self.decoder.is_decoder = True
177
-
178
- # self.decoder.index_value_encoder_intermediate_size = 2048
179
- # self.decoder.include_time_delta = True
180
- # self.decoder.time_delta_monotonic_inversion = True
181
- # self.decoder.add_time_deltas = True
182
- # self.decoder.history = 0
183
- # self.decoder.tables_filter = ["mimic_cxr_sectioned", "triage", "medrecon"]
184
- # self.decoder.prompt_report_sections_filter = ["indication", "history"]
185
- # self.decoder.pad_token_id = 4
186
-
187
- # else:
188
- # self.decoder = kwargs.pop("decoder")
189
-
190
-
191
- # if 'encoder' not in kwargs:
192
- # self.encoder = transformers.AutoConfig.from_pretrained(
193
- # 'aehrc/uniformer_base_tl_384',
194
- # projection_size=768,
195
- # trust_remote_code=True,
196
- # )
197
- # else:
198
- # self.encoder = kwargs.pop("encoder")
199
-
200
-
201
- # self.is_encoder_decoder = True
202
-
203
- # @classmethod
204
- # def from_encoder_decoder_configs(
205
- # cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
206
- # ) -> PretrainedConfig:
207
-
208
- # logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
209
- # decoder_config.is_decoder = True
210
- # decoder_config.add_cross_attention = True
211
-
212
- # return cls(encoder=encoder_config, decoder=decoder_config, **kwargs)
 
46
  text_config = CONFIG_MAPPING[text_config['model_type']](**text_config)
47
 
48
  self.text_config = text_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelling_cxrmate_ed.py CHANGED
@@ -8,13 +8,13 @@ import datasets
8
  import torch
9
  import transformers
10
  from huggingface_hub import hf_hub_download
 
11
  from torch.nn import CrossEntropyLoss
12
  from torch.utils.data import Subset
13
  from torchvision.io import decode_image
14
- from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
15
- from transformers.configuration_utils import PretrainedConfig
16
  from transformers.modeling_outputs import ModelOutput, Seq2SeqLMOutput
17
- from transformers.modeling_utils import PreTrainedModel
18
  from transformers.utils import check_min_version, logging
19
 
20
  from .configuration_cxrmate_ed import CXRMateEDConfig
@@ -187,162 +187,39 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
187
 
188
  self.inf_time_delta_value = self.time_delta_map(float('inf'))
189
 
190
- self.post_init()
191
-
192
- # @classmethod
193
- # def from_encoder_decoder_pretrained(
194
- # cls,
195
- # encoder_pretrained_model_name_or_path: str = None,
196
- # decoder_pretrained_model_name_or_path: str = None,
197
- # *model_args,
198
- # **kwargs,
199
- # ) -> PreTrainedModel:
200
- # r"""
201
- # Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
202
- # checkpoints.
203
-
204
-
205
- # The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
206
- # the model, you need to first set it back in training mode with `model.train()`.
207
-
208
- # Params:
209
- # encoder_pretrained_model_name_or_path (`str`, *optional*):
210
- # Information necessary to initiate the image encoder. Can be either:
211
-
212
- # - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. An
213
- # example is `google/vit-base-patch16-224-in21k`.
214
- # - A path to a *directory* containing model weights saved using
215
- # [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
216
- # - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
217
- # this case, `from_tf` should be set to `True` and a configuration object should be provided as
218
- # `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
219
- # PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
220
-
221
- # decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
222
- # Information necessary to initiate the text decoder. Can be either:
223
-
224
- # - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
225
- # - A path to a *directory* containing model weights saved using
226
- # [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
227
- # - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
228
- # this case, `from_tf` should be set to `True` and a configuration object should be provided as
229
- # `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
230
- # PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
231
-
232
- # model_args (remaining positional arguments, *optional*):
233
- # All remaning positional arguments will be passed to the underlying model's `__init__` method.
234
-
235
- # kwargs (remaining dictionary of keyword arguments, *optional*):
236
- # Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
237
- # `output_attentions=True`).
238
-
239
- # - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
240
- # - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
241
- # - To update the parent model configuration, do not use a prefix for each configuration parameter.
242
-
243
- # Behaves differently depending on whether a `config` is provided or automatically loaded.
244
-
245
- # Example:
246
-
247
- # ```python
248
- # >>> from transformers import VisionEncoderDecoderModel
249
-
250
- # >>> # initialize a vit-bert from a pretrained ViT and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized
251
- # >>> model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
252
- # ... "google/vit-base-patch16-224-in21k", "google-bert/bert-base-uncased"
253
- # ... )
254
- # >>> # saving model after fine-tuning
255
- # >>> model.save_pretrained("./vit-bert")
256
- # >>> # load fine-tuned model
257
- # >>> model = VisionEncoderDecoderModel.from_pretrained("./vit-bert")
258
- # ```"""
259
-
260
- # kwargs_encoder = {
261
- # argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
262
- # }
263
-
264
- # kwargs_decoder = {
265
- # argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
266
- # }
267
-
268
- # # remove encoder, decoder kwargs from kwargs
269
- # for key in kwargs_encoder.keys():
270
- # del kwargs["encoder_" + key]
271
- # for key in kwargs_decoder.keys():
272
- # del kwargs["decoder_" + key]
273
-
274
- # # Load and initialize the encoder and decoder
275
- # # The distinction between encoder and decoder at the model level is made
276
- # # by the value of the flag `is_decoder` that we need to set correctly.
277
- # encoder = kwargs_encoder.pop("model", None)
278
- # if encoder is None:
279
- # if encoder_pretrained_model_name_or_path is None:
280
- # raise ValueError(
281
- # "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
282
- # "to be defined."
283
- # )
284
-
285
- # if "config" not in kwargs_encoder:
286
- # encoder_config, kwargs_encoder = transformers.AutoConfig.from_pretrained(
287
- # encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
288
- # )
289
-
290
- # if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
291
- # logger.info(
292
- # f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
293
- # "from a decoder model. Cross-attention and casual mask are disabled."
294
- # )
295
- # encoder_config.is_decoder = False
296
- # encoder_config.add_cross_attention = False
297
-
298
- # kwargs_encoder["config"] = encoder_config
299
-
300
- # encoder = transformers.AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
301
-
302
- # decoder = kwargs_decoder.pop("model", None)
303
- # if decoder is None:
304
- # if decoder_pretrained_model_name_or_path is None:
305
- # raise ValueError(
306
- # "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
307
- # "to be defined."
308
- # )
309
-
310
- # if "config" not in kwargs_decoder:
311
- # decoder_config, kwargs_decoder = transformers.AutoConfig.from_pretrained(
312
- # decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
313
- # )
314
-
315
- # if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
316
- # logger.info(
317
- # f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
318
- # f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
319
- # f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
320
- # )
321
- # decoder_config.is_decoder = True
322
- # decoder_config.add_cross_attention = False
323
-
324
- # kwargs_decoder["config"] = decoder_config
325
-
326
- # if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
327
- # logger.warning(
328
- # f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
329
- # f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
330
- # "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
331
- # "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
332
- # "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
333
- # )
334
-
335
- # decoder = transformers.AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
336
-
337
- # # instantiate config with corresponding kwargs
338
- # config = CXRMateEDConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
339
-
340
- # # make sure input & output embeddings is not tied
341
- # config.tie_word_embeddings = False
342
 
343
- # config.is_encoder_decoder = False
344
-
345
- # return cls(encoder=encoder, decoder=decoder, config=config)
346
 
347
  def forward(
348
  self,
@@ -712,80 +589,7 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
712
  sections[j].append(section_string)
713
 
714
  return tuple(sections.values())
715
-
716
- def tokenize_text_prompt(self, tokenizer: PreTrainedTokenizerFast, **kwargs):
717
- """
718
- Tokenize the text columns from MIMIC-IV ED and MIMIC-CXR (excluding the findings and impression sections).
719
- Time deltas for the input_ids are also prepared here.
720
-
721
- Argument/s:
722
- tokenizer - Hugging Face tokenizer.
723
-
724
- Returns:
725
- ed - dictionary containing the input_ids, token_type_ids, attention_mask and time_deltas for the ED module columns.
726
- cxr - dictionary containing the input_ids, token_type_ids, and attention_mask for MIMIC-CXR columns.
727
- """
728
-
729
- batch_size = len(kwargs['study_id'])
730
-
731
- tokenized = {
732
- 'input_ids': {i: [] for i in range(batch_size)},
733
- 'token_type_ids': {i: [] for i in range(batch_size)},
734
- 'time_delta': {i: [] for i in range(batch_size)},
735
- 'attention_mask': torch.empty(batch_size, 0, 1, device=self.device),
736
- }
737
-
738
- prompt_text_columns = [f'{k}_{j}' if k != 'mimic_cxr_sectioned' else j for k, v in self.tables.items() if 'text_columns' in v for j in (v['text_columns'] if isinstance(v['text_columns'], list) else [v['text_columns']])] + ['prior_findings', 'prior_impression']
739
-
740
- for i in prompt_text_columns:
741
- if i in kwargs:
742
- if f'{i}_time_delta' not in kwargs:
743
- kwargs[f'{i}_time_delta'] = [[self.zero_time_delta_value for _ in j] if j is not None else None for j in kwargs[i]]
744
- for x, (y, z) in enumerate(zip(kwargs[i], kwargs[f'{i}_time_delta'])):
745
- if y is not None:
746
- assert isinstance(y, list)
747
- assert isinstance(z, list)
748
- for text, time_delta in zip(y, z):
749
- if text is not None:
750
- tokenized['input_ids'][x].append(
751
- tokenizer(text, add_special_tokens=False, return_tensors='pt')['input_ids'].to(device=self.device)
752
- )
753
- tokenized['token_type_ids'][x].append(
754
- torch.full(
755
- (1, tokenized['input_ids'][x][-1].shape[-1]),
756
- self.token_type_to_token_type_id[i],
757
- dtype=torch.long,
758
- device=self.device,
759
- )
760
- )
761
- tokenized['time_delta'][x].append(
762
- torch.full(
763
- (1, tokenized['input_ids'][x][-1].shape[-1]),
764
- time_delta,
765
- dtype=torch.float32,
766
- device=self.device,
767
- )
768
- )
769
 
770
- tokenized['input_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['input_ids'].values()]
771
- tokenized['token_type_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['token_type_ids'].values()]
772
- tokenized['time_delta'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, device=self.device) for j in tokenized['time_delta'].values()]
773
-
774
- tokenized['input_ids'] = torch.nn.utils.rnn.pad_sequence(
775
- tokenized['input_ids'], batch_first=True, padding_value=tokenizer.pad_token_id
776
- )[:, :, 0]
777
- tokenized['token_type_ids'] = torch.nn.utils.rnn.pad_sequence(
778
- tokenized['token_type_ids'], batch_first=True, padding_value=0,
779
- )[:, :, 0]
780
-
781
- tokenized['attention_mask'] = (tokenized['input_ids'] != tokenizer.pad_token_id).int()
782
-
783
- tokenized['time_delta'] = torch.nn.utils.rnn.pad_sequence(
784
- tokenized['time_delta'], batch_first=True, padding_value=0,
785
- )
786
-
787
- return tokenized
788
-
789
  def prepare_inputs(
790
  self,
791
  images,
@@ -914,7 +718,219 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
914
  assert inputs_embeds.shape[1] == token_type_ids.shape[1]
915
 
916
  return inputs_embeds, attention_mask, token_type_ids, position_ids, bos_token_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
917
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
918
  @staticmethod
919
  def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask, dtype):
920
 
@@ -983,86 +999,24 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
983
  mixed_causality_4d_attention_mask[mixed_causality_4d_attention_mask == 1] = 0.0
984
 
985
  return mixed_causality_4d_attention_mask
986
-
987
- # @staticmethod
988
- # def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask):
989
-
990
- # prompt_seq_len = non_causal_2d_attention_mask.shape[-1]
991
- # report_seq_len = causal_2d_attention_mask.shape[-1]
992
-
993
- # non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
994
- # causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
995
-
996
- # # Upper left of attention matrix:
997
- # upper_left = non_causal_2d_attention_mask.expand(-1, -1, prompt_seq_len, -1)
998
- # upper_left = upper_left * non_causal_2d_attention_mask
999
- # upper_left = upper_left * non_causal_2d_attention_mask.permute(0, 1, 3, 2)
1000
-
1001
- # causal_mask = torch.tril(
1002
- # torch.ones(
1003
- # (
1004
- # report_seq_len,
1005
- # report_seq_len,
1006
- # ),
1007
- # dtype=torch.long,
1008
- # device=causal_2d_attention_mask.device,
1009
- # ),
1010
- # )
1011
-
1012
- # # Lower right of attention matrix:
1013
- # lower_right = causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
1014
- # lower_right = lower_right * causal_2d_attention_mask.permute(0, 1, 3, 2)
1015
- # lower_right = lower_right * causal_mask
1016
-
1017
- # # Upper right of attention matrix:
1018
- # upper_right = torch.zeros(
1019
- # causal_2d_attention_mask.shape[0],
1020
- # 1,
1021
- # prompt_seq_len,
1022
- # report_seq_len,
1023
- # dtype=torch.long,
1024
- # device=causal_2d_attention_mask.device,
1025
- # )
1026
-
1027
- # # Lower left of attention matrix:
1028
- # lower_left = non_causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
1029
- # lower_left = lower_left * causal_2d_attention_mask.permute(0, 1, 3, 2)
1030
-
1031
- # left = torch.cat((upper_left, lower_left), dim=2)
1032
- # right = torch.cat((upper_right, lower_right), dim=2)
1033
 
1034
- # mixed_causality_4d_attention_mask = torch.cat((left, right), dim=-1)
1035
-
1036
- # return mixed_causality_4d_attention_mask
1037
-
1038
- # @staticmethod
1039
- # def create_4d_attention_mask_mixed_causality_past_key_values(non_causal_2d_attention_mask, causal_2d_attention_mask):
1040
-
1041
- # non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
1042
- # causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
1043
-
1044
- # mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
1045
- # return mixed_causality_4d_attention_mask
1046
 
1047
- def position_ids_from_time_deltas_and_attention_mask(self, time_deltas, attention_mask):
1048
- mask_value = torch.finfo(time_deltas.dtype).max if self.config.time_delta_monotonic_inversion else torch.finfo(time_deltas.dtype).min
1049
-
1050
- masked_time_deltas = torch.where(attention_mask == 1, time_deltas[:, :, 0], mask_value)
1051
- _, col_indices = torch.sort(masked_time_deltas, descending=not self.config.time_delta_monotonic_inversion)
1052
 
1053
- num_rows, num_cols, _ = time_deltas.shape
1054
 
1055
- row_indices = torch.arange(num_rows, device=time_deltas.device).view(-1, 1).repeat(1, num_cols).view(-1)
1056
- position_ids = torch.zeros_like(col_indices, device=time_deltas.device)
1057
- position_ids[row_indices, col_indices.flatten()] = torch.arange(num_cols, device=time_deltas.device)[None, :].expand(num_rows, -1).flatten()
1058
- position_ids.masked_fill_(attention_mask == 0, 1) # Following: https://github.com/huggingface/transformers/blob/c5f0288bc7d76f65996586f79f69fba8867a0e67/src/transformers/models/llama/modeling_llama.py#L1285
1059
 
1060
- return position_ids
1061
-
1062
- def get_dataset(self, dataset_path, train_transforms=None, test_transforms=None, max_train_images_per_study=None, study_id_split='mimic_iv_ed_mimic_cxr_jpg', test_set_only=False):
1063
 
1064
- assert max_train_images_per_study is not None, 'max_train_images_per_study must be defined.'
1065
- assert test_transforms is not None, 'test_transforms must be defined.'
1066
 
1067
  def train_set_transform(batch):
1068
 
@@ -1081,7 +1035,7 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
1081
 
1082
  # Sort based on ViewPosition:
1083
  batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
1084
- batch['images'] = [torch.stack([train_transforms(j) for j in i]) for i in batch['images']]
1085
  max_size = max(i.shape[0] for i in batch['images'])
1086
 
1087
  batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
@@ -1104,7 +1058,7 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
1104
 
1105
  # Sort based on ViewPosition:
1106
  batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
1107
- batch['images'] = [torch.stack([test_transforms(j) for j in i]) for i in batch['images']]
1108
  max_size = max(i.shape[0] for i in batch['images'])
1109
  batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
1110
  batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
@@ -1177,7 +1131,9 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
1177
  else:
1178
  return test_set
1179
 
1180
- def get_stage_1_dataset(self, dataset_path, train_transforms, test_transforms, max_train_images_per_study):
 
 
1181
 
1182
  def train_set_transform(batch):
1183
 
@@ -1192,7 +1148,7 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
1192
 
1193
  # Sort based on ViewPosition:
1194
  batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
1195
- batch['images'] = [torch.stack([train_transforms(j) for j in i]) for i in batch['images']]
1196
  max_size = max(i.shape[0] for i in batch['images'])
1197
  batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
1198
  batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
@@ -1204,7 +1160,7 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
1204
 
1205
  # Sort based on ViewPosition:
1206
  batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
1207
- batch['images'] = [torch.stack([test_transforms(j) for j in i]) for i in batch['images']]
1208
  max_size = max(i.shape[0] for i in batch['images'])
1209
  batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
1210
  batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
@@ -1256,138 +1212,4 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
1256
  test_set = Subset(test_set, indices)
1257
 
1258
  return train_set, val_set, test_set
1259
-
1260
- def prepare_index_value_feats(self, table, batch):
1261
-
1262
- index_value_columns = (self.tables[table].get('index_columns', []) + self.tables[table].get('value_columns', []))
1263
- index_value_columns = [f'{table}_{i}' for i in index_value_columns] if table != 'mimic_cxr_2_0_0_metadata' else index_value_columns
1264
-
1265
- # Map to indices with lookup table:
1266
- if 'index_columns' in self.tables[table]:
1267
- for i in self.tables[table]['index_columns']:
1268
- k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
1269
- batch[k] = [
1270
- [self.luts[table][i][str(k)] if k is not None else None for k in j] if j is not None else None for j in batch[k]
1271
- ]
1272
-
1273
- batch_index_value_feats_list = []
1274
- batch_token_type_ids_list = []
1275
- batch_time_deltas_list = []
1276
-
1277
- for batch_idx in range(len(batch['study_id'])):
1278
-
1279
- if any([batch[k][batch_idx] for k in index_value_columns]):
1280
-
1281
- num_rows = [len(batch[i][batch_idx]) for i in index_value_columns]
1282
- assert all(x == num_rows[0] for x in num_rows)
1283
- num_rows = num_rows[0]
1284
-
1285
- # The y-index and the datetime for each group:
1286
- if isinstance(batch[self.tables[table]['groupby']][batch_idx], list):
1287
- y_indices = [d.setdefault(x, len(d)) for d in [{}] for x in batch[self.tables[table]['groupby']][batch_idx]]
1288
- datetime = [j for i, j in enumerate(batch[self.tables[table]['time_column']][batch_idx]) if j not in batch[self.tables[table]['time_column']][batch_idx][:i]]
1289
- assert len(set(y_indices)) == len(datetime)
1290
- else:
1291
- y_indices = [0] * num_rows
1292
- datetime = batch[self.tables[table]['time_column']][batch_idx] if 'time_column' in self.tables[table] else [batch['latest_study_datetime'][batch_idx]]
1293
-
1294
- time_deltas = torch.tensor([compute_time_delta(i, batch['latest_study_datetime'][batch_idx], self.time_delta_map, to_tensor=False) for i in datetime])[:, None]
1295
-
1296
- tensor = torch.zeros(max(y_indices) + 1, self.luts[table]['total'])
1297
-
1298
- # Index columns to feats:
1299
- if 'index_columns' in self.tables[table]:
1300
-
1301
- for i in self.tables[table]['index_columns']:
1302
- k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
1303
- y_indices_column = [y_idx for y_idx, x_idx in zip(y_indices, batch[k][batch_idx]) if x_idx is not None]
1304
- x_indices_column = [x_idx for x_idx in batch[k][batch_idx] if x_idx is not None]
1305
-
1306
- tensor[y_indices_column, x_indices_column] = 1.0
1307
-
1308
- if 'value_columns' in self.tables[table]:
1309
- for i in self.tables[table]['value_columns']:
1310
-
1311
- k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
1312
- y_indices_column = [y_idx for y_idx, value in zip(y_indices, batch[k][batch_idx]) if value is not None]
1313
- x_indices_column = [self.luts[table][i] for value in batch[k][batch_idx] if value is not None]
1314
- values = [value for value in batch[k][batch_idx] if value is not None]
1315
-
1316
- tensor[y_indices_column, x_indices_column] = torch.tensor(values, dtype=tensor.dtype)
1317
- assert not torch.isnan(tensor).any()
1318
- else:
1319
- tensor = torch.empty(0, self.luts[table]['total'])
1320
- time_deltas = torch.empty(0, 1)
1321
-
1322
- batch_index_value_feats_list.append(tensor)
1323
- batch_token_type_ids_list.append(torch.full(
1324
- [tensor.shape[0]],
1325
- self.token_type_to_token_type_id[table],
1326
- dtype=torch.long,
1327
- )
1328
- )
1329
- batch_time_deltas_list.append(time_deltas)
1330
-
1331
- assert tensor.shape[0] == batch_token_type_ids_list[-1].shape[0]
1332
- assert tensor.shape[0] == time_deltas.shape[0]
1333
-
1334
- batch_index_value_feats = torch.nn.utils.rnn.pad_sequence(batch_index_value_feats_list, batch_first=True, padding_value=-1) # Pad value of -1 is not ideal. Need to use something else.
1335
- batch_token_type_ids = torch.nn.utils.rnn.pad_sequence(batch_token_type_ids_list, batch_first=True, padding_value=0)
1336
- batch_time_deltas = torch.nn.utils.rnn.pad_sequence(batch_time_deltas_list, batch_first=True, padding_value=0)
1337
-
1338
- batch_mask = (batch_index_value_feats != -1).any(dim=-1).int()
1339
-
1340
- return batch_index_value_feats, batch_token_type_ids, batch_time_deltas, batch_mask
1341
-
1342
- def prepare_text_prompt(self, table, column, batch):
1343
-
1344
- key = f'{table}_{column}' if not table == 'mimic_cxr_sectioned' else column
1345
-
1346
- batch_text_list = []
1347
- batch_time_deltas_list = []
1348
-
1349
- for batch_idx in range(len(batch['study_id'])):
1350
- if batch[key][batch_idx]:
1351
-
1352
- num_rows = len(batch[key][batch_idx])
1353
-
1354
- # The y-index and the datetime for each group:
1355
- if isinstance(batch[self.tables[table]['groupby']][batch_idx], list):
1356
- y_indices = [d.setdefault(x, len(d)) for d in [{}] for x in batch[self.tables[table]['groupby']][batch_idx]]
1357
- datetime = [j for i, j in enumerate(batch[self.tables[table]['time_column']][batch_idx]) if j not in batch[self.tables[table]['time_column']][batch_idx][:i]]
1358
- assert len(set(y_indices)) == len(datetime)
1359
- else:
1360
- y_indices = [0] * num_rows
1361
- datetime = batch[self.tables[table]['time_column']][batch_idx] if 'time_column' in self.tables[table] else [batch['latest_study_datetime'][batch_idx]]
1362
-
1363
- # Remove None values:
1364
- text_rows = batch[key][batch_idx] if isinstance(batch[key][batch_idx], list) else [batch[key][batch_idx]]
1365
- y_indices = [i for i, j in zip(y_indices, text_rows) if j is not None]
1366
- text_rows = [i for i in text_rows if i is not None]
1367
- datetime = [datetime[i] for i in set(y_indices)]
1368
- if text_rows:
1369
-
1370
- # Those in the same group (or those with the same y-index) get joined as the same string:
1371
- batch_text_list.append([', '.join([text_rows[j] for j in range(len(y_indices)) if y_indices[j] == k]) + '.' for k in set(y_indices)])
1372
- batch_time_deltas_list.append([compute_time_delta(i, batch['latest_study_datetime'][batch_idx], self.time_delta_map, to_tensor=False) for i in datetime])
1373
-
1374
- assert len(batch_time_deltas_list[-1]) == len(batch_text_list[-1])
1375
- else:
1376
- batch_text_list.append([])
1377
- batch_time_deltas_list.append([])
1378
- else:
1379
- batch_text_list.append([])
1380
- batch_time_deltas_list.append([])
1381
-
1382
- return batch_text_list, batch_time_deltas_list
1383
-
1384
- @staticmethod
1385
- def collate_fn(batch):
1386
- keys = set().union(*(d.keys() for d in batch))
1387
- batch = {j: [i.setdefault(j, None) for i in batch] for j in keys}
1388
- batch = {k: torch.stack(v) if isinstance(v[0], torch.Tensor) else v for k, v in batch.items()}
1389
- return batch
1390
-
1391
- @staticmethod
1392
- def prepare_dataset(physionet_dir: str, database_dir: str):
1393
- prepare_dataset(physionet_dir=physionet_dir, database_dir=database_dir)
 
8
  import torch
9
  import transformers
10
  from huggingface_hub import hf_hub_download
11
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
12
  from torch.nn import CrossEntropyLoss
13
  from torch.utils.data import Subset
14
  from torchvision.io import decode_image
15
+ from torchvision.transforms import v2
16
+ from transformers import PreTrainedTokenizerFast
17
  from transformers.modeling_outputs import ModelOutput, Seq2SeqLMOutput
 
18
  from transformers.utils import check_min_version, logging
19
 
20
  from .configuration_cxrmate_ed import CXRMateEDConfig
 
187
 
188
  self.inf_time_delta_value = self.time_delta_map(float('inf'))
189
 
190
+ # Image transformations:
191
+ self.train_transforms = v2.Compose(
192
+ [
193
+ v2.Grayscale(num_output_channels=3),
194
+ v2.Resize(
195
+ size=self.config.vision_config.image_size,
196
+ antialias=True,
197
+ interpolation=v2.InterpolationMode.BICUBIC,
198
+ ),
199
+ v2.RandomCrop(
200
+ size=[self.config.vision_config.image_size, self.config.vision_config.image_size],
201
+ pad_if_needed=True,
202
+ ),
203
+ v2.RandomRotation(degrees=5),
204
+ v2.ToDtype(torch.float32, scale=True),
205
+ v2.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
206
+ ]
207
+ )
208
+ self.test_transforms = v2.Compose(
209
+ [
210
+ v2.Grayscale(num_output_channels=3),
211
+ v2.Resize(
212
+ size=self.config.vision_config.image_size,
213
+ antialias=True,
214
+ interpolation=v2.InterpolationMode.BICUBIC,
215
+ ),
216
+ v2.CenterCrop(size=[self.config.vision_config.image_size, self.config.vision_config.image_size]),
217
+ v2.ToDtype(torch.float32, scale=True),
218
+ v2.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
219
+ ]
220
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
+ self.post_init()
 
 
223
 
224
  def forward(
225
  self,
 
589
  sections[j].append(section_string)
590
 
591
  return tuple(sections.values())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
  def prepare_inputs(
594
  self,
595
  images,
 
718
  assert inputs_embeds.shape[1] == token_type_ids.shape[1]
719
 
720
  return inputs_embeds, attention_mask, token_type_ids, position_ids, bos_token_ids
721
+
722
+ def tokenize_text_prompt(self, tokenizer: PreTrainedTokenizerFast, **kwargs):
723
+ """
724
+ Tokenize the text columns from MIMIC-IV ED and MIMIC-CXR (excluding the findings and impression sections).
725
+ Time deltas for the input_ids are also prepared here.
726
+
727
+ Argument/s:
728
+ tokenizer - Hugging Face tokenizer.
729
+
730
+ Returns:
731
+ ed - dictionary containing the input_ids, token_type_ids, attention_mask and time_deltas for the ED module columns.
732
+ cxr - dictionary containing the input_ids, token_type_ids, and attention_mask for MIMIC-CXR columns.
733
+ """
734
+
735
+ batch_size = len(kwargs['study_id'])
736
+
737
+ tokenized = {
738
+ 'input_ids': {i: [] for i in range(batch_size)},
739
+ 'token_type_ids': {i: [] for i in range(batch_size)},
740
+ 'time_delta': {i: [] for i in range(batch_size)},
741
+ 'attention_mask': torch.empty(batch_size, 0, 1, device=self.device),
742
+ }
743
+
744
+ prompt_text_columns = [f'{k}_{j}' if k != 'mimic_cxr_sectioned' else j for k, v in self.tables.items() if 'text_columns' in v for j in (v['text_columns'] if isinstance(v['text_columns'], list) else [v['text_columns']])] + ['prior_findings', 'prior_impression']
745
+
746
+ for i in prompt_text_columns:
747
+ if i in kwargs:
748
+ if f'{i}_time_delta' not in kwargs:
749
+ kwargs[f'{i}_time_delta'] = [[self.zero_time_delta_value for _ in j] if j is not None else None for j in kwargs[i]]
750
+ for x, (y, z) in enumerate(zip(kwargs[i], kwargs[f'{i}_time_delta'])):
751
+ if y is not None:
752
+ assert isinstance(y, list)
753
+ assert isinstance(z, list)
754
+ for text, time_delta in zip(y, z):
755
+ if text is not None:
756
+ tokenized['input_ids'][x].append(
757
+ tokenizer(text, add_special_tokens=False, return_tensors='pt')['input_ids'].to(device=self.device)
758
+ )
759
+ tokenized['token_type_ids'][x].append(
760
+ torch.full(
761
+ (1, tokenized['input_ids'][x][-1].shape[-1]),
762
+ self.token_type_to_token_type_id[i],
763
+ dtype=torch.long,
764
+ device=self.device,
765
+ )
766
+ )
767
+ tokenized['time_delta'][x].append(
768
+ torch.full(
769
+ (1, tokenized['input_ids'][x][-1].shape[-1]),
770
+ time_delta,
771
+ dtype=torch.float32,
772
+ device=self.device,
773
+ )
774
+ )
775
+
776
+ tokenized['input_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['input_ids'].values()]
777
+ tokenized['token_type_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['token_type_ids'].values()]
778
+ tokenized['time_delta'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, device=self.device) for j in tokenized['time_delta'].values()]
779
+
780
+ tokenized['input_ids'] = torch.nn.utils.rnn.pad_sequence(
781
+ tokenized['input_ids'], batch_first=True, padding_value=tokenizer.pad_token_id
782
+ )[:, :, 0]
783
+ tokenized['token_type_ids'] = torch.nn.utils.rnn.pad_sequence(
784
+ tokenized['token_type_ids'], batch_first=True, padding_value=0,
785
+ )[:, :, 0]
786
+
787
+ tokenized['attention_mask'] = (tokenized['input_ids'] != tokenizer.pad_token_id).int()
788
+
789
+ tokenized['time_delta'] = torch.nn.utils.rnn.pad_sequence(
790
+ tokenized['time_delta'], batch_first=True, padding_value=0,
791
+ )
792
+
793
+ return tokenized
794
 
795
+ def position_ids_from_time_deltas_and_attention_mask(self, time_deltas, attention_mask):
796
+ mask_value = torch.finfo(time_deltas.dtype).max if self.config.time_delta_monotonic_inversion else torch.finfo(time_deltas.dtype).min
797
+
798
+ masked_time_deltas = torch.where(attention_mask == 1, time_deltas[:, :, 0], mask_value)
799
+ _, col_indices = torch.sort(masked_time_deltas, descending=not self.config.time_delta_monotonic_inversion)
800
+
801
+ num_rows, num_cols, _ = time_deltas.shape
802
+
803
+ row_indices = torch.arange(num_rows, device=time_deltas.device).view(-1, 1).repeat(1, num_cols).view(-1)
804
+ position_ids = torch.zeros_like(col_indices, device=time_deltas.device)
805
+ position_ids[row_indices, col_indices.flatten()] = torch.arange(num_cols, device=time_deltas.device)[None, :].expand(num_rows, -1).flatten()
806
+ position_ids.masked_fill_(attention_mask == 0, 1) # Following: https://github.com/huggingface/transformers/blob/c5f0288bc7d76f65996586f79f69fba8867a0e67/src/transformers/models/llama/modeling_llama.py#L1285
807
+
808
+ return position_ids
809
+
810
+ def prepare_index_value_feats(self, table, batch):
811
+
812
+ index_value_columns = (self.tables[table].get('index_columns', []) + self.tables[table].get('value_columns', []))
813
+ index_value_columns = [f'{table}_{i}' for i in index_value_columns] if table != 'mimic_cxr_2_0_0_metadata' else index_value_columns
814
+
815
+ # Map to indices with lookup table:
816
+ if 'index_columns' in self.tables[table]:
817
+ for i in self.tables[table]['index_columns']:
818
+ k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
819
+ batch[k] = [
820
+ [self.luts[table][i][str(k)] if k is not None else None for k in j] if j is not None else None for j in batch[k]
821
+ ]
822
+
823
+ batch_index_value_feats_list = []
824
+ batch_token_type_ids_list = []
825
+ batch_time_deltas_list = []
826
+
827
+ for batch_idx in range(len(batch['study_id'])):
828
+
829
+ if any([batch[k][batch_idx] for k in index_value_columns]):
830
+
831
+ num_rows = [len(batch[i][batch_idx]) for i in index_value_columns]
832
+ assert all(x == num_rows[0] for x in num_rows)
833
+ num_rows = num_rows[0]
834
+
835
+ # The y-index and the datetime for each group:
836
+ if isinstance(batch[self.tables[table]['groupby']][batch_idx], list):
837
+ y_indices = [d.setdefault(x, len(d)) for d in [{}] for x in batch[self.tables[table]['groupby']][batch_idx]]
838
+ datetime = [j for i, j in enumerate(batch[self.tables[table]['time_column']][batch_idx]) if j not in batch[self.tables[table]['time_column']][batch_idx][:i]]
839
+ assert len(set(y_indices)) == len(datetime)
840
+ else:
841
+ y_indices = [0] * num_rows
842
+ datetime = batch[self.tables[table]['time_column']][batch_idx] if 'time_column' in self.tables[table] else [batch['latest_study_datetime'][batch_idx]]
843
+
844
+ time_deltas = torch.tensor([compute_time_delta(i, batch['latest_study_datetime'][batch_idx], self.time_delta_map, to_tensor=False) for i in datetime])[:, None]
845
+
846
+ tensor = torch.zeros(max(y_indices) + 1, self.luts[table]['total'])
847
+
848
+ # Index columns to feats:
849
+ if 'index_columns' in self.tables[table]:
850
+
851
+ for i in self.tables[table]['index_columns']:
852
+ k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
853
+ y_indices_column = [y_idx for y_idx, x_idx in zip(y_indices, batch[k][batch_idx]) if x_idx is not None]
854
+ x_indices_column = [x_idx for x_idx in batch[k][batch_idx] if x_idx is not None]
855
+
856
+ tensor[y_indices_column, x_indices_column] = 1.0
857
+
858
+ if 'value_columns' in self.tables[table]:
859
+ for i in self.tables[table]['value_columns']:
860
+
861
+ k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
862
+ y_indices_column = [y_idx for y_idx, value in zip(y_indices, batch[k][batch_idx]) if value is not None]
863
+ x_indices_column = [self.luts[table][i] for value in batch[k][batch_idx] if value is not None]
864
+ values = [value for value in batch[k][batch_idx] if value is not None]
865
+
866
+ tensor[y_indices_column, x_indices_column] = torch.tensor(values, dtype=tensor.dtype)
867
+ assert not torch.isnan(tensor).any()
868
+ else:
869
+ tensor = torch.empty(0, self.luts[table]['total'])
870
+ time_deltas = torch.empty(0, 1)
871
+
872
+ batch_index_value_feats_list.append(tensor)
873
+ batch_token_type_ids_list.append(torch.full(
874
+ [tensor.shape[0]],
875
+ self.token_type_to_token_type_id[table],
876
+ dtype=torch.long,
877
+ )
878
+ )
879
+ batch_time_deltas_list.append(time_deltas)
880
+
881
+ assert tensor.shape[0] == batch_token_type_ids_list[-1].shape[0]
882
+ assert tensor.shape[0] == time_deltas.shape[0]
883
+
884
+ batch_index_value_feats = torch.nn.utils.rnn.pad_sequence(batch_index_value_feats_list, batch_first=True, padding_value=-1) # Pad value of -1 is not ideal. Need to use something else.
885
+ batch_token_type_ids = torch.nn.utils.rnn.pad_sequence(batch_token_type_ids_list, batch_first=True, padding_value=0)
886
+ batch_time_deltas = torch.nn.utils.rnn.pad_sequence(batch_time_deltas_list, batch_first=True, padding_value=0)
887
+
888
+ batch_mask = (batch_index_value_feats != -1).any(dim=-1).int()
889
+
890
+ return batch_index_value_feats, batch_token_type_ids, batch_time_deltas, batch_mask
891
+
892
+ def prepare_text_prompt(self, table, column, batch):
893
+
894
+ key = f'{table}_{column}' if not table == 'mimic_cxr_sectioned' else column
895
+
896
+ batch_text_list = []
897
+ batch_time_deltas_list = []
898
+
899
+ for batch_idx in range(len(batch['study_id'])):
900
+ if batch[key][batch_idx]:
901
+
902
+ num_rows = len(batch[key][batch_idx])
903
+
904
+ # The y-index and the datetime for each group:
905
+ if isinstance(batch[self.tables[table]['groupby']][batch_idx], list):
906
+ y_indices = [d.setdefault(x, len(d)) for d in [{}] for x in batch[self.tables[table]['groupby']][batch_idx]]
907
+ datetime = [j for i, j in enumerate(batch[self.tables[table]['time_column']][batch_idx]) if j not in batch[self.tables[table]['time_column']][batch_idx][:i]]
908
+ assert len(set(y_indices)) == len(datetime)
909
+ else:
910
+ y_indices = [0] * num_rows
911
+ datetime = batch[self.tables[table]['time_column']][batch_idx] if 'time_column' in self.tables[table] else [batch['latest_study_datetime'][batch_idx]]
912
+
913
+ # Remove None values:
914
+ text_rows = batch[key][batch_idx] if isinstance(batch[key][batch_idx], list) else [batch[key][batch_idx]]
915
+ y_indices = [i for i, j in zip(y_indices, text_rows) if j is not None]
916
+ text_rows = [i for i in text_rows if i is not None]
917
+ datetime = [datetime[i] for i in set(y_indices)]
918
+ if text_rows:
919
+
920
+ # Those in the same group (or those with the same y-index) get joined as the same string:
921
+ batch_text_list.append([', '.join([text_rows[j] for j in range(len(y_indices)) if y_indices[j] == k]) + '.' for k in set(y_indices)])
922
+ batch_time_deltas_list.append([compute_time_delta(i, batch['latest_study_datetime'][batch_idx], self.time_delta_map, to_tensor=False) for i in datetime])
923
+
924
+ assert len(batch_time_deltas_list[-1]) == len(batch_text_list[-1])
925
+ else:
926
+ batch_text_list.append([])
927
+ batch_time_deltas_list.append([])
928
+ else:
929
+ batch_text_list.append([])
930
+ batch_time_deltas_list.append([])
931
+
932
+ return batch_text_list, batch_time_deltas_list
933
+
934
  @staticmethod
935
  def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask, dtype):
936
 
 
999
  mixed_causality_4d_attention_mask[mixed_causality_4d_attention_mask == 1] = 0.0
1000
 
1001
  return mixed_causality_4d_attention_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1002
 
1003
+ @staticmethod
1004
+ def collate_fn(batch):
1005
+ keys = set().union(*(d.keys() for d in batch))
1006
+ batch = {j: [i.setdefault(j, None) for i in batch] for j in keys}
1007
+ batch = {k: torch.stack(v) if isinstance(v[0], torch.Tensor) else v for k, v in batch.items()}
1008
+ return batch
 
 
 
 
 
 
1009
 
1010
+ @staticmethod
1011
+ def prepare_dataset(physionet_dir: str, database_dir: str):
 
 
 
1012
 
1013
+ prepare_dataset(physionet_dir=physionet_dir, database_dir=database_dir)
1014
 
1015
+ def get_dataset(self, database_dir, max_train_images_per_study=None, study_id_split='mimic_iv_ed_mimic_cxr_jpg', test_set_only=False):
 
 
 
1016
 
1017
+ dataset_path = os.path.join(database_dir, 'mimic_iv_ed_mimic_cxr_jpg_dataset')
 
 
1018
 
1019
+ assert max_train_images_per_study is not None or test_set_only, 'max_train_images_per_study must be defined if training.'
 
1020
 
1021
  def train_set_transform(batch):
1022
 
 
1035
 
1036
  # Sort based on ViewPosition:
1037
  batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
1038
+ batch['images'] = [torch.stack([self.train_transforms(j) for j in i]) for i in batch['images']]
1039
  max_size = max(i.shape[0] for i in batch['images'])
1040
 
1041
  batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
 
1058
 
1059
  # Sort based on ViewPosition:
1060
  batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
1061
+ batch['images'] = [torch.stack([self.test_transforms(j) for j in i]) for i in batch['images']]
1062
  max_size = max(i.shape[0] for i in batch['images'])
1063
  batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
1064
  batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
 
1131
  else:
1132
  return test_set
1133
 
1134
+ def get_stage_1_dataset(self, database_dir, max_train_images_per_study):
1135
+
1136
+ dataset_path = os.path.join(database_dir, 'mimic_iv_ed_mimic_cxr_jpg_dataset')
1137
 
1138
  def train_set_transform(batch):
1139
 
 
1148
 
1149
  # Sort based on ViewPosition:
1150
  batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
1151
+ batch['images'] = [torch.stack([self.train_transforms(j) for j in i]) for i in batch['images']]
1152
  max_size = max(i.shape[0] for i in batch['images'])
1153
  batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
1154
  batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
 
1160
 
1161
  # Sort based on ViewPosition:
1162
  batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
1163
+ batch['images'] = [torch.stack([self.test_transforms(j) for j in i]) for i in batch['images']]
1164
  max_size = max(i.shape[0] for i in batch['images'])
1165
  batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
1166
  batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
 
1212
  test_set = Subset(test_set, indices)
1213
 
1214
  return train_set, val_set, test_set
1215
+