anicolson commited on
Commit
688909e
·
verified ·
1 Parent(s): e1b274e

Upload model

Browse files
config.json CHANGED
@@ -32,11 +32,19 @@
32
  "vocab_size": 30000
33
  },
34
  "encoder": {
35
- "_name_or_path": "",
36
- "architectures": null,
37
- "model_type": "vit",
 
 
 
 
 
 
 
 
38
  "projection_size": 768,
39
- "torch_dtype": null
40
  },
41
  "is_encoder_decoder": false,
42
  "model_type": "cxrmate-ed",
 
32
  "vocab_size": 30000
33
  },
34
  "encoder": {
35
+ "_name_or_path": "aehrc/uniformer_base_tl_384",
36
+ "architectures": [
37
+ "UniFormerModel"
38
+ ],
39
+ "auto_map": {
40
+ "AutoConfig": "aehrc/uniformer_base_tl_384--configuration_uniformer.UniFormerWithProjectionHeadConfig",
41
+ "AutoModel": "aehrc/uniformer_base_tl_384--modelling_uniformer.UniFormerModel"
42
+ },
43
+ "init_value": 1e-06,
44
+ "layer_scale": false,
45
+ "model_type": "uniformer",
46
  "projection_size": 768,
47
+ "torch_dtype": "float32"
48
  },
49
  "is_encoder_decoder": false,
50
  "model_type": "cxrmate-ed",
configuration_cxrmate_ed.py CHANGED
@@ -2,8 +2,6 @@ import transformers
2
  from transformers.configuration_utils import PretrainedConfig
3
  from transformers.utils import logging
4
 
5
- from .configuration_uniformer import UniFormerWithProjectionHeadConfig
6
-
7
  logger = logging.get_logger(__name__)
8
 
9
 
@@ -40,9 +38,10 @@ class CXRMateEDConfig(PretrainedConfig):
40
 
41
 
42
  if 'encoder' not in kwargs:
43
- self.encoder = UniFormerWithProjectionHeadConfig.from_pretrained(
44
  'aehrc/uniformer_base_tl_384',
45
  projection_size=768,
 
46
  )
47
  else:
48
  self.encoder = kwargs.pop("encoder")
 
2
  from transformers.configuration_utils import PretrainedConfig
3
  from transformers.utils import logging
4
 
 
 
5
  logger = logging.get_logger(__name__)
6
 
7
 
 
38
 
39
 
40
  if 'encoder' not in kwargs:
41
+ self.encoder = transformers.AutoConfig.from_pretrained(
42
  'aehrc/uniformer_base_tl_384',
43
  projection_size=768,
44
+ trust_remote_code=True,
45
  )
46
  else:
47
  self.encoder = kwargs.pop("encoder")
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3479bf42a0f90e144362b0785b0fe9a11078562f61217230e4340a5519e56f48
3
- size 789958760
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38661d70c87174cf130d3b60c278cd5a07491742aed00eac34b2bca5d795d564
3
+ size 789964216
modelling_cxrmate_ed.py CHANGED
@@ -12,13 +12,12 @@ from torch.utils.data import Subset
12
  from torchvision.io import decode_image
13
  from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
14
  from transformers.configuration_utils import PretrainedConfig
15
- from transformers.modeling_outputs import Seq2SeqLMOutput
16
  from transformers.modeling_utils import PreTrainedModel
17
  from transformers.utils import logging
18
 
19
  from .configuration_cxrmate_ed import CXRMateEDConfig
20
  from .dataset import PriorsDataset
21
- from .modelling_uniformer import MultiUniFormerWithProjectionHead
22
  from .prepare_dataset import prepare_dataset
23
  from .utils import compute_time_delta
24
 
@@ -46,6 +45,70 @@ class FNNEncoder(torch.nn.Module):
46
  return self.down_proj(self.act_fn(self.up_proj(x)))
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  class CXRMateEDModel(VisionEncoderDecoderModel):
50
 
51
  config_class = CXRMateEDConfig
@@ -77,14 +140,18 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
77
 
78
  # Encoder:
79
  if encoder is None:
80
- encoder = MultiUniFormerWithProjectionHead(config=config.encoder)
 
 
 
 
81
 
82
  # Decoder:
83
  if decoder is None:
84
  assert not config.decoder.add_cross_attention
85
  decoder = transformers.LlamaForCausalLM(config=config.decoder)
86
 
87
- self.encoder = encoder
88
  self.decoder = decoder
89
 
90
  if self.encoder.config.to_dict() != self.config.encoder.to_dict():
 
12
  from torchvision.io import decode_image
13
  from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
14
  from transformers.configuration_utils import PretrainedConfig
15
+ from transformers.modeling_outputs import ModelOutput, Seq2SeqLMOutput
16
  from transformers.modeling_utils import PreTrainedModel
17
  from transformers.utils import logging
18
 
19
  from .configuration_cxrmate_ed import CXRMateEDConfig
20
  from .dataset import PriorsDataset
 
21
  from .prepare_dataset import prepare_dataset
22
  from .utils import compute_time_delta
23
 
 
45
  return self.down_proj(self.act_fn(self.up_proj(x)))
46
 
47
 
48
+ class ProjectionHead(torch.nn.Module):
49
+
50
+ def __init__(self, input_size, hidden_size) -> None:
51
+ super().__init__()
52
+
53
+ # Layer normalisation before projection:
54
+ self.layer_norm = torch.nn.LayerNorm(input_size, eps=1e-6)
55
+
56
+ # No bias as following layer normalisation with bias:
57
+ self.projection = torch.nn.Linear(input_size, hidden_size, bias=False)
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ x = self.layer_norm(x)
61
+ x = self.projection(x)
62
+ return x
63
+
64
+
65
+ class CXRStudyImagesEncoder(torch.nn.Module):
66
+ def __init__(self, encoder, decoder_config):
67
+ super().__init__()
68
+
69
+ self.encoder = encoder
70
+ self.config = encoder.config
71
+ self.adapter = ProjectionHead(self.config.embed_dim[-1], decoder_config.hidden_size)
72
+
73
+ def forward(
74
+ self,
75
+ pixel_values: Optional[torch.Tensor] = None,
76
+ output_hidden_states: Optional[bool] = None,
77
+ return_dict: Optional[bool] = None,
78
+ ) -> Union[Tuple, ModelOutput]:
79
+
80
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
81
+
82
+ # Flatten the batch and study_id dimensions:
83
+ assert len(pixel_values.shape) == 5, 'pixel_values must be B, S, C, H, W, where S is the max number of images for a study in the batch.'
84
+ last_hidden_state = self.encoder(pixel_values.view(-1, *pixel_values.shape[2:])).last_hidden_state
85
+
86
+ # Flatten h x w:
87
+ last_hidden_state = torch.flatten(last_hidden_state, 2) if last_hidden_state.dim() > 3 else last_hidden_state
88
+
89
+ # Project the features for each spatial position to the decoder's hidden size using the adapter network:
90
+ last_hidden_state = self.adapter(last_hidden_state)
91
+
92
+ # Concatenate the features for each chest X-ray:
93
+ last_hidden_state = last_hidden_state.view(pixel_values.shape[0], -1, last_hidden_state.shape[-1])
94
+
95
+ # Derive the attention mask from the pixel values:
96
+ mask = (pixel_values[:, :, 0, 0, 0] != 0.0)[:, :, None]
97
+ attention_mask = torch.ones(
98
+ [last_hidden_state.shape[0], pixel_values.shape[1], last_hidden_state.shape[1] // pixel_values.shape[1]],
99
+ dtype=torch.long,
100
+ device=mask.device,
101
+ )
102
+ attention_mask = attention_mask * mask
103
+ attention_mask = attention_mask.view(attention_mask.shape[0], -1)
104
+
105
+ if not return_dict:
106
+ return last_hidden_state
107
+
108
+ return ModelOutput(last_hidden_state=last_hidden_state, attention_mask=attention_mask)
109
+
110
+
111
+
112
  class CXRMateEDModel(VisionEncoderDecoderModel):
113
 
114
  config_class = CXRMateEDConfig
 
140
 
141
  # Encoder:
142
  if encoder is None:
143
+ encoder = transformers.AutoModel.from_pretrained(
144
+ 'aehrc/uniformer_base_tl_384',
145
+ config=config.encoder,
146
+ trust_remote_code=True,
147
+ )
148
 
149
  # Decoder:
150
  if decoder is None:
151
  assert not config.decoder.add_cross_attention
152
  decoder = transformers.LlamaForCausalLM(config=config.decoder)
153
 
154
+ self.encoder = CXRStudyImagesEncoder(encoder, self.config.decoder)
155
  self.decoder = decoder
156
 
157
  if self.encoder.config.to_dict() != self.config.encoder.to_dict():