medical
AleksanderObuchowski commited on
Commit
1640642
·
verified ·
1 Parent(s): 5f10544

Add files using upload-large-folder tool

Browse files
Files changed (9) hide show
  1. .gitignore +10 -0
  2. .python-version +1 -0
  3. README.md +0 -3
  4. example.py +44 -0
  5. flask_app.py +57 -0
  6. medimageinsightmodel.py +239 -0
  7. pyproject.toml +29 -0
  8. requirements.txt +18 -0
  9. uv.lock +0 -0
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.8.19
README.md CHANGED
@@ -1,3 +0,0 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
example.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Initialize classifier
2
+ from medimageinsightmodel import MedImageInsight
3
+ import base64
4
+
5
+
6
+ classifier = MedImageInsight(
7
+ model_dir="2024.09.27",
8
+ vision_model_name="medimageinsigt-v1.0.0.pt",
9
+ language_model_name="language_model.pth"
10
+ )
11
+
12
+ def read_image(image_path):
13
+ with open(image_path, "rb") as f:
14
+ return f.read()
15
+
16
+ # Load model
17
+ classifier.load_model()
18
+
19
+ import urllib.request
20
+
21
+ image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
22
+ image_path = "CXR145_IM-0290-1001.png"
23
+
24
+ urllib.request.urlretrieve(image_url, image_path)
25
+ print(f"Image downloaded to {image_path}")
26
+
27
+
28
+ image = base64.encodebytes(read_image(image_path)).decode("utf-8")
29
+
30
+ # Example inference
31
+ images = [image]
32
+ labels = ["normal", "Pneumonia", "unclear"]
33
+
34
+ #Zero-shot classification
35
+ results = classifier.predict(images, labels)
36
+ print(results)
37
+
38
+ #Image embeddings
39
+ results = classifier.encode(images = images)
40
+ print(results)
41
+
42
+ #Text embeddings
43
+ results = classifier.encode(texts = labels)
44
+ print(results)
flask_app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List
4
+ import uvicorn
5
+ from medimageinsightmodel import MedImageInsight
6
+ import base64
7
+
8
+ # Initialize FastAPI app
9
+ app = FastAPI(title="Medical Image Analysis API")
10
+
11
+ # Initialize model
12
+ classifier = MedImageInsight(
13
+ model_dir="2024.09.27",
14
+ vision_model_name="medimageinsigt-v1.0.0.pt",
15
+ language_model_name="language_model.pth"
16
+ )
17
+ classifier.load_model()
18
+
19
+
20
+ class ClassificationRequest(BaseModel):
21
+ images: List[str] # Base64 encoded images
22
+ labels: List[str]
23
+ multilabel : bool = False
24
+
25
+ class EmbeddingRequest(BaseModel):
26
+ images: List[str] = None # Base64 encoded images
27
+ texts: List[str] = None
28
+
29
+ @app.post("/predict")
30
+ async def predict(request: ClassificationRequest):
31
+ try:
32
+ results = classifier.predict(
33
+ images=request.images,
34
+ labels=request.labels,
35
+ multilabel = request.multilabel
36
+ )
37
+ return {"predictions": results}
38
+ except Exception as e:
39
+ raise HTTPException(status_code=500, detail=str(e))
40
+
41
+ @app.post("/encode")
42
+ async def encode(request: EmbeddingRequest):
43
+ try:
44
+ results = classifier.encode(images=request.images, texts= request.texts)
45
+ results["image_embeddings"] = results["image_embeddings"].tolist() if results["image_embeddings"] is not None else None
46
+ results["text_embeddings"] = results["text_embeddings"].tolist() if results["text_embeddings"] is not None else None
47
+
48
+ return results
49
+ except Exception as e:
50
+ raise HTTPException(status_code=500, detail=str(e))
51
+
52
+ @app.get("/health")
53
+ async def health():
54
+ return {"status": "healthy"}
55
+
56
+ if __name__ == "__main__":
57
+ uvicorn.run(app, host="0.0.0.0", port=8000)
medimageinsightmodel.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Medical Image Classification model wrapper class that loads the model, preprocesses inputs and performs inference."""
2
+
3
+ import torch
4
+ from PIL import Image
5
+ import pandas as pd
6
+ from typing import List, Tuple
7
+ import os
8
+ import tempfile
9
+ import base64
10
+ import io
11
+
12
+ from MedImageInsight.UniCLModel import build_unicl_model
13
+ from MedImageInsight.Utils.Arguments import load_opt_from_config_files
14
+ from MedImageInsight.ImageDataLoader import build_transforms
15
+ from MedImageInsight.LangEncoder import build_tokenizer
16
+
17
+
18
+ class MedImageInsight:
19
+ """Wrapper class for medical image classification model."""
20
+
21
+ def __init__(
22
+ self,
23
+ model_dir: str,
24
+ vision_model_name: str,
25
+ language_model_name: str
26
+ ) -> None:
27
+ """Initialize the medical image classifier.
28
+
29
+ Args:
30
+ model_dir: Directory containing model files and config
31
+ vision_model_name: Name of the vision model
32
+ language_model_name: Name of the language model
33
+ """
34
+ self.model_dir = model_dir
35
+ self.vision_model_name = vision_model_name
36
+ self.language_model_name = language_model_name
37
+ self.model = None
38
+ self.device = None
39
+ self.tokenize = None
40
+ self.preprocess = None
41
+ self.opt = None
42
+
43
+ def load_model(self) -> None:
44
+ """Load the model and necessary components."""
45
+ try:
46
+ # Load configuration
47
+ config_path = os.path.join(self.model_dir, 'config.yaml')
48
+ self.opt = load_opt_from_config_files([config_path])
49
+
50
+ # Set paths
51
+ self.opt['LANG_ENCODER']['PRETRAINED_TOKENIZER'] = os.path.join(
52
+ self.model_dir,
53
+ 'language_model',
54
+ 'clip_tokenizer_4.16.2'
55
+ )
56
+ self.opt['UNICL_MODEL']['PRETRAINED'] = os.path.join(
57
+ self.model_dir,
58
+ 'vision_model',
59
+ self.vision_model_name
60
+ )
61
+
62
+ # Initialize components
63
+ self.preprocess = build_transforms(self.opt, False)
64
+ self.model = build_unicl_model(self.opt)
65
+
66
+ # Set device
67
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
+ self.model.to(self.device)
69
+
70
+ # Load tokenizer
71
+ self.tokenize = build_tokenizer(self.opt['LANG_ENCODER'])
72
+ self.max_length = self.opt['LANG_ENCODER']['CONTEXT_LENGTH']
73
+
74
+ print(f"Model loaded successfully on device: {self.device}")
75
+
76
+ except Exception as e:
77
+ print("Failed to load the model:")
78
+ raise e
79
+
80
+ @staticmethod
81
+ def decode_base64_image(base64_str: str) -> Image.Image:
82
+ """Decode base64 string to PIL Image and ensure RGB format.
83
+
84
+ Args:
85
+ base64_str: Base64 encoded image string
86
+
87
+ Returns:
88
+ PIL Image object in RGB format
89
+ """
90
+ try:
91
+ # Remove header if present
92
+ if ',' in base64_str:
93
+ base64_str = base64_str.split(',')[1]
94
+
95
+ image_bytes = base64.b64decode(base64_str)
96
+ image = Image.open(io.BytesIO(image_bytes))
97
+
98
+ # Convert grayscale (L) or grayscale with alpha (LA) to RGB
99
+ if image.mode in ('L', 'LA'):
100
+ image = image.convert('RGB')
101
+
102
+ return image
103
+ except Exception as e:
104
+ raise ValueError(f"Failed to decode base64 image: {str(e)}")
105
+
106
+ def predict(self, images: List[str], labels: List[str], multilabel: bool = False) -> List[dict]:
107
+ """Perform zero shot classification on the input images.
108
+
109
+ Args:
110
+ images: List of base64 encoded image strings
111
+ labels: List of candidate labels for classification
112
+
113
+ Returns:
114
+ DataFrame with columns ["probabilities", "labels"]
115
+ """
116
+ if not self.model:
117
+ raise RuntimeError("Model not loaded. Call load_model() first.")
118
+
119
+ if not labels:
120
+ raise ValueError("No labels provided")
121
+
122
+ # Create temporary directory for processing
123
+ with tempfile.TemporaryDirectory() as tmp_dir:
124
+ # Process images
125
+ image_list = []
126
+ for img_base64 in images:
127
+ try:
128
+ img = self.decode_base64_image(img_base64)
129
+ image_list.append(img)
130
+ except Exception as e:
131
+ raise ValueError(f"Failed to process image: {str(e)}")
132
+
133
+ # Run inference
134
+ probs = self.run_inference_batch(image_list, labels, multilabel)
135
+ probs_np = probs.cpu().numpy()
136
+ results = []
137
+ for prob_row in probs_np:
138
+ # Create label-prob pairs and sort by probability
139
+ label_probs = [(label, float(prob)) for label, prob in zip(labels, prob_row)]
140
+ label_probs.sort(key=lambda x: x[1], reverse=True)
141
+
142
+ # Create ordered dictionary from sorted pairs
143
+ results.append({
144
+ label: prob
145
+ for label, prob in label_probs
146
+ })
147
+
148
+ return results
149
+
150
+ def encode(self, images: List[str] = None, texts: List[str] = None):
151
+
152
+ output = {
153
+ "image_embeddings" : None,
154
+ "text_embeddings" : None,
155
+ }
156
+
157
+ if not self.model:
158
+ raise RuntimeError("Model not loaded. Call load_model() first.")
159
+
160
+ if not images and not texts:
161
+ raise ValueError("You must provide either images or texts")
162
+
163
+ if images is not None:
164
+ with tempfile.TemporaryDirectory() as tmp_dir:
165
+ # Process images
166
+ image_list = []
167
+ for img_base64 in images:
168
+ try:
169
+ img = self.decode_base64_image(img_base64)
170
+ image_list.append(img)
171
+ except Exception as e:
172
+ raise ValueError(f"Failed to process image: {str(e)}")
173
+ images = torch.stack([self.preprocess(img) for img in image_list]).to(self.device)
174
+ with torch.no_grad():
175
+ output["image_embeddings"] = self.model.encode_image(images).cpu().numpy()
176
+
177
+ if texts is not None:
178
+ text_tokens = self.tokenize(
179
+ texts,
180
+ padding='max_length',
181
+ max_length=self.max_length,
182
+ truncation=True,
183
+ return_tensors='pt'
184
+ )
185
+
186
+ # Move text tensors to the correct device
187
+ text_tokens = {k: v.to(self.device) for k, v in text_tokens.items()}
188
+ output["text_embeddings"] = self.model.encode_text(text_tokens).cpu().numpy()
189
+
190
+
191
+ return output
192
+
193
+ def run_inference_batch(
194
+ self,
195
+ images: List[Image.Image],
196
+ texts: List[str],
197
+ multilabel: bool = False
198
+ ) -> torch.Tensor:
199
+ """Perform inference on batch of input images.
200
+
201
+ Args:
202
+ images: List of PIL Image objects
203
+ texts: List of text labels
204
+ multilabel: If True, use sigmoid for multilabel classification.
205
+ If False, use softmax for single-label classification.
206
+
207
+ Returns:
208
+ Tensor of prediction probabilities
209
+ """
210
+ # Prepare inputs
211
+ images = torch.stack([self.preprocess(img) for img in images]).to(self.device)
212
+
213
+ # Process text
214
+ text_tokens = self.tokenize(
215
+ texts,
216
+ padding='max_length',
217
+ max_length=self.max_length,
218
+ truncation=True,
219
+ return_tensors='pt'
220
+ )
221
+
222
+ # Move text tensors to the correct device
223
+ text_tokens = {k: v.to(self.device) for k, v in text_tokens.items()}
224
+
225
+ # Run inference
226
+ with torch.no_grad():
227
+ outputs = self.model(image=images, text=text_tokens)
228
+ logits_per_image = outputs[0] @ outputs[1].t() * outputs[2]
229
+
230
+ if multilabel:
231
+ # Use sigmoid for independent probabilities per label
232
+ probs = torch.sigmoid(logits_per_image)
233
+ else:
234
+ # Use softmax for single-label classification
235
+ probs = logits_per_image.softmax(dim=1)
236
+
237
+ return probs
238
+
239
+
pyproject.toml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "MedImageInsights"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = "==3.8.19"
7
+ dependencies = [
8
+ "mlflow==2.14.3",
9
+ "cffi==1.17.1",
10
+ "cloudpickle==3.0.0",
11
+ "colorama==0.4.6",
12
+ "einops==0.8.0",
13
+ "ftfy==6.2.3",
14
+ "fvcore==0.1.5.post20221221",
15
+ "mup==1.0.0",
16
+ "numpy==1.24.4",
17
+ "packaging==24.1",
18
+ "pandas==2.0.3",
19
+ "pyyaml==6.0.2",
20
+ "requests==2.32.3",
21
+ "sentencepiece==0.2.0",
22
+ "tenacity==9.0.0",
23
+ "timm==1.0.9",
24
+ "tornado==6.4.1",
25
+ "transformers==4.46.0",
26
+ # "huggingface-hub==0.26.1",
27
+ "fastapi[standard]>=0.115.3",
28
+ # "opencv-python>=4.10.0.84",
29
+ ]
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mlflow==2.14.3
2
+ cffi==1.17.1
3
+ cloudpickle==3.0.0
4
+ colorama==0.4.6
5
+ einops==0.8.0
6
+ ftfy==6.2.3
7
+ fvcore==0.1.5.post20221221
8
+ mup==1.0.0
9
+ numpy==1.24.4
10
+ packaging==24.1
11
+ pandas==2.0.3
12
+ pyyaml==6.0.2
13
+ requests==2.32.3
14
+ sentencepiece==0.2.0
15
+ tenacity==9.0.0
16
+ timm==1.0.9
17
+ tornado==6.4.1
18
+ transformers==4.16.2
uv.lock ADDED
The diff for this file is too large to render. See raw diff