Upload model
Browse files- config.json +12 -4
- configuration_cxrmate_ed.py +2 -3
- model.safetensors +2 -2
- modelling_cxrmate_ed.py +71 -4
config.json
CHANGED
@@ -32,11 +32,19 @@
|
|
32 |
"vocab_size": 30000
|
33 |
},
|
34 |
"encoder": {
|
35 |
-
"_name_or_path": "",
|
36 |
-
"architectures":
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
"projection_size": 768,
|
39 |
-
"torch_dtype":
|
40 |
},
|
41 |
"is_encoder_decoder": false,
|
42 |
"model_type": "cxrmate-ed",
|
|
|
32 |
"vocab_size": 30000
|
33 |
},
|
34 |
"encoder": {
|
35 |
+
"_name_or_path": "aehrc/uniformer_base_tl_384",
|
36 |
+
"architectures": [
|
37 |
+
"UniFormerModel"
|
38 |
+
],
|
39 |
+
"auto_map": {
|
40 |
+
"AutoConfig": "aehrc/uniformer_base_tl_384--configuration_uniformer.UniFormerWithProjectionHeadConfig",
|
41 |
+
"AutoModel": "aehrc/uniformer_base_tl_384--modelling_uniformer.UniFormerModel"
|
42 |
+
},
|
43 |
+
"init_value": 1e-06,
|
44 |
+
"layer_scale": false,
|
45 |
+
"model_type": "uniformer",
|
46 |
"projection_size": 768,
|
47 |
+
"torch_dtype": "float32"
|
48 |
},
|
49 |
"is_encoder_decoder": false,
|
50 |
"model_type": "cxrmate-ed",
|
configuration_cxrmate_ed.py
CHANGED
@@ -2,8 +2,6 @@ import transformers
|
|
2 |
from transformers.configuration_utils import PretrainedConfig
|
3 |
from transformers.utils import logging
|
4 |
|
5 |
-
from .configuration_uniformer import UniFormerWithProjectionHeadConfig
|
6 |
-
|
7 |
logger = logging.get_logger(__name__)
|
8 |
|
9 |
|
@@ -40,9 +38,10 @@ class CXRMateEDConfig(PretrainedConfig):
|
|
40 |
|
41 |
|
42 |
if 'encoder' not in kwargs:
|
43 |
-
self.encoder =
|
44 |
'aehrc/uniformer_base_tl_384',
|
45 |
projection_size=768,
|
|
|
46 |
)
|
47 |
else:
|
48 |
self.encoder = kwargs.pop("encoder")
|
|
|
2 |
from transformers.configuration_utils import PretrainedConfig
|
3 |
from transformers.utils import logging
|
4 |
|
|
|
|
|
5 |
logger = logging.get_logger(__name__)
|
6 |
|
7 |
|
|
|
38 |
|
39 |
|
40 |
if 'encoder' not in kwargs:
|
41 |
+
self.encoder = transformers.AutoConfig.from_pretrained(
|
42 |
'aehrc/uniformer_base_tl_384',
|
43 |
projection_size=768,
|
44 |
+
trust_remote_code=True,
|
45 |
)
|
46 |
else:
|
47 |
self.encoder = kwargs.pop("encoder")
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:38661d70c87174cf130d3b60c278cd5a07491742aed00eac34b2bca5d795d564
|
3 |
+
size 789964216
|
modelling_cxrmate_ed.py
CHANGED
@@ -12,13 +12,12 @@ from torch.utils.data import Subset
|
|
12 |
from torchvision.io import decode_image
|
13 |
from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
|
14 |
from transformers.configuration_utils import PretrainedConfig
|
15 |
-
from transformers.modeling_outputs import Seq2SeqLMOutput
|
16 |
from transformers.modeling_utils import PreTrainedModel
|
17 |
from transformers.utils import logging
|
18 |
|
19 |
from .configuration_cxrmate_ed import CXRMateEDConfig
|
20 |
from .dataset import PriorsDataset
|
21 |
-
from .modelling_uniformer import MultiUniFormerWithProjectionHead
|
22 |
from .prepare_dataset import prepare_dataset
|
23 |
from .utils import compute_time_delta
|
24 |
|
@@ -46,6 +45,70 @@ class FNNEncoder(torch.nn.Module):
|
|
46 |
return self.down_proj(self.act_fn(self.up_proj(x)))
|
47 |
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
class CXRMateEDModel(VisionEncoderDecoderModel):
|
50 |
|
51 |
config_class = CXRMateEDConfig
|
@@ -77,14 +140,18 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
|
|
77 |
|
78 |
# Encoder:
|
79 |
if encoder is None:
|
80 |
-
encoder =
|
|
|
|
|
|
|
|
|
81 |
|
82 |
# Decoder:
|
83 |
if decoder is None:
|
84 |
assert not config.decoder.add_cross_attention
|
85 |
decoder = transformers.LlamaForCausalLM(config=config.decoder)
|
86 |
|
87 |
-
self.encoder = encoder
|
88 |
self.decoder = decoder
|
89 |
|
90 |
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
|
|
|
12 |
from torchvision.io import decode_image
|
13 |
from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
|
14 |
from transformers.configuration_utils import PretrainedConfig
|
15 |
+
from transformers.modeling_outputs import ModelOutput, Seq2SeqLMOutput
|
16 |
from transformers.modeling_utils import PreTrainedModel
|
17 |
from transformers.utils import logging
|
18 |
|
19 |
from .configuration_cxrmate_ed import CXRMateEDConfig
|
20 |
from .dataset import PriorsDataset
|
|
|
21 |
from .prepare_dataset import prepare_dataset
|
22 |
from .utils import compute_time_delta
|
23 |
|
|
|
45 |
return self.down_proj(self.act_fn(self.up_proj(x)))
|
46 |
|
47 |
|
48 |
+
class ProjectionHead(torch.nn.Module):
|
49 |
+
|
50 |
+
def __init__(self, input_size, hidden_size) -> None:
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
# Layer normalisation before projection:
|
54 |
+
self.layer_norm = torch.nn.LayerNorm(input_size, eps=1e-6)
|
55 |
+
|
56 |
+
# No bias as following layer normalisation with bias:
|
57 |
+
self.projection = torch.nn.Linear(input_size, hidden_size, bias=False)
|
58 |
+
|
59 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
60 |
+
x = self.layer_norm(x)
|
61 |
+
x = self.projection(x)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class CXRStudyImagesEncoder(torch.nn.Module):
|
66 |
+
def __init__(self, encoder, decoder_config):
|
67 |
+
super().__init__()
|
68 |
+
|
69 |
+
self.encoder = encoder
|
70 |
+
self.config = encoder.config
|
71 |
+
self.adapter = ProjectionHead(self.config.embed_dim[-1], decoder_config.hidden_size)
|
72 |
+
|
73 |
+
def forward(
|
74 |
+
self,
|
75 |
+
pixel_values: Optional[torch.Tensor] = None,
|
76 |
+
output_hidden_states: Optional[bool] = None,
|
77 |
+
return_dict: Optional[bool] = None,
|
78 |
+
) -> Union[Tuple, ModelOutput]:
|
79 |
+
|
80 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
81 |
+
|
82 |
+
# Flatten the batch and study_id dimensions:
|
83 |
+
assert len(pixel_values.shape) == 5, 'pixel_values must be B, S, C, H, W, where S is the max number of images for a study in the batch.'
|
84 |
+
last_hidden_state = self.encoder(pixel_values.view(-1, *pixel_values.shape[2:])).last_hidden_state
|
85 |
+
|
86 |
+
# Flatten h x w:
|
87 |
+
last_hidden_state = torch.flatten(last_hidden_state, 2) if last_hidden_state.dim() > 3 else last_hidden_state
|
88 |
+
|
89 |
+
# Project the features for each spatial position to the decoder's hidden size using the adapter network:
|
90 |
+
last_hidden_state = self.adapter(last_hidden_state)
|
91 |
+
|
92 |
+
# Concatenate the features for each chest X-ray:
|
93 |
+
last_hidden_state = last_hidden_state.view(pixel_values.shape[0], -1, last_hidden_state.shape[-1])
|
94 |
+
|
95 |
+
# Derive the attention mask from the pixel values:
|
96 |
+
mask = (pixel_values[:, :, 0, 0, 0] != 0.0)[:, :, None]
|
97 |
+
attention_mask = torch.ones(
|
98 |
+
[last_hidden_state.shape[0], pixel_values.shape[1], last_hidden_state.shape[1] // pixel_values.shape[1]],
|
99 |
+
dtype=torch.long,
|
100 |
+
device=mask.device,
|
101 |
+
)
|
102 |
+
attention_mask = attention_mask * mask
|
103 |
+
attention_mask = attention_mask.view(attention_mask.shape[0], -1)
|
104 |
+
|
105 |
+
if not return_dict:
|
106 |
+
return last_hidden_state
|
107 |
+
|
108 |
+
return ModelOutput(last_hidden_state=last_hidden_state, attention_mask=attention_mask)
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
class CXRMateEDModel(VisionEncoderDecoderModel):
|
113 |
|
114 |
config_class = CXRMateEDConfig
|
|
|
140 |
|
141 |
# Encoder:
|
142 |
if encoder is None:
|
143 |
+
encoder = transformers.AutoModel.from_pretrained(
|
144 |
+
'aehrc/uniformer_base_tl_384',
|
145 |
+
config=config.encoder,
|
146 |
+
trust_remote_code=True,
|
147 |
+
)
|
148 |
|
149 |
# Decoder:
|
150 |
if decoder is None:
|
151 |
assert not config.decoder.add_cross_attention
|
152 |
decoder = transformers.LlamaForCausalLM(config=config.decoder)
|
153 |
|
154 |
+
self.encoder = CXRStudyImagesEncoder(encoder, self.config.decoder)
|
155 |
self.decoder = decoder
|
156 |
|
157 |
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
|