upload everything
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +24 -0
- MedVersa/pytorch_model.bin +3 -0
- README.md +58 -0
- __pycache__/utils.cpython-39.pyc +0 -0
- demo.py +544 -0
- demo_ex/1de015eb-891f1b02-f90be378-d6af1e86-df3270c2.png +0 -0
- demo_ex/79eee504-b1b60ab8-5e8dd843-b6ed87aa-670747b1.png +0 -0
- demo_ex/Case_00840_0000.nii.gz +3 -0
- demo_ex/Case_01013_0000.nii.gz +3 -0
- demo_ex/ISIC_0032258.jpg +0 -0
- demo_ex/ISIC_0033730.jpg +0 -0
- demo_ex/bc25fa99-0d3766cc-7704edb7-5c7a4a63-dc65480a.png +0 -0
- demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png +0 -0
- demo_ex/f39b05b1-f544e51a-cfe317ca-b66a4aa6-1c1dc22d.png +0 -0
- demo_ex/f3fefc29-68544ac8-284b820d-858b5470-f579b982.png +0 -0
- environment.yml +479 -0
- inference.py +107 -0
- medomni/__init__.py +31 -0
- medomni/__pycache__/__init__.cpython-311.pyc +0 -0
- medomni/__pycache__/__init__.cpython-39.pyc +0 -0
- medomni/common/__init__.py +0 -0
- medomni/common/__pycache__/__init__.cpython-39.pyc +0 -0
- medomni/common/__pycache__/config.cpython-39.pyc +0 -0
- medomni/common/__pycache__/dist_utils.cpython-39.pyc +0 -0
- medomni/common/__pycache__/logger.cpython-39.pyc +0 -0
- medomni/common/__pycache__/optims.cpython-39.pyc +0 -0
- medomni/common/__pycache__/registry.cpython-39.pyc +0 -0
- medomni/common/__pycache__/utils.cpython-39.pyc +0 -0
- medomni/common/config.py +468 -0
- medomni/common/dist_utils.py +137 -0
- medomni/common/gradcam.py +24 -0
- medomni/common/logger.py +200 -0
- medomni/common/optims.py +119 -0
- medomni/common/registry.py +327 -0
- medomni/common/utils.py +424 -0
- medomni/configs/datasets/medinterp/align.yaml +5 -0
- medomni/configs/default.yaml +5 -0
- medomni/configs/models/medomni.yaml +12 -0
- medomni/conversation/__init__.py +0 -0
- medomni/conversation/__pycache__/__init__.cpython-39.pyc +0 -0
- medomni/conversation/__pycache__/conversation.cpython-39.pyc +0 -0
- medomni/conversation/conversation.py +222 -0
- medomni/datasets/__init__.py +0 -0
- medomni/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- medomni/datasets/__pycache__/data_utils.cpython-39.pyc +0 -0
- medomni/datasets/builders/__init__.py +71 -0
- medomni/datasets/builders/__pycache__/__init__.cpython-39.pyc +0 -0
- medomni/datasets/builders/__pycache__/base_dataset_builder.cpython-39.pyc +0 -0
- medomni/datasets/builders/__pycache__/image_text_pair_builder.cpython-39.pyc +0 -0
- medomni/datasets/builders/base_dataset_builder.py +234 -0
LICENSE
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright – President and Fellows of Harvard College, 2024. All Rights Reserved.
|
2 |
+
|
3 |
+
Redistribution and use in source and binary forms, with or without
|
4 |
+
modification, are permitted provided that the following conditions are met:
|
5 |
+
|
6 |
+
Redistributions of source code must retain the above copyright notice, this
|
7 |
+
list of conditions and the following disclaimer. Redistributions in binary
|
8 |
+
form must reproduce the above copyrightnotice, this list of conditions and the
|
9 |
+
following disclaimer in the documentation and/or other materials provided with
|
10 |
+
the distribution. Neither the name "Harvard" nor the names of its contributors
|
11 |
+
may be used to endorse or promote products derived from this software without
|
12 |
+
specific prior written permission.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOTLIMITED TO, THE
|
16 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
17 |
+
ARE DISCLAIMED. IN NO EVENT SHALLTHECOPYRIGHT HOLDER OR CONTRIBUTORS BE
|
18 |
+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
19 |
+
CONSEQUENTIAL DAMAGES(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
20 |
+
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
21 |
+
INTERRUPTION)HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
22 |
+
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
|
23 |
+
OTHERWISE)ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
|
24 |
+
OF THE POSSIBILITY OF SUCH DAMAGE.
|
MedVersa/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b3ce596897168d79649e6d6df128a1b409a0cc878092f00667873be6f4b8c9d3
|
3 |
+
size 13993804625
|
README.md
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: hyzhouMedVersa
|
3 |
+
app_file: demo_inter.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 4.24.0
|
6 |
+
---
|
7 |
+
# MedVersa: An orchestrated medical AI system
|
8 |
+
MedVersa is a compound medical AI system that can coordinate multimodal inputs, orchestrate models and tools for varying tasks, and generate multimodal outputs.
|
9 |
+
|
10 |
+
## Environment
|
11 |
+
MedVersa is written in [Python](https://www.python.org/). It is recommended to configure/manage your python environment using conda. To do this, you need to install the [miniconda](https://docs.anaconda.com/free/miniconda/index.html) or [anaconda](https://www.anaconda.com/) first.
|
12 |
+
|
13 |
+
After installing conda, you need to set up a new conda environment for MedVersa using the provided `environment.yml`:
|
14 |
+
``` shell
|
15 |
+
conda env create -f environment.yml
|
16 |
+
conda activate medversa
|
17 |
+
```
|
18 |
+
|
19 |
+
## Inference
|
20 |
+
``` python
|
21 |
+
from utils import *
|
22 |
+
|
23 |
+
# --- Launch Model ---
|
24 |
+
device = 'cuda:0'
|
25 |
+
model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
|
26 |
+
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device)
|
27 |
+
model.eval()
|
28 |
+
|
29 |
+
# --- Define examples ---
|
30 |
+
examples = [
|
31 |
+
[
|
32 |
+
["./demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png"],
|
33 |
+
"Age:30-40.\nGender:F.\nIndication: ___-year-old female with end-stage renal disease not on dialysis presents with dyspnea. PICC line placement.\nComparison: None.",
|
34 |
+
"How would you characterize the findings from <img0>?",
|
35 |
+
"cxr",
|
36 |
+
"report generation",
|
37 |
+
],
|
38 |
+
]
|
39 |
+
# --- Define hyperparams ---
|
40 |
+
num_beams = 1
|
41 |
+
do_sample = True
|
42 |
+
min_length = 1
|
43 |
+
top_p = 0.9
|
44 |
+
repetition_penalty = 1
|
45 |
+
length_penalty = 1
|
46 |
+
temperature = 0.1
|
47 |
+
|
48 |
+
# --- Generate a report for an chest X-ray image ---
|
49 |
+
index = 0
|
50 |
+
demo_ex = examples[index]
|
51 |
+
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
|
52 |
+
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
53 |
+
print(output_text)
|
54 |
+
```
|
55 |
+
For more details and examples, please refer to `inference.py`.
|
56 |
+
|
57 |
+
## Demo
|
58 |
+
`CUDA_VISIBLE_DEVICES=0 python demo.py --cfg-path medversa.yaml`
|
__pycache__/utils.cpython-39.pyc
ADDED
Binary file (11.2 kB). View file
|
|
demo.py
ADDED
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torchvision.transforms.functional as TF
|
6 |
+
from torchvision import transforms
|
7 |
+
from PIL import Image
|
8 |
+
import skimage.morphology, skimage.io
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
import random
|
12 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
13 |
+
from copy import deepcopy
|
14 |
+
from medomni.common.config import Config
|
15 |
+
from medomni.common.dist_utils import get_rank
|
16 |
+
from medomni.common.registry import registry
|
17 |
+
import torchio as tio
|
18 |
+
import nibabel as nib
|
19 |
+
from scipy import ndimage, misc
|
20 |
+
import time
|
21 |
+
import ipdb
|
22 |
+
|
23 |
+
# Function to parse command line arguments
|
24 |
+
def parse_args():
|
25 |
+
parser = argparse.ArgumentParser(description="Demo")
|
26 |
+
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
27 |
+
parser.add_argument(
|
28 |
+
"--options",
|
29 |
+
nargs="+",
|
30 |
+
help="override some settings in the used config, the key-value pair in xxx=yyy format will be merged into config file (deprecate), change to --cfg-options instead.",
|
31 |
+
)
|
32 |
+
args = parser.parse_args()
|
33 |
+
return args
|
34 |
+
|
35 |
+
device = 'cuda:0'
|
36 |
+
# Launch model
|
37 |
+
args = parse_args()
|
38 |
+
cfg = Config(args)
|
39 |
+
|
40 |
+
model_config = cfg.model_cfg
|
41 |
+
model_cls = registry.get_model_class(model_config.arch)
|
42 |
+
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device)
|
43 |
+
model.eval()
|
44 |
+
global global_images
|
45 |
+
global_images = None
|
46 |
+
|
47 |
+
def seg_2d_process(image_path, pred_mask, img_size=224):
|
48 |
+
image = cv2.imread(image_path[0])
|
49 |
+
if pred_mask.sum() != 0:
|
50 |
+
labels = skimage.morphology.label(pred_mask)
|
51 |
+
labelCount = np.bincount(labels.ravel())
|
52 |
+
largest_label = np.argmax(labelCount[1:]) + 1
|
53 |
+
pred_mask[labels != largest_label] = 0
|
54 |
+
pred_mask[labels == largest_label] = 255
|
55 |
+
pred_mask = pred_mask.astype(np.uint8)
|
56 |
+
contours, _ = cv2.findContours(pred_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
|
57 |
+
if contours:
|
58 |
+
contours = np.vstack(contours)
|
59 |
+
binary_array = np.zeros((img_size, img_size))
|
60 |
+
binary_array = cv2.drawContours(binary_array, contours, -1, 255, thickness=cv2.FILLED)
|
61 |
+
binary_array = cv2.resize(binary_array, (image.shape[1], image.shape[0]), interpolation = cv2.INTER_NEAREST) / 255
|
62 |
+
image = [Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))]
|
63 |
+
mask = [binary_array]
|
64 |
+
else:
|
65 |
+
image = [Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))]
|
66 |
+
mask = [np.zeros((image.shape[1], image.shape[0]))]
|
67 |
+
else:
|
68 |
+
mask = [np.zeros((image.shape[1], image.shape[0]))]
|
69 |
+
image = [Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))]
|
70 |
+
# output_image = cv2.drawContours(binary_array, contours, -1, (110, 0, 255), 2)
|
71 |
+
# output_image_pil = Image.fromarray(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))
|
72 |
+
return image, mask
|
73 |
+
|
74 |
+
def seg_3d_process(image_path, seg_mask):
|
75 |
+
img = nib.load(image_path[0]).get_fdata()
|
76 |
+
image = window_scan(img).transpose(2,0,1).astype(np.uint8)
|
77 |
+
if seg_mask.sum() != 0:
|
78 |
+
seg_mask = resize_back_volume_abd(seg_mask, image.shape).astype(np.uint8)
|
79 |
+
image_slices = []
|
80 |
+
contour_slices = []
|
81 |
+
for i in range(seg_mask.shape[0]):
|
82 |
+
slice_img = np.fliplr(np.rot90(image[i]))
|
83 |
+
slice_mask = np.fliplr(np.rot90(seg_mask[i]))
|
84 |
+
contours, _ = cv2.findContours(slice_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
|
85 |
+
image_slices.append(Image.fromarray(slice_img))
|
86 |
+
if contours:
|
87 |
+
binary_array = np.zeros(seg_mask.shape[1:])
|
88 |
+
binary_array = cv2.drawContours(binary_array, contours, -1, 255, thickness=cv2.FILLED) / 255
|
89 |
+
binary_array = cv2.resize(binary_array, slice_img.shape, interpolation = cv2.INTER_NEAREST)
|
90 |
+
contour_slices.append(binary_array)
|
91 |
+
else:
|
92 |
+
contour_slices.append(np.zeros_like(slice_img))
|
93 |
+
else:
|
94 |
+
image_slices = []
|
95 |
+
contour_slices = []
|
96 |
+
slice_img = np.fliplr(np.rot90(image[i]))
|
97 |
+
image_slices.append(Image.fromarray(slice_img))
|
98 |
+
contour_slices.append(np.zeros_like(slice_img))
|
99 |
+
|
100 |
+
return image_slices, contour_slices
|
101 |
+
|
102 |
+
def det_2d_process(image_path, box):
|
103 |
+
image_slices = []
|
104 |
+
image = cv2.imread(image_path[0])
|
105 |
+
if box is not None:
|
106 |
+
hi,wd,_ = image.shape
|
107 |
+
color = tuple(np.random.random(size=3) * 256)
|
108 |
+
x1, y1, x2, y2 = int(box[0]*wd), int(box[1]*hi), int(box[2]*wd), int(box[3]*hi)
|
109 |
+
image = cv2.rectangle(image, (x1, y1), (x2, y2), color, 10)
|
110 |
+
image_slices.append(Image.fromarray(image))
|
111 |
+
return image_slices
|
112 |
+
|
113 |
+
def window_scan(scan, window_center=50, window_width=400):
|
114 |
+
"""
|
115 |
+
Apply windowing to a scan.
|
116 |
+
|
117 |
+
Parameters:
|
118 |
+
scan (numpy.ndarray): 3D numpy array of the CT scan
|
119 |
+
window_center (int): The center of the window
|
120 |
+
window_width (int): The width of the window
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
numpy.ndarray: Windowed CT scan
|
124 |
+
"""
|
125 |
+
lower_bound = window_center - (window_width // 2)
|
126 |
+
upper_bound = window_center + (window_width // 2)
|
127 |
+
|
128 |
+
windowed_scan = np.clip(scan, lower_bound, upper_bound)
|
129 |
+
windowed_scan = (windowed_scan - lower_bound) / (upper_bound - lower_bound)
|
130 |
+
windowed_scan = (windowed_scan * 255).astype(np.uint8)
|
131 |
+
|
132 |
+
return windowed_scan
|
133 |
+
|
134 |
+
def task_seg_2d(model, preds, hidden_states, image):
|
135 |
+
token_mask = preds == model.seg_token_idx_2d
|
136 |
+
indices = torch.where(token_mask == True)[0].cpu().numpy()
|
137 |
+
feats = model.model_seg_2d.encoder(image.unsqueeze(0)[:, 0])
|
138 |
+
last_feats = feats[-1]
|
139 |
+
target_states = [hidden_states[ind][-1] for ind in indices]
|
140 |
+
if target_states:
|
141 |
+
target_states = torch.cat(target_states).squeeze()
|
142 |
+
seg_states = model.text2seg_2d(target_states).unsqueeze(0)
|
143 |
+
last_feats = last_feats + seg_states.unsqueeze(-1).unsqueeze(-1)
|
144 |
+
last_feats = model.text2seg_2d_gn(last_feats)
|
145 |
+
feats[-1] = last_feats
|
146 |
+
seg_feats = model.model_seg_2d.decoder(*feats)
|
147 |
+
seg_preds = model.model_seg_2d.segmentation_head(seg_feats)
|
148 |
+
seg_probs = F.sigmoid(seg_preds)
|
149 |
+
seg_mask = seg_probs.cpu().squeeze().numpy() >= 0.5
|
150 |
+
return seg_mask
|
151 |
+
else:
|
152 |
+
return None
|
153 |
+
|
154 |
+
def task_seg_3d(model, preds, hidden_states, img_embeds_list):
|
155 |
+
new_img_embeds_list = deepcopy(img_embeds_list)
|
156 |
+
token_mask = preds == model.seg_token_idx_3d
|
157 |
+
indices = torch.where(token_mask == True)[0].cpu().numpy()
|
158 |
+
target_states = [hidden_states[ind][-1] for ind in indices]
|
159 |
+
if target_states:
|
160 |
+
target_states = torch.cat(target_states).squeeze().unsqueeze(0)
|
161 |
+
seg_states = model.text2seg_3d(target_states)
|
162 |
+
last_feats = new_img_embeds_list[-1]
|
163 |
+
last_feats = last_feats + seg_states.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
164 |
+
last_feats = model.text2seg_3d_gn(last_feats)
|
165 |
+
new_img_embeds_list[-1] = last_feats
|
166 |
+
seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list)
|
167 |
+
seg_probs = F.sigmoid(seg_preds)
|
168 |
+
seg_mask = seg_probs.cpu().squeeze().numpy() >= 0.5
|
169 |
+
return seg_mask
|
170 |
+
|
171 |
+
def task_det_2d(model, preds, hidden_states):
|
172 |
+
token_mask = preds == model.det_token_idx
|
173 |
+
indices = torch.where(token_mask == True)[0].cpu().numpy()
|
174 |
+
target_states = [hidden_states[ind][-1] for ind in indices]
|
175 |
+
if target_states:
|
176 |
+
target_states = torch.cat(target_states).squeeze()
|
177 |
+
det_states = model.text_det(target_states).detach().cpu()
|
178 |
+
return det_states.numpy()
|
179 |
+
return torch.zeros_like(indices)
|
180 |
+
|
181 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
182 |
+
def __init__(self, stops=[]):
|
183 |
+
super().__init__()
|
184 |
+
self.stops = stops
|
185 |
+
|
186 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
187 |
+
for stop in self.stops:
|
188 |
+
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
189 |
+
return True
|
190 |
+
return False
|
191 |
+
|
192 |
+
def resize_back_volume_abd(img, target_size):
|
193 |
+
desired_depth = target_size[0]
|
194 |
+
desired_width = target_size[1]
|
195 |
+
desired_height = target_size[2]
|
196 |
+
|
197 |
+
current_depth = img.shape[0] # [d, w, h]
|
198 |
+
current_width = img.shape[1]
|
199 |
+
current_height = img.shape[2]
|
200 |
+
|
201 |
+
depth = current_depth / desired_depth
|
202 |
+
width = current_width / desired_width
|
203 |
+
height = current_height / desired_height
|
204 |
+
|
205 |
+
depth_factor = 1 / depth
|
206 |
+
width_factor = 1 / width
|
207 |
+
height_factor = 1 / height
|
208 |
+
|
209 |
+
img = ndimage.zoom(img, (depth_factor, width_factor, height_factor), order=0)
|
210 |
+
return img
|
211 |
+
|
212 |
+
def resize_volume_abd(img):
|
213 |
+
img[img<=-200] = -200
|
214 |
+
img[img>=300] = 300
|
215 |
+
|
216 |
+
desired_depth = 64
|
217 |
+
desired_width = 192
|
218 |
+
desired_height = 192
|
219 |
+
|
220 |
+
current_width = img.shape[0] # [w, h, d]
|
221 |
+
current_height = img.shape[1]
|
222 |
+
current_depth = img.shape[2]
|
223 |
+
|
224 |
+
depth = current_depth / desired_depth
|
225 |
+
width = current_width / desired_width
|
226 |
+
height = current_height / desired_height
|
227 |
+
|
228 |
+
depth_factor = 1 / depth
|
229 |
+
width_factor = 1 / width
|
230 |
+
height_factor = 1 / height
|
231 |
+
|
232 |
+
img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=0)
|
233 |
+
return img
|
234 |
+
|
235 |
+
def load_and_preprocess_image(image):
|
236 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
237 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
238 |
+
transform = transforms.Compose([
|
239 |
+
transforms.Resize([224, 224]),
|
240 |
+
transforms.ToTensor(),
|
241 |
+
transforms.Normalize(mean, std)
|
242 |
+
])
|
243 |
+
image = transform(image).type(torch.bfloat16).cuda().unsqueeze(0)
|
244 |
+
return image
|
245 |
+
|
246 |
+
def load_and_preprocess_volume(image):
|
247 |
+
img = nib.load(image).get_fdata()
|
248 |
+
image = torch.from_numpy(resize_volume_abd(img)).permute(2,0,1)
|
249 |
+
transform = tio.Compose([
|
250 |
+
tio.ZNormalization(masking_method=tio.ZNormalization.mean),
|
251 |
+
])
|
252 |
+
image = transform(image.unsqueeze(0)).type(torch.bfloat16).cuda()
|
253 |
+
return image
|
254 |
+
|
255 |
+
def read_image(image_path):
|
256 |
+
if image_path.endswith(('.jpg', '.jpeg', '.png')):
|
257 |
+
return load_and_preprocess_image(Image.open(image_path).convert('RGB'))
|
258 |
+
elif image_path.endswith('.nii.gz'):
|
259 |
+
return load_and_preprocess_volume(image_path)
|
260 |
+
else:
|
261 |
+
raise ValueError("Unsupported file format")
|
262 |
+
|
263 |
+
def generate(image_path, image, context, modal, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
|
264 |
+
if (len(context) != 0 and ('report' in prompt or 'finding' in prompt or 'impression' in prompt)) or (len(context) != 0 and modal=='derm' and ('diagnosis' in prompt or 'issue' in prompt or 'problem' in prompt)):
|
265 |
+
prompt = '<context>' + context + '</context>' + prompt
|
266 |
+
if modal == 'ct' and 'segment' in prompt.lower():
|
267 |
+
if 'liver' in prompt:
|
268 |
+
prompt = 'Segment the liver.'
|
269 |
+
if 'spleen' in prompt:
|
270 |
+
prompt = 'Segment the spleen.'
|
271 |
+
if 'kidney' in prompt:
|
272 |
+
prompt = 'Segment the kidney.'
|
273 |
+
if 'pancrea' in prompt:
|
274 |
+
prompt = 'Segment the pancreas.'
|
275 |
+
img_embeds, atts_img, img_embeds_list = model.encode_img(image.unsqueeze(0), [modal])
|
276 |
+
placeholder = ['<ImageHere>'] * 9
|
277 |
+
prefix = '###Human:' + ''.join([f'<img{i}>' + ''.join(placeholder) + f'</img{i}>' for i in range(num_imgs)])
|
278 |
+
img_embeds, atts_img = model.prompt_wrap(img_embeds, atts_img, [prefix], [num_imgs])
|
279 |
+
prompt += '###Assistant:'
|
280 |
+
prompt_tokens = model.llama_tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(image.device)
|
281 |
+
new_img_embeds, new_atts_img = model.prompt_concat(img_embeds, atts_img, prompt_tokens)
|
282 |
+
|
283 |
+
outputs = model.llama_model.generate(
|
284 |
+
inputs_embeds=new_img_embeds,
|
285 |
+
max_new_tokens=450,
|
286 |
+
stopping_criteria=StoppingCriteriaList([StoppingCriteriaSub(stops=[
|
287 |
+
torch.tensor([835]).type(torch.bfloat16).to(image.device),
|
288 |
+
torch.tensor([2277, 29937]).type(torch.bfloat16).to(image.device)
|
289 |
+
])]),
|
290 |
+
num_beams=num_beams,
|
291 |
+
do_sample=do_sample,
|
292 |
+
min_length=min_length,
|
293 |
+
top_p=top_p,
|
294 |
+
repetition_penalty=repetition_penalty,
|
295 |
+
length_penalty=length_penalty,
|
296 |
+
temperature=temperature,
|
297 |
+
output_hidden_states=True,
|
298 |
+
return_dict_in_generate=True,
|
299 |
+
)
|
300 |
+
|
301 |
+
hidden_states = outputs.hidden_states
|
302 |
+
preds = outputs.sequences[0]
|
303 |
+
output_image = None
|
304 |
+
seg_mask_2d = None
|
305 |
+
seg_mask_3d = None
|
306 |
+
if sum(preds == model.seg_token_idx_2d):
|
307 |
+
seg_mask = task_seg_2d(model, preds, hidden_states, image)
|
308 |
+
output_image, seg_mask_2d = seg_2d_process(image_path, seg_mask)
|
309 |
+
if sum(preds == model.seg_token_idx_3d):
|
310 |
+
seg_mask = task_seg_3d(model, preds, hidden_states, img_embeds_list)
|
311 |
+
output_image, seg_mask_3d = seg_3d_process(image_path, seg_mask)
|
312 |
+
if sum(preds == model.det_token_idx):
|
313 |
+
det_box = task_det_2d(model, preds, hidden_states)
|
314 |
+
output_image = det_2d_process(image_path, det_box)
|
315 |
+
|
316 |
+
if preds[0] == 0: # Remove unknown token <unk> at the beginning
|
317 |
+
preds = preds[1:]
|
318 |
+
if preds[0] == 1: # Remove start token <s> at the beginning
|
319 |
+
preds = preds[1:]
|
320 |
+
|
321 |
+
output_text = model.llama_tokenizer.decode(preds, add_special_tokens=False)
|
322 |
+
output_text = output_text.split('###')[0].split('Assistant:')[-1].strip()
|
323 |
+
|
324 |
+
if 'mel' in output_text and modal == 'derm':
|
325 |
+
output_text = 'The main diagnosis is melanoma.'
|
326 |
+
return output_image, seg_mask_2d, seg_mask_3d, output_text
|
327 |
+
|
328 |
+
def generate_predictions(images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
|
329 |
+
num_imgs = len(images)
|
330 |
+
modal = modality.lower()
|
331 |
+
image_tensors = [read_image(img) for img in images]
|
332 |
+
if modality == 'ct':
|
333 |
+
time.sleep(2)
|
334 |
+
else:
|
335 |
+
time.sleep(1)
|
336 |
+
image_tensor = torch.cat(image_tensors)
|
337 |
+
|
338 |
+
with torch.autocast("cuda"):
|
339 |
+
with torch.no_grad():
|
340 |
+
generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(images, image_tensor, context, modal, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
341 |
+
|
342 |
+
return generated_image, seg_mask_2d, seg_mask_3d, output_text
|
343 |
+
|
344 |
+
my_dict = {}
|
345 |
+
def gradio_interface(chatbot, images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
|
346 |
+
global global_images
|
347 |
+
if not images:
|
348 |
+
image = np.zeros((224, 224, 3), dtype=np.uint8)
|
349 |
+
blank_image = Image.fromarray(image)
|
350 |
+
snapshot = (blank_image, [])
|
351 |
+
global_images = 'none'
|
352 |
+
return [(prompt, "At least one image is required to proceed.")], snapshot, gr.update(maximum=0)
|
353 |
+
if not prompt or not modality:
|
354 |
+
image = np.zeros((224, 224, 3), dtype=np.uint8)
|
355 |
+
blank_image = Image.fromarray(image)
|
356 |
+
snapshot = (blank_image, [])
|
357 |
+
global_images = 'none'
|
358 |
+
return [(prompt, "Please provide prompt and modality to proceed.")], snapshot, gr.update(maximum=0)
|
359 |
+
|
360 |
+
generated_images, seg_mask_2d, seg_mask_3d, output_text = generate_predictions(images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
361 |
+
output_images = []
|
362 |
+
input_images = [np.asarray(Image.open(img.name).convert('RGB')).astype(np.uint8) if img.name.endswith(('.jpg', '.jpeg', '.png')) else f"{img.name} (3D Volume)" for img in images]
|
363 |
+
if generated_images is not None:
|
364 |
+
for generated_image in generated_images:
|
365 |
+
output_images.append(np.asarray(generated_image).astype(np.uint8))
|
366 |
+
snapshot = (output_images[0], [])
|
367 |
+
if seg_mask_2d is not None:
|
368 |
+
snapshot = (output_images[0], [(seg_mask_2d[0], "Mask")])
|
369 |
+
if seg_mask_3d is not None:
|
370 |
+
snapshot = (output_images[0], [(seg_mask_3d[0], "Mask")])
|
371 |
+
else:
|
372 |
+
output_images = input_images.copy()
|
373 |
+
snapshot = (output_images[0], [])
|
374 |
+
|
375 |
+
my_dict['image'] = output_images
|
376 |
+
my_dict['mask'] = None
|
377 |
+
if seg_mask_2d is not None:
|
378 |
+
my_dict['mask'] = seg_mask_2d
|
379 |
+
if seg_mask_3d is not None:
|
380 |
+
my_dict['mask'] = seg_mask_3d
|
381 |
+
|
382 |
+
if global_images != images and (global_images is not None):
|
383 |
+
chatbot = []
|
384 |
+
chatbot.append((prompt, output_text))
|
385 |
+
else:
|
386 |
+
chatbot.append((prompt, output_text))
|
387 |
+
global_images = images
|
388 |
+
|
389 |
+
return chatbot, snapshot, gr.update(maximum=len(output_images)-1)
|
390 |
+
|
391 |
+
# my_dict = {}
|
392 |
+
# def gradio_interface(images, task, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
|
393 |
+
# if not images:
|
394 |
+
# return None, "Error: At least one image is required to proceed."
|
395 |
+
# if not prompt or not task or not modality:
|
396 |
+
# return None, "Error: Please provide prompt, select task and modality to proceed."
|
397 |
+
|
398 |
+
# generated_images, seg_mask_2d, seg_mask_3d, output_text = generate_predictions(images, task, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
399 |
+
# output_images = []
|
400 |
+
|
401 |
+
# input_images = [np.asarray(Image.open(img.name).convert('RGB')).astype(np.uint8) if img.name.endswith(('.jpg', '.jpeg', '.png')) else f"{img.name} (3D Volume)" for img in images]
|
402 |
+
# if generated_images is not None:
|
403 |
+
# for generated_image in generated_images:
|
404 |
+
# output_images.append(np.asarray(generated_image).astype(np.uint8))
|
405 |
+
# snapshot = (output_images[0], [])
|
406 |
+
# if seg_mask_2d is not None:
|
407 |
+
# snapshot = (output_images[0], [(seg_mask_2d[0], "Mask")])
|
408 |
+
# if seg_mask_3d is not None:
|
409 |
+
# snapshot = (output_images[0], [(seg_mask_3d[0], "Mask")])
|
410 |
+
# else:
|
411 |
+
# output_images = input_images.copy()
|
412 |
+
# snapshot = (output_images[0], [])
|
413 |
+
|
414 |
+
# my_dict['image'] = output_images
|
415 |
+
# my_dict['mask'] = None
|
416 |
+
# if seg_mask_2d is not None:
|
417 |
+
# my_dict['mask'] = seg_mask_2d
|
418 |
+
# if seg_mask_3d is not None:
|
419 |
+
# my_dict['mask'] = seg_mask_3d
|
420 |
+
|
421 |
+
# return output_text, snapshot, gr.update(maximum=len(output_images)-1)
|
422 |
+
|
423 |
+
def render(x):
|
424 |
+
if x > len(my_dict['image'])-1:
|
425 |
+
x = len(my_dict['image'])-1
|
426 |
+
if x < 0:
|
427 |
+
x = 0
|
428 |
+
image = my_dict['image'][x]
|
429 |
+
if my_dict['mask'] is None:
|
430 |
+
return (image,[])
|
431 |
+
else:
|
432 |
+
mask = my_dict['mask'][x]
|
433 |
+
value = (image,[(mask, "Mask")])
|
434 |
+
return value
|
435 |
+
|
436 |
+
def update_context_visibility(task):
|
437 |
+
if task == "report generation" or task == 'classification':
|
438 |
+
return gr.update(visible=True)
|
439 |
+
else:
|
440 |
+
return gr.update(visible=False)
|
441 |
+
|
442 |
+
def reset_chatbot():
|
443 |
+
return []
|
444 |
+
|
445 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
446 |
+
# with gr.Row():
|
447 |
+
# gr.Markdown("<link href='https://fonts.googleapis.com/css2?family=Libre+Franklin:wght@400;700&display=swap' rel='stylesheet'>")
|
448 |
+
gr.Markdown("# MedVersa")
|
449 |
+
with gr.Row():
|
450 |
+
with gr.Column():
|
451 |
+
image_input = gr.File(label="Upload Images", file_count="multiple", file_types=["image", "numpy"])
|
452 |
+
# task_input = gr.Dropdown(choices=["report generation", "vqa", "localization", "classification"], label="Task")
|
453 |
+
context_input = gr.Textbox(label="Context", placeholder="Enter context here...", lines=3, visible=True)
|
454 |
+
modality_input = gr.Dropdown(choices=["cxr", "derm", "ct"], label="Modality")
|
455 |
+
prompt_input = gr.Textbox(label="Prompt", placeholder="Enter prompt here... (images should be referred as <img0>, <img1>, ...)", lines=3)
|
456 |
+
submit_button = gr.Button("Generate Predictions")
|
457 |
+
with gr.Accordion("Advanced Settings", open=False):
|
458 |
+
num_beams = gr.Slider(label="Number of Beams", minimum=1, maximum=10, step=1, value=1)
|
459 |
+
do_sample = gr.Checkbox(label="Do Sample", value=True)
|
460 |
+
min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, step=1, value=1)
|
461 |
+
top_p = gr.Slider(label="Top P", minimum=0.1, maximum=1.0, step=0.1, value=0.9)
|
462 |
+
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.0)
|
463 |
+
length_penalty = gr.Slider(label="Length Penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.0)
|
464 |
+
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.1)
|
465 |
+
|
466 |
+
with gr.Column():
|
467 |
+
# output_text = gr.Textbox(label="Generated Text", lines=10, elem_classes="output-textbox")
|
468 |
+
chatbot = gr.Chatbot(label="Chatbox")
|
469 |
+
slider = gr.Slider(minimum=0, maximum=64, value=1, step=1)
|
470 |
+
output_image = gr.AnnotatedImage(height=448, label="Images")
|
471 |
+
|
472 |
+
# task_input.change(
|
473 |
+
# fn=update_context_visibility,
|
474 |
+
# inputs=task_input,
|
475 |
+
# outputs=context_input
|
476 |
+
# )
|
477 |
+
|
478 |
+
submit_button.click(
|
479 |
+
fn=gradio_interface,
|
480 |
+
inputs=[chatbot, image_input, context_input, prompt_input, modality_input, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature],
|
481 |
+
outputs=[chatbot, output_image, slider]
|
482 |
+
)
|
483 |
+
|
484 |
+
slider.change(
|
485 |
+
render,
|
486 |
+
inputs=[slider],
|
487 |
+
outputs=[output_image],
|
488 |
+
)
|
489 |
+
|
490 |
+
examples = [
|
491 |
+
[
|
492 |
+
["./demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png"],
|
493 |
+
"Age:30-40.\nGender:F.\nIndication: ___-year-old female with end-stage renal disease not on dialysis presents with dyspnea. PICC line placement.\nComparison: None.",
|
494 |
+
"How would you characterize the findings from <img0>?",
|
495 |
+
"cxr",
|
496 |
+
],
|
497 |
+
[
|
498 |
+
["./demo_ex/79eee504-b1b60ab8-5e8dd843-b6ed87aa-670747b1.png"],
|
499 |
+
"Age:70-80.\nGender:F.\nIndication: Respiratory distress.\nComparison: None.",
|
500 |
+
"How would you characterize the findings from <img0>?",
|
501 |
+
"cxr",
|
502 |
+
],
|
503 |
+
[
|
504 |
+
["./demo_ex/f39b05b1-f544e51a-cfe317ca-b66a4aa6-1c1dc22d.png", "./demo_ex/f3fefc29-68544ac8-284b820d-858b5470-f579b982.png"],
|
505 |
+
"Age:80-90.\nGender:F.\nIndication: ___-year-old female with history of chest pain.\nComparison: None.",
|
506 |
+
"How would you characterize the findings from <img0><img1>?",
|
507 |
+
"cxr",
|
508 |
+
],
|
509 |
+
[
|
510 |
+
["./demo_ex/1de015eb-891f1b02-f90be378-d6af1e86-df3270c2.png"],
|
511 |
+
"Age:40-50.\nGender:M.\nIndication: ___-year-old male with shortness of breath.\nComparison: None.",
|
512 |
+
"How would you characterize the findings from <img0>?",
|
513 |
+
"cxr",
|
514 |
+
],
|
515 |
+
[
|
516 |
+
["./demo_ex/bc25fa99-0d3766cc-7704edb7-5c7a4a63-dc65480a.png"],
|
517 |
+
"Age:40-50.\nGender:F.\nIndication: History: ___F with tachyacrdia cough doe // infilatrate\nComparison: None.",
|
518 |
+
"How would you characterize the findings from <img0>?",
|
519 |
+
"cxr",
|
520 |
+
],
|
521 |
+
[
|
522 |
+
["./demo_ex/ISIC_0032258.jpg"],
|
523 |
+
"Age:70.\nGender:female.\nLocation:back.",
|
524 |
+
"What is primary diagnosis?",
|
525 |
+
"derm",
|
526 |
+
],
|
527 |
+
[
|
528 |
+
["./demo_ex/Case_01013_0000.nii.gz"],
|
529 |
+
"",
|
530 |
+
"Segment the liver.",
|
531 |
+
"ct",
|
532 |
+
],
|
533 |
+
[
|
534 |
+
["./demo_ex/Case_00840_0000.nii.gz"],
|
535 |
+
"",
|
536 |
+
"Segment the liver.",
|
537 |
+
"ct",
|
538 |
+
],
|
539 |
+
]
|
540 |
+
|
541 |
+
gr.Examples(examples, inputs=[image_input, context_input, prompt_input, modality_input])
|
542 |
+
|
543 |
+
# Run Gradio app
|
544 |
+
demo.launch(share=True)
|
demo_ex/1de015eb-891f1b02-f90be378-d6af1e86-df3270c2.png
ADDED
demo_ex/79eee504-b1b60ab8-5e8dd843-b6ed87aa-670747b1.png
ADDED
demo_ex/Case_00840_0000.nii.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:27d91a51f4f792740aab30da1416e2a200f637a53e9aa842cf47f2dd96519216
|
3 |
+
size 30618190
|
demo_ex/Case_01013_0000.nii.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:63f597a81e594aa0b5d5b67551658f4a8be831ac6189f2f3f644b0a1098fbb09
|
3 |
+
size 30845920
|
demo_ex/ISIC_0032258.jpg
ADDED
demo_ex/ISIC_0033730.jpg
ADDED
demo_ex/bc25fa99-0d3766cc-7704edb7-5c7a4a63-dc65480a.png
ADDED
demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png
ADDED
demo_ex/f39b05b1-f544e51a-cfe317ca-b66a4aa6-1c1dc22d.png
ADDED
demo_ex/f3fefc29-68544ac8-284b820d-858b5470-f579b982.png
ADDED
environment.yml
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: medversa
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- conda-forge
|
6 |
+
- anaconda
|
7 |
+
- defaults
|
8 |
+
dependencies:
|
9 |
+
- _libgcc_mutex=0.1=main
|
10 |
+
- _openmp_mutex=5.1=1_gnu
|
11 |
+
- abseil-cpp=20211102.0=h27087fc_1
|
12 |
+
- aiosignal=1.3.1=pyhd8ed1ab_0
|
13 |
+
- arrow-cpp=11.0.0=py39h613000e_0
|
14 |
+
- asttokens=2.2.1=pyhd8ed1ab_0
|
15 |
+
- async-timeout=4.0.2=pyhd8ed1ab_0
|
16 |
+
- atk-1.0=2.36.0=ha1a6a79_0
|
17 |
+
- aws-c-common=0.4.57=he6710b0_1
|
18 |
+
- aws-c-event-stream=0.1.6=h2531618_5
|
19 |
+
- aws-checksums=0.1.9=he6710b0_0
|
20 |
+
- aws-sdk-cpp=1.8.185=hce553d0_0
|
21 |
+
- blas=1.0=mkl
|
22 |
+
- boost-cpp=1.70.0=ha2d47e9_1
|
23 |
+
- bottleneck=1.3.5=py39h7deecbd_0
|
24 |
+
- brotlipy=0.7.0=py39h27cfd23_1003
|
25 |
+
- bzip2=1.0.8=h7b6447c_0
|
26 |
+
- c-ares=1.18.1=h7f8727e_0
|
27 |
+
- ca-certificates=2023.12.12=h06a4308_0
|
28 |
+
- cairo=1.16.0=hb05425b_5
|
29 |
+
- certifi=2023.11.17=py39h06a4308_0
|
30 |
+
- cffi=1.15.1=py39h5eee18b_3
|
31 |
+
- colorama=0.4.6=pyhd8ed1ab_0
|
32 |
+
- cryptography=41.0.2=py39h774aba0_0
|
33 |
+
- cuda-cudart=11.7.99=0
|
34 |
+
- cuda-cupti=11.7.101=0
|
35 |
+
- cuda-libraries=11.7.1=0
|
36 |
+
- cuda-nvrtc=11.7.99=0
|
37 |
+
- cuda-nvtx=11.7.91=0
|
38 |
+
- cuda-runtime=11.7.1=0
|
39 |
+
- cudatoolkit=11.7.0=hd8887f6_10
|
40 |
+
- curl=7.87.0=h5eee18b_0
|
41 |
+
- dataclasses=0.8=pyhc8e2a94_3
|
42 |
+
- datasets=2.14.3=pyhd8ed1ab_0
|
43 |
+
- dill=0.3.7=pyhd8ed1ab_0
|
44 |
+
- executing=1.2.0=pyhd8ed1ab_0
|
45 |
+
- expat=2.4.9=h6a678d5_0
|
46 |
+
- ffmpeg=4.3=hf484d3e_0
|
47 |
+
- filelock=3.9.0=py39h06a4308_0
|
48 |
+
- flit-core=3.6.0=pyhd3eb1b0_0
|
49 |
+
- font-ttf-dejavu-sans-mono=2.37=hd3eb1b0_0
|
50 |
+
- font-ttf-inconsolata=2.001=hcb22688_0
|
51 |
+
- font-ttf-source-code-pro=2.030=hd3eb1b0_0
|
52 |
+
- font-ttf-ubuntu=0.83=h8b1ccd4_0
|
53 |
+
- fontconfig=2.14.1=h52c9d5c_1
|
54 |
+
- fonts-anaconda=1=h8fa9717_0
|
55 |
+
- fonts-conda-ecosystem=1=hd3eb1b0_0
|
56 |
+
- freetype=2.12.1=h4a9f257_0
|
57 |
+
- fribidi=1.0.10=h7b6447c_0
|
58 |
+
- frozenlist=1.3.3=py39h5eee18b_0
|
59 |
+
- gdbm=1.18=hd4cb3f1_4
|
60 |
+
- gdk-pixbuf=2.42.10=h5eee18b_0
|
61 |
+
- gettext=0.21.0=hf68c758_0
|
62 |
+
- gflags=2.2.2=he1b5a44_1004
|
63 |
+
- giflib=5.2.1=h5eee18b_3
|
64 |
+
- git=2.34.1=pl5262hc120c5b_0
|
65 |
+
- glib=2.69.1=he621ea3_2
|
66 |
+
- glog=0.5.0=h48cff8f_0
|
67 |
+
- gmp=6.2.1=h295c915_3
|
68 |
+
- gmpy2=2.1.2=py39heeb90bb_0
|
69 |
+
- gnutls=3.6.15=he1e5248_0
|
70 |
+
- gobject-introspection=1.72.0=py39hbb6d50b_2
|
71 |
+
- graphite2=1.3.14=h295c915_1
|
72 |
+
- graphviz=2.50.0=h1b29801_1
|
73 |
+
- grpc-cpp=1.46.1=h33aed49_1
|
74 |
+
- gtk2=2.24.33=h73c1081_2
|
75 |
+
- gts=0.7.6=hb67d8dd_3
|
76 |
+
- harfbuzz=4.3.0=hf52aaf7_1
|
77 |
+
- huggingface_hub=0.16.4=pyhd8ed1ab_0
|
78 |
+
- icu=58.2=he6710b0_3
|
79 |
+
- idna=3.4=py39h06a4308_0
|
80 |
+
- importlib_metadata=6.8.0=hd8ed1ab_0
|
81 |
+
- intel-openmp=2023.1.0=hdb19cb5_46305
|
82 |
+
- jinja2=3.1.2=py39h06a4308_0
|
83 |
+
- jpeg=9e=h5eee18b_1
|
84 |
+
- krb5=1.19.4=h568e23c_0
|
85 |
+
- lame=3.100=h7b6447c_0
|
86 |
+
- lcms2=2.12=h3be6417_0
|
87 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
88 |
+
- lerc=3.0=h295c915_0
|
89 |
+
- libbrotlicommon=1.0.9=h166bdaf_7
|
90 |
+
- libbrotlidec=1.0.9=h166bdaf_7
|
91 |
+
- libbrotlienc=1.0.9=h166bdaf_7
|
92 |
+
- libcublas=11.10.3.66=0
|
93 |
+
- libcufft=10.7.2.124=h4fbf590_0
|
94 |
+
- libcufile=1.7.1.12=0
|
95 |
+
- libcurand=10.3.3.129=0
|
96 |
+
- libcurl=7.87.0=h91b91d3_0
|
97 |
+
- libcusolver=11.4.0.1=0
|
98 |
+
- libcusparse=11.7.4.91=0
|
99 |
+
- libdeflate=1.17=h5eee18b_0
|
100 |
+
- libedit=3.1.20221030=h5eee18b_0
|
101 |
+
- libev=4.33=h7f8727e_1
|
102 |
+
- libevent=2.1.10=h9b69904_4
|
103 |
+
- libffi=3.4.2=h6a678d5_6
|
104 |
+
- libgcc=7.2.0=h69d50b8_2
|
105 |
+
- libgcc-ng=11.2.0=h1234567_1
|
106 |
+
- libgd=2.3.3=h6a678d5_3
|
107 |
+
- libgomp=11.2.0=h1234567_1
|
108 |
+
- libiconv=1.16=h7f8727e_2
|
109 |
+
- libidn2=2.3.4=h5eee18b_0
|
110 |
+
- libnghttp2=1.46.0=hce63b2e_0
|
111 |
+
- libnpp=11.7.4.75=0
|
112 |
+
- libnvjpeg=11.8.0.2=0
|
113 |
+
- libpng=1.6.39=h5eee18b_0
|
114 |
+
- libprotobuf=3.20.3=he621ea3_0
|
115 |
+
- librsvg=2.54.4=h36cc946_2
|
116 |
+
- libssh2=1.10.0=h8f2d780_0
|
117 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
118 |
+
- libtasn1=4.19.0=h5eee18b_0
|
119 |
+
- libthrift=0.15.0=he6d91bd_0
|
120 |
+
- libtiff=4.5.0=h6a678d5_2
|
121 |
+
- libtool=2.4.6=h6a678d5_1009
|
122 |
+
- libunistring=0.9.10=h27cfd23_0
|
123 |
+
- libuuid=1.41.5=h5eee18b_0
|
124 |
+
- libwebp=1.2.4=h11a3e52_1
|
125 |
+
- libwebp-base=1.2.4=h5eee18b_1
|
126 |
+
- libxcb=1.15=h7f8727e_0
|
127 |
+
- libxml2=2.9.14=h74e7548_0
|
128 |
+
- lz4-c=1.9.4=h6a678d5_0
|
129 |
+
- mkl=2023.1.0=h6d00ec8_46342
|
130 |
+
- mkl-service=2.4.0=py39h5eee18b_1
|
131 |
+
- mkl_fft=1.3.6=py39h417a72b_1
|
132 |
+
- mkl_random=1.2.2=py39h417a72b_1
|
133 |
+
- mpc=1.1.0=h10f8cd9_1
|
134 |
+
- mpfr=4.0.2=hb69a4c5_1
|
135 |
+
- mpmath=1.3.0=py39h06a4308_0
|
136 |
+
- ncurses=6.4=h6a678d5_0
|
137 |
+
- nettle=3.7.3=hbbd107a_1
|
138 |
+
- networkx=3.1=py39h06a4308_0
|
139 |
+
- ninja-base=1.10.2=hd09550d_5
|
140 |
+
- numexpr=2.8.4=py39hc78ab66_1
|
141 |
+
- numpy-base=1.25.0=py39hb5e798b_0
|
142 |
+
- openh264=2.1.1=h4ff587b_0
|
143 |
+
- openjpeg=2.4.0=h3ad879b_0
|
144 |
+
- openssl=1.1.1w=h7f8727e_0
|
145 |
+
- orc=1.7.4=hb3bc3d3_1
|
146 |
+
- pango=1.50.7=h05da053_0
|
147 |
+
- pcre=8.45=h295c915_0
|
148 |
+
- pcre2=10.37=he7ceb23_1
|
149 |
+
- perl=5.34.0=h5eee18b_2
|
150 |
+
- pip=23.0.1=py39h06a4308_0
|
151 |
+
- pixman=0.40.0=h7f8727e_1
|
152 |
+
- poppler=0.81.0=h01f5e8b_2
|
153 |
+
- poppler-data=0.4.11=h06a4308_1
|
154 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
155 |
+
- pyopenssl=23.2.0=py39h06a4308_0
|
156 |
+
- pysocks=1.7.1=py39h06a4308_0
|
157 |
+
- python=3.9.16=h7a1cb2a_2
|
158 |
+
- python-dateutil=2.8.2=pyhd8ed1ab_0
|
159 |
+
- python-devtools=0.11.0=pyhd8ed1ab_0
|
160 |
+
- python-graphviz=0.20.1=py39h06a4308_0
|
161 |
+
- python-xxhash=2.0.2=py39h5eee18b_1
|
162 |
+
- python_abi=3.9=2_cp39
|
163 |
+
- pytorch=2.0.1=py3.9_cuda11.7_cudnn8.5.0_0
|
164 |
+
- pytorch-cuda=11.7=h778d358_5
|
165 |
+
- pytorch-mutex=1.0=cuda
|
166 |
+
- pytz=2023.3=pyhd8ed1ab_0
|
167 |
+
- pyyaml=6.0=py39hb9d737c_4
|
168 |
+
- re2=2022.04.01=h27087fc_0
|
169 |
+
- readline=8.2=h5eee18b_0
|
170 |
+
- sacremoses=0.0.53=pyhd8ed1ab_0
|
171 |
+
- setuptools=66.0.0=py39h06a4308_0
|
172 |
+
- six=1.16.0=pyh6c4a22f_0
|
173 |
+
- snappy=1.1.9=h295c915_0
|
174 |
+
- sqlite=3.41.2=h5eee18b_0
|
175 |
+
- sympy=1.11.1=py39h06a4308_0
|
176 |
+
- tbb=2021.8.0=hdb19cb5_0
|
177 |
+
- tk=8.6.12=h1ccaba5_0
|
178 |
+
- tmux=3.2a=h385fc29_0
|
179 |
+
- tokenizers=0.13.2=py39he7d60b5_1
|
180 |
+
- torchtriton=2.0.0=py39
|
181 |
+
- transformers=4.28.1=pyhd8ed1ab_0
|
182 |
+
- typing_extensions=4.4.0=py39h06a4308_0
|
183 |
+
- utf8proc=2.6.1=h27cfd23_0
|
184 |
+
- wheel=0.38.4=py39h06a4308_0
|
185 |
+
- xz=5.2.10=h5eee18b_1
|
186 |
+
- yaml=0.2.5=h7f98852_2
|
187 |
+
- zlib=1.2.13=h5eee18b_0
|
188 |
+
- zstd=1.5.5=hc292b87_0
|
189 |
+
- pip:
|
190 |
+
- absl-py==2.0.0
|
191 |
+
- accelerate==0.16.0
|
192 |
+
- aiofiles==23.1.0
|
193 |
+
- aiohttp==3.8.4
|
194 |
+
- albumentations==1.3.1
|
195 |
+
- altair==4.2.2
|
196 |
+
- antlr4-python3-runtime==4.9.3
|
197 |
+
- anyio==3.6.2
|
198 |
+
- appdirs==1.4.4
|
199 |
+
- apptools==5.2.1
|
200 |
+
- argon2-cffi==21.3.0
|
201 |
+
- argon2-cffi-bindings==21.2.0
|
202 |
+
- argparse==1.4.0
|
203 |
+
- arrow==1.2.3
|
204 |
+
- attrs==22.2.0
|
205 |
+
- backcall==0.2.0
|
206 |
+
- batchgenerators==0.25
|
207 |
+
- beautifulsoup4==4.12.2
|
208 |
+
- bitsandbytes==0.37.0
|
209 |
+
- bitsandbytes-cuda117==0.26.0.post2
|
210 |
+
- bleach==6.0.0
|
211 |
+
- blis==0.7.9
|
212 |
+
- braceexpand==0.1.7
|
213 |
+
- brotli==1.1.0
|
214 |
+
- cachetools==5.3.1
|
215 |
+
- catalogue==2.0.8
|
216 |
+
- cchardet==2.1.7
|
217 |
+
- chardet==5.1.0
|
218 |
+
- charset-normalizer==3.1.0
|
219 |
+
- click==8.1.3
|
220 |
+
- cmake==3.26.3
|
221 |
+
- comm==0.1.3
|
222 |
+
- commonmark==0.9.1
|
223 |
+
- conda-pack==0.6.0
|
224 |
+
- confection==0.0.4
|
225 |
+
- configobj==5.0.8
|
226 |
+
- conllu==4.5.3
|
227 |
+
- contourpy==1.0.7
|
228 |
+
- cpufeature==0.2.1
|
229 |
+
- cycler==0.11.0
|
230 |
+
- cymem==2.0.7
|
231 |
+
- debugpy==1.6.7
|
232 |
+
- decorator==5.1.1
|
233 |
+
- decord==0.6.0
|
234 |
+
- defusedxml==0.7.1
|
235 |
+
- deprecated==1.2.14
|
236 |
+
- docker-pycreds==0.4.0
|
237 |
+
- efficientnet-pytorch==0.7.1
|
238 |
+
- einops==0.6.1
|
239 |
+
- einops-exts==0.0.4
|
240 |
+
- entrypoints==0.4
|
241 |
+
- envisage==7.0.3
|
242 |
+
- et-xmlfile==1.1.0
|
243 |
+
- exceptiongroup==1.2.0
|
244 |
+
- fairscale==0.4.13
|
245 |
+
- fastapi==0.95.1
|
246 |
+
- fastjsonschema==2.16.3
|
247 |
+
- ffmpy==0.3.0
|
248 |
+
- fonttools==4.38.0
|
249 |
+
- fqdn==1.5.1
|
250 |
+
- fschat==0.1.10
|
251 |
+
- fsspec==2023.4.0
|
252 |
+
- ftfy==6.1.1
|
253 |
+
- future==0.18.3
|
254 |
+
- gitdb==4.0.10
|
255 |
+
- gitpython==3.1.31
|
256 |
+
- google-auth==2.23.3
|
257 |
+
- google-auth-oauthlib==1.0.0
|
258 |
+
- gradio==3.23.0
|
259 |
+
- gradio-client==0.0.8
|
260 |
+
- grpcio==1.59.0
|
261 |
+
- h11==0.14.0
|
262 |
+
- h5py==3.9.0
|
263 |
+
- hjson==3.1.0
|
264 |
+
- httpcore==0.17.0
|
265 |
+
- httpx==0.24.0
|
266 |
+
- humanize==4.8.0
|
267 |
+
- hyperlink==21.0.0
|
268 |
+
- imageio==2.33.0
|
269 |
+
- importlib-metadata==6.6.0
|
270 |
+
- importlib-resources==5.12.0
|
271 |
+
- inflate64==1.0.0
|
272 |
+
- iniconfig==2.0.0
|
273 |
+
- iopath==0.1.10
|
274 |
+
- ipdb==0.13.13
|
275 |
+
- ipykernel==6.22.0
|
276 |
+
- ipython==8.12.0
|
277 |
+
- ipython-genutils==0.2.0
|
278 |
+
- isoduration==20.11.0
|
279 |
+
- jedi==0.18.2
|
280 |
+
- joblib==1.2.0
|
281 |
+
- jsonpointer==2.3
|
282 |
+
- jsonschema==4.17.3
|
283 |
+
- jupyter-client==8.2.0
|
284 |
+
- jupyter-core==5.3.0
|
285 |
+
- jupyter-events==0.6.3
|
286 |
+
- jupyter-server==2.5.0
|
287 |
+
- jupyter-server-terminals==0.4.4
|
288 |
+
- jupyterlab-pygments==0.2.2
|
289 |
+
- kiwisolver==1.4.4
|
290 |
+
- langcodes==3.3.0
|
291 |
+
- lazy-loader==0.3
|
292 |
+
- linecache2==1.0.0
|
293 |
+
- linkify-it-py==2.0.0
|
294 |
+
- lit==16.0.2
|
295 |
+
- llvmlite==0.39.1
|
296 |
+
- markdown==3.5
|
297 |
+
- markdown-it-py==2.2.0
|
298 |
+
- markdown2==2.4.8
|
299 |
+
- markupsafe==2.1.2
|
300 |
+
- matplotlib==3.7.0
|
301 |
+
- matplotlib-inline==0.1.6
|
302 |
+
- mdit-py-plugins==0.3.3
|
303 |
+
- mdurl==0.1.2
|
304 |
+
- mistune==2.0.5
|
305 |
+
- multidict==6.0.4
|
306 |
+
- multiprocess==0.70.15
|
307 |
+
- multivolumefile==0.2.3
|
308 |
+
- munch==4.0.0
|
309 |
+
- murmurhash==1.0.9
|
310 |
+
- mypy-extensions==1.0.0
|
311 |
+
- nbclassic==0.5.6
|
312 |
+
- nbclient==0.7.4
|
313 |
+
- nbconvert==7.3.1
|
314 |
+
- nbformat==5.8.0
|
315 |
+
- nest-asyncio==1.5.6
|
316 |
+
- nh3==0.2.11
|
317 |
+
- nibabel==5.1.0
|
318 |
+
- ninja==1.11.1
|
319 |
+
- nltk==3.8.1
|
320 |
+
- nmslib==2.1.1
|
321 |
+
- notebook==6.5.4
|
322 |
+
- notebook-shim==0.2.3
|
323 |
+
- numba==0.56.4
|
324 |
+
- numpy==1.23.5
|
325 |
+
- nvidia-cublas-cu11==11.10.3.66
|
326 |
+
- nvidia-cuda-cupti-cu11==11.7.101
|
327 |
+
- nvidia-cuda-nvrtc-cu11==11.7.99
|
328 |
+
- nvidia-cuda-runtime-cu11==11.7.99
|
329 |
+
- nvidia-cudnn-cu11==8.5.0.96
|
330 |
+
- nvidia-cufft-cu11==10.9.0.58
|
331 |
+
- nvidia-curand-cu11==10.2.10.91
|
332 |
+
- nvidia-cusolver-cu11==11.4.0.1
|
333 |
+
- nvidia-cusparse-cu11==11.7.4.91
|
334 |
+
- nvidia-nccl-cu11==2.14.3
|
335 |
+
- nvidia-nvtx-cu11==11.7.91
|
336 |
+
- oauthlib==3.2.2
|
337 |
+
- omegaconf==2.3.0
|
338 |
+
- open-clip-torch==2.20.0
|
339 |
+
- openai==0.27.0
|
340 |
+
- opencv-python==4.7.0.72
|
341 |
+
- opencv-python-headless==4.8.0.74
|
342 |
+
- openpyxl==3.1.2
|
343 |
+
- orjson==3.8.11
|
344 |
+
- packaging==23.0
|
345 |
+
- pandas==2.0.1
|
346 |
+
- pandocfilters==1.5.0
|
347 |
+
- parso==0.8.3
|
348 |
+
- pathtools==0.1.2
|
349 |
+
- pathy==0.10.1
|
350 |
+
- peft==0.2.0
|
351 |
+
- pexpect==4.8.0
|
352 |
+
- pickleshare==0.7.5
|
353 |
+
- pillow==9.5.0
|
354 |
+
- platformdirs==3.5.0
|
355 |
+
- pluggy==1.3.0
|
356 |
+
- portalocker==2.7.0
|
357 |
+
- preshed==3.0.8
|
358 |
+
- pretrainedmodels==0.7.4
|
359 |
+
- prometheus-client==0.16.0
|
360 |
+
- prompt-toolkit==3.0.38
|
361 |
+
- protobuf==3.20.3
|
362 |
+
- psutil==5.9.4
|
363 |
+
- ptyprocess==0.7.0
|
364 |
+
- pure-eval==0.2.2
|
365 |
+
- py-cpuinfo==9.0.0
|
366 |
+
- py-rsync==0.0.1a0.dev0
|
367 |
+
- py7zr==0.20.8
|
368 |
+
- pyarrow==11.0.0
|
369 |
+
- pyasn1==0.5.0
|
370 |
+
- pyasn1-modules==0.3.0
|
371 |
+
- pybcj==1.0.2
|
372 |
+
- pybind11==2.6.1
|
373 |
+
- pycocoevalcap==1.2
|
374 |
+
- pycocotools==2.0.6
|
375 |
+
- pycryptodomex==3.19.1
|
376 |
+
- pydantic==1.10.7
|
377 |
+
- pydub==0.25.1
|
378 |
+
- pyface==8.0.0
|
379 |
+
- pygments==2.15.1
|
380 |
+
- pynndescent==0.5.10
|
381 |
+
- pyparsing==3.0.9
|
382 |
+
- pyppmd==1.1.0
|
383 |
+
- pyqt5==5.15.10
|
384 |
+
- pyqt5-qt5==5.15.2
|
385 |
+
- pyqt5-sip==12.13.0
|
386 |
+
- pyrsistent==0.19.3
|
387 |
+
- pysbd==0.3.4
|
388 |
+
- pytest==7.4.3
|
389 |
+
- python-json-logger==2.0.7
|
390 |
+
- python-multipart==0.0.6
|
391 |
+
- python-polylabel==0.6
|
392 |
+
- python-rsync==0.1.0
|
393 |
+
- pyzmq==25.0.2
|
394 |
+
- pyzstd==0.15.9
|
395 |
+
- qudida==0.0.4
|
396 |
+
- regex==2022.10.31
|
397 |
+
- requests==2.29.0
|
398 |
+
- requests-oauthlib==1.3.1
|
399 |
+
- rfc3339-validator==0.1.4
|
400 |
+
- rfc3986-validator==0.1.1
|
401 |
+
- rich==12.6.0
|
402 |
+
- rsa==4.9
|
403 |
+
- safetensors==0.3.1
|
404 |
+
- scikit-image==0.22.0
|
405 |
+
- scipy==1.10.1
|
406 |
+
- scispacy==0.5.2
|
407 |
+
- segmentation-models-pytorch==0.3.3
|
408 |
+
- semantic-version==2.10.0
|
409 |
+
- send2trash==1.8.2
|
410 |
+
- sentence-transformers==2.2.2
|
411 |
+
- sentencepiece==0.1.98
|
412 |
+
- sentry-sdk==1.21.0
|
413 |
+
- setproctitle==1.3.2
|
414 |
+
- shapely==2.0.2
|
415 |
+
- shellingham==1.5.4
|
416 |
+
- shortuuid==1.0.11
|
417 |
+
- simpleitk==2.2.1
|
418 |
+
- smart-open==6.3.0
|
419 |
+
- smmap==5.0.0
|
420 |
+
- sniffio==1.3.0
|
421 |
+
- soupsieve==2.4.1
|
422 |
+
- spacy==3.4.4
|
423 |
+
- spacy-legacy==3.0.12
|
424 |
+
- spacy-loggers==1.0.4
|
425 |
+
- srsly==2.4.6
|
426 |
+
- stack-data==0.6.2
|
427 |
+
- starlette==0.26.1
|
428 |
+
- surface-distance-based-measures==0.1
|
429 |
+
- svgwrite==1.4.3
|
430 |
+
- swig==4.1.1
|
431 |
+
- tenacity==8.2.2
|
432 |
+
- tensorboard==2.14.1
|
433 |
+
- tensorboard-data-server==0.7.1
|
434 |
+
- terminado==0.17.1
|
435 |
+
- texttable==1.7.0
|
436 |
+
- thinc==8.1.9
|
437 |
+
- threadpoolctl==3.1.0
|
438 |
+
- tifffile==2023.9.26
|
439 |
+
- timm==0.9.2
|
440 |
+
- tinycss2==1.2.1
|
441 |
+
- tomli==2.0.1
|
442 |
+
- toolz==0.12.0
|
443 |
+
- torchio==0.19.2
|
444 |
+
- torchvision==0.15.2
|
445 |
+
- tornado==6.3.1
|
446 |
+
- tqdm==4.64.1
|
447 |
+
- traceback2==1.4.0
|
448 |
+
- traitlets==5.9.0
|
449 |
+
- traits==6.4.3
|
450 |
+
- traitsui==8.0.0
|
451 |
+
- triton==2.0.0
|
452 |
+
- typer==0.7.0
|
453 |
+
- typing-extensions==4.5.0
|
454 |
+
- typing-inspect==0.8.0
|
455 |
+
- tzdata==2023.3
|
456 |
+
- uc-micro-py==1.0.1
|
457 |
+
- umap-learn==0.5.3
|
458 |
+
- unittest2==1.1.0
|
459 |
+
- unzip==1.0.0
|
460 |
+
- uri-template==1.2.0
|
461 |
+
- urllib3==1.26.15
|
462 |
+
- uvicorn==0.22.0
|
463 |
+
- vtk==9.3.0
|
464 |
+
- wandb==0.15.0
|
465 |
+
- wasabi==0.10.1
|
466 |
+
- wavedrom==2.0.3.post3
|
467 |
+
- wcwidth==0.2.6
|
468 |
+
- webcolors==1.13
|
469 |
+
- webdataset==0.2.48
|
470 |
+
- webencodings==0.5.1
|
471 |
+
- websocket-client==1.5.1
|
472 |
+
- websockets==11.0.2
|
473 |
+
- werkzeug==3.0.0
|
474 |
+
- wrapt==1.16.0
|
475 |
+
- xxhash==3.3.0
|
476 |
+
- yarl==1.8.2
|
477 |
+
- zipp==3.14.0
|
478 |
+
prefix: /home/zhouhy/anaconda3/envs/medversa
|
479 |
+
|
inference.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import *
|
2 |
+
|
3 |
+
# --- Launch Model ---
|
4 |
+
device = 'cuda:0'
|
5 |
+
model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
|
6 |
+
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device)
|
7 |
+
model.eval()
|
8 |
+
|
9 |
+
# --- Define examples ---
|
10 |
+
examples = [
|
11 |
+
[
|
12 |
+
["./demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png"],
|
13 |
+
"Age:30-40.\nGender:F.\nIndication: ___-year-old female with end-stage renal disease not on dialysis presents with dyspnea. PICC line placement.\nComparison: None.",
|
14 |
+
"How would you characterize the findings from <img0>?",
|
15 |
+
"cxr",
|
16 |
+
"report generation",
|
17 |
+
],
|
18 |
+
[
|
19 |
+
["./demo_ex/79eee504-b1b60ab8-5e8dd843-b6ed87aa-670747b1.png"],
|
20 |
+
"Age:70-80.\nGender:F.\nIndication: Respiratory distress.\nComparison: None.",
|
21 |
+
"How would you characterize the findings from <img0>?",
|
22 |
+
"cxr",
|
23 |
+
"report generation",
|
24 |
+
],
|
25 |
+
[
|
26 |
+
["./demo_ex/f39b05b1-f544e51a-cfe317ca-b66a4aa6-1c1dc22d.png", "./demo_ex/f3fefc29-68544ac8-284b820d-858b5470-f579b982.png"],
|
27 |
+
"Age:80-90.\nGender:F.\nIndication: ___-year-old female with history of chest pain.\nComparison: None.",
|
28 |
+
"How would you characterize the findings from <img0><img1>?",
|
29 |
+
"cxr",
|
30 |
+
"report generation",
|
31 |
+
],
|
32 |
+
[
|
33 |
+
["./demo_ex/1de015eb-891f1b02-f90be378-d6af1e86-df3270c2.png"],
|
34 |
+
"Age:40-50.\nGender:M.\nIndication: ___-year-old male with shortness of breath.\nComparison: None.",
|
35 |
+
"How would you characterize the findings from <img0>?",
|
36 |
+
"cxr",
|
37 |
+
"report generation",
|
38 |
+
],
|
39 |
+
[
|
40 |
+
["./demo_ex/bc25fa99-0d3766cc-7704edb7-5c7a4a63-dc65480a.png"],
|
41 |
+
"Age:40-50.\nGender:F.\nIndication: History: ___F with tachyacrdia cough doe // infilatrate\nComparison: None.",
|
42 |
+
"How would you characterize the findings from <img0>?",
|
43 |
+
"cxr",
|
44 |
+
"report generation",
|
45 |
+
],
|
46 |
+
[
|
47 |
+
["./demo_ex/ISIC_0032258.jpg"],
|
48 |
+
"Age:70.\nGender:female.\nLocation:back.",
|
49 |
+
"What is primary diagnosis?",
|
50 |
+
"derm",
|
51 |
+
"classification",
|
52 |
+
],
|
53 |
+
[
|
54 |
+
["./demo_ex/ISIC_0032258.jpg"],
|
55 |
+
"Age:70.\nGender:female.\nLocation:back.",
|
56 |
+
"Segment the lesion.",
|
57 |
+
"derm",
|
58 |
+
"segmentation",
|
59 |
+
],
|
60 |
+
[
|
61 |
+
["./demo_ex/Case_01013_0000.nii.gz"],
|
62 |
+
"",
|
63 |
+
"Segment the liver.",
|
64 |
+
"ct",
|
65 |
+
"segmentation",
|
66 |
+
],
|
67 |
+
[
|
68 |
+
["./demo_ex/Case_00840_0000.nii.gz"],
|
69 |
+
"",
|
70 |
+
"Segment the liver.",
|
71 |
+
"ct",
|
72 |
+
"segmentation",
|
73 |
+
],
|
74 |
+
]
|
75 |
+
# --- Define hyperparams ---
|
76 |
+
num_beams = 1
|
77 |
+
do_sample = True
|
78 |
+
min_length = 1
|
79 |
+
top_p = 0.9
|
80 |
+
repetition_penalty = 1
|
81 |
+
length_penalty = 1
|
82 |
+
temperature = 0.1
|
83 |
+
|
84 |
+
# --- Generate a report for an chest X-ray image ---
|
85 |
+
index = 0
|
86 |
+
demo_ex = examples[index]
|
87 |
+
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
|
88 |
+
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
89 |
+
print(output_text)
|
90 |
+
|
91 |
+
# --- Segment the lesion in the dermatology image ---
|
92 |
+
index = 6
|
93 |
+
demo_ex = examples[index]
|
94 |
+
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
|
95 |
+
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
96 |
+
print(output_text)
|
97 |
+
print(seg_mask_2d[0].shape) # H, W
|
98 |
+
|
99 |
+
# --- Segment the liver in the abdomen CT scan ---
|
100 |
+
index = -2
|
101 |
+
demo_ex = examples[index]
|
102 |
+
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
|
103 |
+
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
104 |
+
print(output_text)
|
105 |
+
print(len(seg_mask_3d)) # Number of slices
|
106 |
+
print(seg_mask_3d[0].shape) # H, W
|
107 |
+
|
medomni/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
|
13 |
+
from medomni.common.registry import registry
|
14 |
+
|
15 |
+
from medomni.datasets.builders import *
|
16 |
+
from medomni.models import *
|
17 |
+
from medomni.processors import *
|
18 |
+
from medomni.tasks import *
|
19 |
+
|
20 |
+
|
21 |
+
root_dir = os.path.dirname(os.path.abspath(__file__))
|
22 |
+
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
|
23 |
+
|
24 |
+
registry.register_path("library_root", root_dir)
|
25 |
+
repo_root = os.path.join(root_dir, "..")
|
26 |
+
registry.register_path("repo_root", repo_root)
|
27 |
+
cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
|
28 |
+
registry.register_path("cache_root", cache_root)
|
29 |
+
|
30 |
+
registry.register("MAX_INT", sys.maxsize)
|
31 |
+
registry.register("SPLIT_NAMES", ["train", "val", "test"])
|
medomni/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (1.69 kB). View file
|
|
medomni/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (1.01 kB). View file
|
|
medomni/common/__init__.py
ADDED
File without changes
|
medomni/common/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (147 Bytes). View file
|
|
medomni/common/__pycache__/config.cpython-39.pyc
ADDED
Binary file (12.1 kB). View file
|
|
medomni/common/__pycache__/dist_utils.cpython-39.pyc
ADDED
Binary file (3.77 kB). View file
|
|
medomni/common/__pycache__/logger.cpython-39.pyc
ADDED
Binary file (6.46 kB). View file
|
|
medomni/common/__pycache__/optims.cpython-39.pyc
ADDED
Binary file (2.99 kB). View file
|
|
medomni/common/__pycache__/registry.cpython-39.pyc
ADDED
Binary file (8.99 kB). View file
|
|
medomni/common/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (12.6 kB). View file
|
|
medomni/common/config.py
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import json
|
10 |
+
from typing import Dict
|
11 |
+
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from medomni.common.registry import registry
|
14 |
+
import ipdb
|
15 |
+
|
16 |
+
class Config:
|
17 |
+
def __init__(self, args):
|
18 |
+
self.config = {}
|
19 |
+
|
20 |
+
self.args = args
|
21 |
+
|
22 |
+
# Register the config and configuration for setup
|
23 |
+
registry.register("configuration", self)
|
24 |
+
|
25 |
+
user_config = self._build_opt_list(self.args.options)
|
26 |
+
|
27 |
+
config = OmegaConf.load(self.args.cfg_path)
|
28 |
+
|
29 |
+
runner_config = self.build_runner_config(config)
|
30 |
+
model_config = self.build_model_config(config, **user_config)
|
31 |
+
dataset_config = self.build_dataset_config(config)
|
32 |
+
|
33 |
+
# Validate the user-provided runner configuration
|
34 |
+
# model and dataset configuration are supposed to be validated by the respective classes
|
35 |
+
# [TODO] validate the model/dataset configuration
|
36 |
+
# self._validate_runner_config(runner_config)
|
37 |
+
|
38 |
+
# Override the default configuration with user options.
|
39 |
+
self.config = OmegaConf.merge(
|
40 |
+
runner_config, model_config, dataset_config, user_config
|
41 |
+
)
|
42 |
+
|
43 |
+
def _validate_runner_config(self, runner_config):
|
44 |
+
"""
|
45 |
+
This method validates the configuration, such that
|
46 |
+
1) all the user specified options are valid;
|
47 |
+
2) no type mismatches between the user specified options and the config.
|
48 |
+
"""
|
49 |
+
runner_config_validator = create_runner_config_validator()
|
50 |
+
runner_config_validator.validate(runner_config)
|
51 |
+
|
52 |
+
def _build_opt_list(self, opts):
|
53 |
+
opts_dot_list = self._convert_to_dot_list(opts)
|
54 |
+
return OmegaConf.from_dotlist(opts_dot_list)
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def build_model_config(config, **kwargs):
|
58 |
+
model = config.get("model", None)
|
59 |
+
assert model is not None, "Missing model configuration file."
|
60 |
+
|
61 |
+
model_cls = registry.get_model_class(model.arch)
|
62 |
+
assert model_cls is not None, f"Model '{model.arch}' has not been registered."
|
63 |
+
|
64 |
+
model_type = kwargs.get("model.model_type", None)
|
65 |
+
if not model_type:
|
66 |
+
model_type = model.get("model_type", None)
|
67 |
+
# else use the model type selected by user.
|
68 |
+
|
69 |
+
assert model_type is not None, "Missing model_type."
|
70 |
+
|
71 |
+
model_config_path = model_cls.default_config_path(model_type=model_type)
|
72 |
+
|
73 |
+
model_config = OmegaConf.create()
|
74 |
+
# hierarchy override, customized config > default config
|
75 |
+
model_config = OmegaConf.merge(
|
76 |
+
model_config,
|
77 |
+
OmegaConf.load(model_config_path),
|
78 |
+
{"model": config["model"]},
|
79 |
+
)
|
80 |
+
|
81 |
+
return model_config
|
82 |
+
|
83 |
+
@staticmethod
|
84 |
+
def build_runner_config(config):
|
85 |
+
return {"run": config.run}
|
86 |
+
|
87 |
+
@staticmethod
|
88 |
+
def build_dataset_config(config):
|
89 |
+
datasets = config.get("datasets", None)
|
90 |
+
if datasets is None:
|
91 |
+
raise KeyError(
|
92 |
+
"Expecting 'datasets' as the root key for dataset configuration."
|
93 |
+
)
|
94 |
+
|
95 |
+
dataset_config = OmegaConf.create()
|
96 |
+
|
97 |
+
for dataset_name in datasets:
|
98 |
+
builder_cls = registry.get_builder_class(dataset_name)
|
99 |
+
|
100 |
+
dataset_config_type = datasets[dataset_name].get("type", "default")
|
101 |
+
dataset_config_path = builder_cls.default_config_path(
|
102 |
+
type=dataset_config_type
|
103 |
+
)
|
104 |
+
|
105 |
+
# hierarchy override, customized config > default config
|
106 |
+
dataset_config = OmegaConf.merge(
|
107 |
+
dataset_config,
|
108 |
+
OmegaConf.load(dataset_config_path),
|
109 |
+
{"datasets": {dataset_name: config["datasets"][dataset_name]}},
|
110 |
+
)
|
111 |
+
|
112 |
+
return dataset_config
|
113 |
+
|
114 |
+
def _convert_to_dot_list(self, opts):
|
115 |
+
if opts is None:
|
116 |
+
opts = []
|
117 |
+
|
118 |
+
if len(opts) == 0:
|
119 |
+
return opts
|
120 |
+
|
121 |
+
has_equal = opts[0].find("=") != -1
|
122 |
+
|
123 |
+
if has_equal:
|
124 |
+
return opts
|
125 |
+
|
126 |
+
return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
|
127 |
+
|
128 |
+
def get_config(self):
|
129 |
+
return self.config
|
130 |
+
|
131 |
+
@property
|
132 |
+
def run_cfg(self):
|
133 |
+
return self.config.run
|
134 |
+
|
135 |
+
@property
|
136 |
+
def datasets_cfg(self):
|
137 |
+
return self.config.datasets
|
138 |
+
|
139 |
+
@property
|
140 |
+
def model_cfg(self):
|
141 |
+
return self.config.model
|
142 |
+
|
143 |
+
def pretty_print(self):
|
144 |
+
logging.info("\n===== Running Parameters =====")
|
145 |
+
logging.info(self._convert_node_to_json(self.config.run))
|
146 |
+
|
147 |
+
logging.info("\n====== Dataset Attributes ======")
|
148 |
+
datasets = self.config.datasets
|
149 |
+
|
150 |
+
for dataset in datasets:
|
151 |
+
if dataset in self.config.datasets:
|
152 |
+
logging.info(f"\n======== {dataset} =======")
|
153 |
+
dataset_config = self.config.datasets[dataset]
|
154 |
+
logging.info(self._convert_node_to_json(dataset_config))
|
155 |
+
else:
|
156 |
+
logging.warning(f"No dataset named '{dataset}' in config. Skipping")
|
157 |
+
|
158 |
+
logging.info(f"\n====== Model Attributes ======")
|
159 |
+
logging.info(self._convert_node_to_json(self.config.model))
|
160 |
+
|
161 |
+
def _convert_node_to_json(self, node):
|
162 |
+
container = OmegaConf.to_container(node, resolve=True)
|
163 |
+
return json.dumps(container, indent=4, sort_keys=True)
|
164 |
+
|
165 |
+
def to_dict(self):
|
166 |
+
return OmegaConf.to_container(self.config)
|
167 |
+
|
168 |
+
|
169 |
+
def node_to_dict(node):
|
170 |
+
return OmegaConf.to_container(node)
|
171 |
+
|
172 |
+
|
173 |
+
class ConfigValidator:
|
174 |
+
"""
|
175 |
+
This is a preliminary implementation to centralize and validate the configuration.
|
176 |
+
May be altered in the future.
|
177 |
+
|
178 |
+
A helper class to validate configurations from yaml file.
|
179 |
+
|
180 |
+
This serves the following purposes:
|
181 |
+
1. Ensure all the options in the yaml are defined, raise error if not.
|
182 |
+
2. when type mismatches are found, the validator will raise an error.
|
183 |
+
3. a central place to store and display helpful messages for supported configurations.
|
184 |
+
|
185 |
+
"""
|
186 |
+
|
187 |
+
class _Argument:
|
188 |
+
def __init__(self, name, choices=None, type=None, help=None):
|
189 |
+
self.name = name
|
190 |
+
self.val = None
|
191 |
+
self.choices = choices
|
192 |
+
self.type = type
|
193 |
+
self.help = help
|
194 |
+
|
195 |
+
def __str__(self):
|
196 |
+
s = f"{self.name}={self.val}"
|
197 |
+
if self.type is not None:
|
198 |
+
s += f", ({self.type})"
|
199 |
+
if self.choices is not None:
|
200 |
+
s += f", choices: {self.choices}"
|
201 |
+
if self.help is not None:
|
202 |
+
s += f", ({self.help})"
|
203 |
+
return s
|
204 |
+
|
205 |
+
def __init__(self, description):
|
206 |
+
self.description = description
|
207 |
+
|
208 |
+
self.arguments = dict()
|
209 |
+
|
210 |
+
self.parsed_args = None
|
211 |
+
|
212 |
+
def __getitem__(self, key):
|
213 |
+
assert self.parsed_args is not None, "No arguments parsed yet."
|
214 |
+
|
215 |
+
return self.parsed_args[key]
|
216 |
+
|
217 |
+
def __str__(self) -> str:
|
218 |
+
return self.format_help()
|
219 |
+
|
220 |
+
def add_argument(self, *args, **kwargs):
|
221 |
+
"""
|
222 |
+
Assume the first argument is the name of the argument.
|
223 |
+
"""
|
224 |
+
self.arguments[args[0]] = self._Argument(*args, **kwargs)
|
225 |
+
|
226 |
+
def validate(self, config=None):
|
227 |
+
"""
|
228 |
+
Convert yaml config (dict-like) to list, required by argparse.
|
229 |
+
"""
|
230 |
+
for k, v in config.items():
|
231 |
+
assert (
|
232 |
+
k in self.arguments
|
233 |
+
), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
|
234 |
+
|
235 |
+
if self.arguments[k].type is not None:
|
236 |
+
try:
|
237 |
+
self.arguments[k].val = self.arguments[k].type(v)
|
238 |
+
except ValueError:
|
239 |
+
raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
|
240 |
+
|
241 |
+
if self.arguments[k].choices is not None:
|
242 |
+
assert (
|
243 |
+
v in self.arguments[k].choices
|
244 |
+
), f"""{k} must be one of {self.arguments[k].choices}."""
|
245 |
+
|
246 |
+
return config
|
247 |
+
|
248 |
+
def format_arguments(self):
|
249 |
+
return str([f"{k}" for k in sorted(self.arguments.keys())])
|
250 |
+
|
251 |
+
def format_help(self):
|
252 |
+
# description + key-value pair string for each argument
|
253 |
+
help_msg = str(self.description)
|
254 |
+
return help_msg + ", available arguments: " + self.format_arguments()
|
255 |
+
|
256 |
+
def print_help(self):
|
257 |
+
# display help message
|
258 |
+
print(self.format_help())
|
259 |
+
|
260 |
+
|
261 |
+
def create_runner_config_validator():
|
262 |
+
validator = ConfigValidator(description="Runner configurations")
|
263 |
+
|
264 |
+
validator.add_argument(
|
265 |
+
"runner",
|
266 |
+
type=str,
|
267 |
+
choices=["runner_base", "runner_iter"],
|
268 |
+
help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
|
269 |
+
runner runs based on iters. Default: runner_base""",
|
270 |
+
)
|
271 |
+
# add argumetns for training dataset ratios
|
272 |
+
validator.add_argument(
|
273 |
+
"train_dataset_ratios",
|
274 |
+
type=Dict[str, float],
|
275 |
+
help="""Ratios of training dataset. This is used in iteration-based runner.
|
276 |
+
Do not support for epoch-based runner because how to define an epoch becomes tricky.
|
277 |
+
Default: None""",
|
278 |
+
)
|
279 |
+
validator.add_argument(
|
280 |
+
"max_iters",
|
281 |
+
type=float,
|
282 |
+
help="Maximum number of iterations to run.",
|
283 |
+
)
|
284 |
+
validator.add_argument(
|
285 |
+
"max_epoch",
|
286 |
+
type=int,
|
287 |
+
help="Maximum number of epochs to run.",
|
288 |
+
)
|
289 |
+
# add arguments for iters_per_inner_epoch
|
290 |
+
validator.add_argument(
|
291 |
+
"iters_per_inner_epoch",
|
292 |
+
type=float,
|
293 |
+
help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
|
294 |
+
)
|
295 |
+
lr_scheds_choices = registry.list_lr_schedulers()
|
296 |
+
validator.add_argument(
|
297 |
+
"lr_sched",
|
298 |
+
type=str,
|
299 |
+
choices=lr_scheds_choices,
|
300 |
+
help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
|
301 |
+
)
|
302 |
+
task_choices = registry.list_tasks()
|
303 |
+
validator.add_argument(
|
304 |
+
"task",
|
305 |
+
type=str,
|
306 |
+
choices=task_choices,
|
307 |
+
help="Task to use, from {}".format(task_choices),
|
308 |
+
)
|
309 |
+
# add arguments for init_lr
|
310 |
+
validator.add_argument(
|
311 |
+
"init_lr",
|
312 |
+
type=float,
|
313 |
+
help="Initial learning rate. This will be the learning rate after warmup and before decay.",
|
314 |
+
)
|
315 |
+
# add arguments for min_lr
|
316 |
+
validator.add_argument(
|
317 |
+
"min_lr",
|
318 |
+
type=float,
|
319 |
+
help="Minimum learning rate (after decay).",
|
320 |
+
)
|
321 |
+
# add arguments for warmup_lr
|
322 |
+
validator.add_argument(
|
323 |
+
"warmup_lr",
|
324 |
+
type=float,
|
325 |
+
help="Starting learning rate for warmup.",
|
326 |
+
)
|
327 |
+
# add arguments for learning rate decay rate
|
328 |
+
validator.add_argument(
|
329 |
+
"lr_decay_rate",
|
330 |
+
type=float,
|
331 |
+
help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
|
332 |
+
)
|
333 |
+
# add arguments for weight decay
|
334 |
+
validator.add_argument(
|
335 |
+
"weight_decay",
|
336 |
+
type=float,
|
337 |
+
help="Weight decay rate.",
|
338 |
+
)
|
339 |
+
# add arguments for training batch size
|
340 |
+
validator.add_argument(
|
341 |
+
"batch_size_train",
|
342 |
+
type=int,
|
343 |
+
help="Training batch size.",
|
344 |
+
)
|
345 |
+
# add arguments for evaluation batch size
|
346 |
+
validator.add_argument(
|
347 |
+
"batch_size_eval",
|
348 |
+
type=int,
|
349 |
+
help="Evaluation batch size, including validation and testing.",
|
350 |
+
)
|
351 |
+
# add arguments for number of workers for data loading
|
352 |
+
validator.add_argument(
|
353 |
+
"num_workers",
|
354 |
+
help="Number of workers for data loading.",
|
355 |
+
)
|
356 |
+
# add arguments for warm up steps
|
357 |
+
validator.add_argument(
|
358 |
+
"warmup_steps",
|
359 |
+
type=int,
|
360 |
+
help="Number of warmup steps. Required if a warmup schedule is used.",
|
361 |
+
)
|
362 |
+
# add arguments for random seed
|
363 |
+
validator.add_argument(
|
364 |
+
"seed",
|
365 |
+
type=int,
|
366 |
+
help="Random seed.",
|
367 |
+
)
|
368 |
+
# add arguments for output directory
|
369 |
+
validator.add_argument(
|
370 |
+
"output_dir",
|
371 |
+
type=str,
|
372 |
+
help="Output directory to save checkpoints and logs.",
|
373 |
+
)
|
374 |
+
# add arguments for whether only use evaluation
|
375 |
+
validator.add_argument(
|
376 |
+
"evaluate",
|
377 |
+
help="Whether to only evaluate the model. If true, training will not be performed.",
|
378 |
+
)
|
379 |
+
# add arguments for splits used for training, e.g. ["train", "val"]
|
380 |
+
validator.add_argument(
|
381 |
+
"train_splits",
|
382 |
+
type=list,
|
383 |
+
help="Splits to use for training.",
|
384 |
+
)
|
385 |
+
# add arguments for splits used for validation, e.g. ["val"]
|
386 |
+
validator.add_argument(
|
387 |
+
"valid_splits",
|
388 |
+
type=list,
|
389 |
+
help="Splits to use for validation. If not provided, will skip the validation.",
|
390 |
+
)
|
391 |
+
# add arguments for splits used for testing, e.g. ["test"]
|
392 |
+
validator.add_argument(
|
393 |
+
"test_splits",
|
394 |
+
type=list,
|
395 |
+
help="Splits to use for testing. If not provided, will skip the testing.",
|
396 |
+
)
|
397 |
+
# add arguments for accumulating gradient for iterations
|
398 |
+
validator.add_argument(
|
399 |
+
"accum_grad_iters",
|
400 |
+
type=int,
|
401 |
+
help="Number of iterations to accumulate gradient for.",
|
402 |
+
)
|
403 |
+
|
404 |
+
# ====== distributed training ======
|
405 |
+
validator.add_argument(
|
406 |
+
"device",
|
407 |
+
type=str,
|
408 |
+
choices=["cpu", "cuda"],
|
409 |
+
help="Device to use. Support 'cuda' or 'cpu' as for now.",
|
410 |
+
)
|
411 |
+
validator.add_argument(
|
412 |
+
"world_size",
|
413 |
+
type=int,
|
414 |
+
help="Number of processes participating in the job.",
|
415 |
+
)
|
416 |
+
validator.add_argument("dist_url", type=str)
|
417 |
+
validator.add_argument("distributed", type=bool)
|
418 |
+
# add arguments to opt using distributed sampler during evaluation or not
|
419 |
+
validator.add_argument(
|
420 |
+
"use_dist_eval_sampler",
|
421 |
+
type=bool,
|
422 |
+
help="Whether to use distributed sampler during evaluation or not.",
|
423 |
+
)
|
424 |
+
|
425 |
+
# ====== task specific ======
|
426 |
+
# generation task specific arguments
|
427 |
+
# add arguments for maximal length of text output
|
428 |
+
validator.add_argument(
|
429 |
+
"max_len",
|
430 |
+
type=int,
|
431 |
+
help="Maximal length of text output.",
|
432 |
+
)
|
433 |
+
# add arguments for minimal length of text output
|
434 |
+
validator.add_argument(
|
435 |
+
"min_len",
|
436 |
+
type=int,
|
437 |
+
help="Minimal length of text output.",
|
438 |
+
)
|
439 |
+
# add arguments number of beams
|
440 |
+
validator.add_argument(
|
441 |
+
"num_beams",
|
442 |
+
type=int,
|
443 |
+
help="Number of beams used for beam search.",
|
444 |
+
)
|
445 |
+
|
446 |
+
# vqa task specific arguments
|
447 |
+
# add arguments for number of answer candidates
|
448 |
+
validator.add_argument(
|
449 |
+
"num_ans_candidates",
|
450 |
+
type=int,
|
451 |
+
help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
|
452 |
+
)
|
453 |
+
# add arguments for inference method
|
454 |
+
validator.add_argument(
|
455 |
+
"inference_method",
|
456 |
+
type=str,
|
457 |
+
choices=["genearte", "rank"],
|
458 |
+
help="""Inference method to use for question answering. If rank, requires a answer list.""",
|
459 |
+
)
|
460 |
+
|
461 |
+
# ====== model specific ======
|
462 |
+
validator.add_argument(
|
463 |
+
"k_test",
|
464 |
+
type=int,
|
465 |
+
help="Number of top k most similar samples from ITC/VTC selection to be tested.",
|
466 |
+
)
|
467 |
+
|
468 |
+
return validator
|
medomni/common/dist_utils.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import datetime
|
9 |
+
import functools
|
10 |
+
import os
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
import timm.models.hub as timm_hub
|
15 |
+
|
16 |
+
|
17 |
+
def setup_for_distributed(is_master):
|
18 |
+
"""
|
19 |
+
This function disables printing when not in master process
|
20 |
+
"""
|
21 |
+
import builtins as __builtin__
|
22 |
+
|
23 |
+
builtin_print = __builtin__.print
|
24 |
+
|
25 |
+
def print(*args, **kwargs):
|
26 |
+
force = kwargs.pop("force", False)
|
27 |
+
if is_master or force:
|
28 |
+
builtin_print(*args, **kwargs)
|
29 |
+
|
30 |
+
__builtin__.print = print
|
31 |
+
|
32 |
+
|
33 |
+
def is_dist_avail_and_initialized():
|
34 |
+
if not dist.is_available():
|
35 |
+
return False
|
36 |
+
if not dist.is_initialized():
|
37 |
+
return False
|
38 |
+
return True
|
39 |
+
|
40 |
+
|
41 |
+
def get_world_size():
|
42 |
+
if not is_dist_avail_and_initialized():
|
43 |
+
return 1
|
44 |
+
return dist.get_world_size()
|
45 |
+
|
46 |
+
|
47 |
+
def get_rank():
|
48 |
+
if not is_dist_avail_and_initialized():
|
49 |
+
return 0
|
50 |
+
return dist.get_rank()
|
51 |
+
|
52 |
+
|
53 |
+
def is_main_process():
|
54 |
+
return get_rank() == 0
|
55 |
+
|
56 |
+
|
57 |
+
def init_distributed_mode(args):
|
58 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
59 |
+
args.rank = int(os.environ["RANK"])
|
60 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
61 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
62 |
+
elif "SLURM_PROCID" in os.environ:
|
63 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
64 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
65 |
+
else:
|
66 |
+
print("Not using distributed mode")
|
67 |
+
args.distributed = False
|
68 |
+
return
|
69 |
+
|
70 |
+
args.distributed = True
|
71 |
+
|
72 |
+
torch.cuda.set_device(args.gpu)
|
73 |
+
args.dist_backend = "nccl"
|
74 |
+
print(
|
75 |
+
"| distributed init (rank {}, world {}): {}".format(
|
76 |
+
args.rank, args.world_size, args.dist_url
|
77 |
+
),
|
78 |
+
flush=True,
|
79 |
+
)
|
80 |
+
torch.distributed.init_process_group(
|
81 |
+
backend=args.dist_backend,
|
82 |
+
init_method=args.dist_url,
|
83 |
+
world_size=args.world_size,
|
84 |
+
rank=args.rank,
|
85 |
+
timeout=datetime.timedelta(
|
86 |
+
days=365
|
87 |
+
), # allow auto-downloading and de-compressing
|
88 |
+
)
|
89 |
+
torch.distributed.barrier()
|
90 |
+
setup_for_distributed(args.rank == 0)
|
91 |
+
|
92 |
+
|
93 |
+
def get_dist_info():
|
94 |
+
if torch.__version__ < "1.0":
|
95 |
+
initialized = dist._initialized
|
96 |
+
else:
|
97 |
+
initialized = dist.is_initialized()
|
98 |
+
if initialized:
|
99 |
+
rank = dist.get_rank()
|
100 |
+
world_size = dist.get_world_size()
|
101 |
+
else: # non-distributed training
|
102 |
+
rank = 0
|
103 |
+
world_size = 1
|
104 |
+
return rank, world_size
|
105 |
+
|
106 |
+
|
107 |
+
def main_process(func):
|
108 |
+
@functools.wraps(func)
|
109 |
+
def wrapper(*args, **kwargs):
|
110 |
+
rank, _ = get_dist_info()
|
111 |
+
if rank == 0:
|
112 |
+
return func(*args, **kwargs)
|
113 |
+
|
114 |
+
return wrapper
|
115 |
+
|
116 |
+
|
117 |
+
def download_cached_file(url, check_hash=True, progress=False):
|
118 |
+
"""
|
119 |
+
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
120 |
+
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
121 |
+
"""
|
122 |
+
|
123 |
+
def get_cached_file_path():
|
124 |
+
# a hack to sync the file path across processes
|
125 |
+
parts = torch.hub.urlparse(url)
|
126 |
+
filename = os.path.basename(parts.path)
|
127 |
+
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
|
128 |
+
|
129 |
+
return cached_file
|
130 |
+
|
131 |
+
if is_main_process():
|
132 |
+
timm_hub.download_cached_file(url, check_hash, progress)
|
133 |
+
|
134 |
+
if is_dist_avail_and_initialized():
|
135 |
+
dist.barrier()
|
136 |
+
|
137 |
+
return get_cached_file_path()
|
medomni/common/gradcam.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from matplotlib import pyplot as plt
|
3 |
+
from scipy.ndimage import filters
|
4 |
+
from skimage import transform as skimage_transform
|
5 |
+
|
6 |
+
|
7 |
+
def getAttMap(img, attMap, blur=True, overlap=True):
|
8 |
+
attMap -= attMap.min()
|
9 |
+
if attMap.max() > 0:
|
10 |
+
attMap /= attMap.max()
|
11 |
+
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
|
12 |
+
if blur:
|
13 |
+
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
|
14 |
+
attMap -= attMap.min()
|
15 |
+
attMap /= attMap.max()
|
16 |
+
cmap = plt.get_cmap("jet")
|
17 |
+
attMapV = cmap(attMap)
|
18 |
+
attMapV = np.delete(attMapV, 3, 2)
|
19 |
+
if overlap:
|
20 |
+
attMap = (
|
21 |
+
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
|
22 |
+
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
|
23 |
+
)
|
24 |
+
return attMap
|
medomni/common/logger.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import datetime
|
9 |
+
import logging
|
10 |
+
import time
|
11 |
+
from collections import defaultdict, deque
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.distributed as dist
|
15 |
+
|
16 |
+
from medomni.common import dist_utils
|
17 |
+
|
18 |
+
|
19 |
+
class SmoothedValue(object):
|
20 |
+
"""Track a series of values and provide access to smoothed values over a
|
21 |
+
window or the global series average.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, window_size=20, fmt=None):
|
25 |
+
if fmt is None:
|
26 |
+
fmt = "{median:.4f} ({global_avgdata_time:.4f})"
|
27 |
+
self.deque = deque(maxlen=window_size)
|
28 |
+
self.total = 0.0
|
29 |
+
self.count = 0
|
30 |
+
self.fmt = fmt
|
31 |
+
|
32 |
+
def update(self, value, n=1):
|
33 |
+
self.deque.append(value)
|
34 |
+
self.count += n
|
35 |
+
self.total += value * n
|
36 |
+
|
37 |
+
def synchronize_between_processes(self):
|
38 |
+
"""
|
39 |
+
Warning: does not synchronize the deque!
|
40 |
+
"""
|
41 |
+
if not dist_utils.is_dist_avail_and_initialized():
|
42 |
+
return
|
43 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
44 |
+
dist.barrier()
|
45 |
+
dist.all_reduce(t)
|
46 |
+
t = t.tolist()
|
47 |
+
self.count = int(t[0])
|
48 |
+
self.total = t[1]
|
49 |
+
|
50 |
+
@property
|
51 |
+
def median(self):
|
52 |
+
d = torch.tensor(list(self.deque))
|
53 |
+
return d.median().item()
|
54 |
+
|
55 |
+
@property
|
56 |
+
def avg(self):
|
57 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
58 |
+
return d.mean().item()
|
59 |
+
|
60 |
+
@property
|
61 |
+
def global_avg(self):
|
62 |
+
return self.total / self.count
|
63 |
+
|
64 |
+
@property
|
65 |
+
def max(self):
|
66 |
+
return max(self.deque)
|
67 |
+
|
68 |
+
@property
|
69 |
+
def value(self):
|
70 |
+
return self.deque[-1]
|
71 |
+
|
72 |
+
def __str__(self):
|
73 |
+
return self.fmt.format(
|
74 |
+
median=self.median,
|
75 |
+
avg=self.avg,
|
76 |
+
global_avg=self.global_avg,
|
77 |
+
max=self.max,
|
78 |
+
value=self.value,
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
class MetricLogger(object):
|
83 |
+
def __init__(self, delimiter="\t"):
|
84 |
+
self.meters = defaultdict(SmoothedValue)
|
85 |
+
self.delimiter = delimiter
|
86 |
+
|
87 |
+
def update(self, **kwargs):
|
88 |
+
for k, v in kwargs.items():
|
89 |
+
if isinstance(v, torch.Tensor):
|
90 |
+
v = v.item()
|
91 |
+
# assert isinstance(v, (float, int))
|
92 |
+
if isinstance(v, (float, int)):
|
93 |
+
self.meters[k].update(v)
|
94 |
+
else:
|
95 |
+
self.meters[k] = v
|
96 |
+
|
97 |
+
def __getattr__(self, attr):
|
98 |
+
if attr in self.meters:
|
99 |
+
return self.meters[attr]
|
100 |
+
if attr in self.__dict__:
|
101 |
+
return self.__dict__[attr]
|
102 |
+
raise AttributeError(
|
103 |
+
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
104 |
+
)
|
105 |
+
|
106 |
+
def __str__(self):
|
107 |
+
loss_str = []
|
108 |
+
for name, meter in self.meters.items():
|
109 |
+
loss_str.append("{}: {}".format(name, str(meter)))
|
110 |
+
return self.delimiter.join(loss_str)
|
111 |
+
|
112 |
+
def global_avg(self):
|
113 |
+
loss_str = []
|
114 |
+
for name, meter in self.meters.items():
|
115 |
+
if not isinstance(meter, str):
|
116 |
+
loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
|
117 |
+
return self.delimiter.join(loss_str)
|
118 |
+
|
119 |
+
def synchronize_between_processes(self):
|
120 |
+
for meter in self.meters.values():
|
121 |
+
if not isinstance(meter, str):
|
122 |
+
meter.synchronize_between_processes()
|
123 |
+
|
124 |
+
def add_meter(self, name, meter):
|
125 |
+
self.meters[name] = meter
|
126 |
+
|
127 |
+
def log_every(self, iterable, print_freq, header=None):
|
128 |
+
i = 0
|
129 |
+
if not header:
|
130 |
+
header = ""
|
131 |
+
start_time = time.time()
|
132 |
+
end = time.time()
|
133 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
134 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
135 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
136 |
+
log_msg = [
|
137 |
+
header,
|
138 |
+
"[{0" + space_fmt + "}/{1}]",
|
139 |
+
"eta: {eta}",
|
140 |
+
"{meters}",
|
141 |
+
"time: {time}",
|
142 |
+
"data: {data}",
|
143 |
+
]
|
144 |
+
if torch.cuda.is_available():
|
145 |
+
log_msg.append("max mem: {memory:.0f}")
|
146 |
+
log_msg = self.delimiter.join(log_msg)
|
147 |
+
MB = 1024.0 * 1024.0
|
148 |
+
for obj in iterable:
|
149 |
+
data_time.update(time.time() - end)
|
150 |
+
yield obj
|
151 |
+
iter_time.update(time.time() - end)
|
152 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
153 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
154 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
155 |
+
if torch.cuda.is_available():
|
156 |
+
print(
|
157 |
+
log_msg.format(
|
158 |
+
i,
|
159 |
+
len(iterable),
|
160 |
+
eta=eta_string,
|
161 |
+
meters=str(self),
|
162 |
+
time=str(iter_time),
|
163 |
+
data=str(data_time),
|
164 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
165 |
+
)
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
print(
|
169 |
+
log_msg.format(
|
170 |
+
i,
|
171 |
+
len(iterable),
|
172 |
+
eta=eta_string,
|
173 |
+
meters=str(self),
|
174 |
+
time=str(iter_time),
|
175 |
+
data=str(data_time),
|
176 |
+
)
|
177 |
+
)
|
178 |
+
i += 1
|
179 |
+
end = time.time()
|
180 |
+
total_time = time.time() - start_time
|
181 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
182 |
+
print(
|
183 |
+
"{} Total time: {} ({:.4f} s / it)".format(
|
184 |
+
header, total_time_str, total_time / len(iterable)
|
185 |
+
)
|
186 |
+
)
|
187 |
+
|
188 |
+
|
189 |
+
class AttrDict(dict):
|
190 |
+
def __init__(self, *args, **kwargs):
|
191 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
192 |
+
self.__dict__ = self
|
193 |
+
|
194 |
+
|
195 |
+
def setup_logger():
|
196 |
+
logging.basicConfig(
|
197 |
+
level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
|
198 |
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
199 |
+
handlers=[logging.StreamHandler()],
|
200 |
+
)
|
medomni/common/optims.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import math
|
9 |
+
|
10 |
+
from medomni.common.registry import registry
|
11 |
+
|
12 |
+
|
13 |
+
@registry.register_lr_scheduler("linear_warmup_step_lr")
|
14 |
+
class LinearWarmupStepLRScheduler:
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
optimizer,
|
18 |
+
max_epoch,
|
19 |
+
min_lr,
|
20 |
+
init_lr,
|
21 |
+
decay_rate=1,
|
22 |
+
warmup_start_lr=-1,
|
23 |
+
warmup_steps=0,
|
24 |
+
**kwargs
|
25 |
+
):
|
26 |
+
self.optimizer = optimizer
|
27 |
+
|
28 |
+
self.max_epoch = max_epoch
|
29 |
+
self.min_lr = min_lr
|
30 |
+
|
31 |
+
self.decay_rate = decay_rate
|
32 |
+
|
33 |
+
self.init_lr = init_lr
|
34 |
+
self.warmup_steps = warmup_steps
|
35 |
+
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
36 |
+
|
37 |
+
def step(self, cur_epoch, cur_step):
|
38 |
+
if cur_epoch == 0:
|
39 |
+
warmup_lr_schedule(
|
40 |
+
step=cur_step,
|
41 |
+
optimizer=self.optimizer,
|
42 |
+
max_step=self.warmup_steps,
|
43 |
+
init_lr=self.warmup_start_lr,
|
44 |
+
max_lr=self.init_lr,
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
step_lr_schedule(
|
48 |
+
epoch=cur_epoch,
|
49 |
+
optimizer=self.optimizer,
|
50 |
+
init_lr=self.init_lr,
|
51 |
+
min_lr=self.min_lr,
|
52 |
+
decay_rate=self.decay_rate,
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
@registry.register_lr_scheduler("linear_warmup_cosine_lr")
|
57 |
+
class LinearWarmupCosineLRScheduler:
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
optimizer,
|
61 |
+
max_epoch,
|
62 |
+
iters_per_epoch,
|
63 |
+
min_lr,
|
64 |
+
init_lr,
|
65 |
+
warmup_steps=0,
|
66 |
+
warmup_start_lr=-1,
|
67 |
+
**kwargs
|
68 |
+
):
|
69 |
+
self.optimizer = optimizer
|
70 |
+
|
71 |
+
self.max_epoch = max_epoch
|
72 |
+
self.iters_per_epoch = iters_per_epoch
|
73 |
+
self.min_lr = min_lr
|
74 |
+
|
75 |
+
self.init_lr = init_lr
|
76 |
+
self.warmup_steps = warmup_steps
|
77 |
+
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
78 |
+
|
79 |
+
def step(self, cur_epoch, cur_step):
|
80 |
+
total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
|
81 |
+
if total_cur_step < self.warmup_steps:
|
82 |
+
warmup_lr_schedule(
|
83 |
+
step=cur_step,
|
84 |
+
optimizer=self.optimizer,
|
85 |
+
max_step=self.warmup_steps,
|
86 |
+
init_lr=self.warmup_start_lr,
|
87 |
+
max_lr=self.init_lr,
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
cosine_lr_schedule(
|
91 |
+
epoch=total_cur_step,
|
92 |
+
optimizer=self.optimizer,
|
93 |
+
max_epoch=self.max_epoch * self.iters_per_epoch,
|
94 |
+
init_lr=self.init_lr,
|
95 |
+
min_lr=self.min_lr,
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
|
100 |
+
"""Decay the learning rate"""
|
101 |
+
lr = (init_lr - min_lr) * 0.5 * (
|
102 |
+
1.0 + math.cos(math.pi * epoch / max_epoch)
|
103 |
+
) + min_lr
|
104 |
+
for param_group in optimizer.param_groups:
|
105 |
+
param_group["lr"] = lr
|
106 |
+
|
107 |
+
|
108 |
+
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
|
109 |
+
"""Warmup the learning rate"""
|
110 |
+
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
|
111 |
+
for param_group in optimizer.param_groups:
|
112 |
+
param_group["lr"] = lr
|
113 |
+
|
114 |
+
|
115 |
+
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
|
116 |
+
"""Decay the learning rate"""
|
117 |
+
lr = max(min_lr, init_lr * (decay_rate**epoch))
|
118 |
+
for param_group in optimizer.param_groups:
|
119 |
+
param_group["lr"] = lr
|
medomni/common/registry.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
class Registry:
|
9 |
+
mapping = {
|
10 |
+
"builder_name_mapping": {},
|
11 |
+
"task_name_mapping": {},
|
12 |
+
"processor_name_mapping": {},
|
13 |
+
"model_name_mapping": {},
|
14 |
+
"lr_scheduler_name_mapping": {},
|
15 |
+
"runner_name_mapping": {},
|
16 |
+
"state": {},
|
17 |
+
"paths": {},
|
18 |
+
}
|
19 |
+
|
20 |
+
@classmethod
|
21 |
+
def register_builder(cls, name):
|
22 |
+
r"""Register a dataset builder to registry with key 'name'
|
23 |
+
|
24 |
+
Args:
|
25 |
+
name: Key with which the builder will be registered.
|
26 |
+
|
27 |
+
Usage:
|
28 |
+
|
29 |
+
from medomni.common.registry import registry
|
30 |
+
from medomni.datasets.base_dataset_builder import BaseDatasetBuilder
|
31 |
+
"""
|
32 |
+
|
33 |
+
def wrap(builder_cls):
|
34 |
+
from medomni.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
35 |
+
|
36 |
+
assert issubclass(
|
37 |
+
builder_cls, BaseDatasetBuilder
|
38 |
+
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
|
39 |
+
builder_cls
|
40 |
+
)
|
41 |
+
if name in cls.mapping["builder_name_mapping"]:
|
42 |
+
raise KeyError(
|
43 |
+
"Name '{}' already registered for {}.".format(
|
44 |
+
name, cls.mapping["builder_name_mapping"][name]
|
45 |
+
)
|
46 |
+
)
|
47 |
+
cls.mapping["builder_name_mapping"][name] = builder_cls
|
48 |
+
return builder_cls
|
49 |
+
|
50 |
+
return wrap
|
51 |
+
|
52 |
+
@classmethod
|
53 |
+
def register_task(cls, name):
|
54 |
+
r"""Register a task to registry with key 'name'
|
55 |
+
|
56 |
+
Args:
|
57 |
+
name: Key with which the task will be registered.
|
58 |
+
|
59 |
+
Usage:
|
60 |
+
|
61 |
+
from medomni.common.registry import registry
|
62 |
+
"""
|
63 |
+
|
64 |
+
def wrap(task_cls):
|
65 |
+
from medomni.tasks.base_task import BaseTask
|
66 |
+
|
67 |
+
assert issubclass(
|
68 |
+
task_cls, BaseTask
|
69 |
+
), "All tasks must inherit BaseTask class"
|
70 |
+
if name in cls.mapping["task_name_mapping"]:
|
71 |
+
raise KeyError(
|
72 |
+
"Name '{}' already registered for {}.".format(
|
73 |
+
name, cls.mapping["task_name_mapping"][name]
|
74 |
+
)
|
75 |
+
)
|
76 |
+
cls.mapping["task_name_mapping"][name] = task_cls
|
77 |
+
return task_cls
|
78 |
+
|
79 |
+
return wrap
|
80 |
+
|
81 |
+
@classmethod
|
82 |
+
def register_model(cls, name):
|
83 |
+
r"""Register a task to registry with key 'name'
|
84 |
+
|
85 |
+
Args:
|
86 |
+
name: Key with which the task will be registered.
|
87 |
+
|
88 |
+
Usage:
|
89 |
+
|
90 |
+
from medomni.common.registry import registry
|
91 |
+
"""
|
92 |
+
|
93 |
+
def wrap(model_cls):
|
94 |
+
from medomni.models import BaseModel
|
95 |
+
|
96 |
+
assert issubclass(
|
97 |
+
model_cls, BaseModel
|
98 |
+
), "All models must inherit BaseModel class"
|
99 |
+
if name in cls.mapping["model_name_mapping"]:
|
100 |
+
raise KeyError(
|
101 |
+
"Name '{}' already registered for {}.".format(
|
102 |
+
name, cls.mapping["model_name_mapping"][name]
|
103 |
+
)
|
104 |
+
)
|
105 |
+
cls.mapping["model_name_mapping"][name] = model_cls
|
106 |
+
return model_cls
|
107 |
+
|
108 |
+
return wrap
|
109 |
+
|
110 |
+
@classmethod
|
111 |
+
def register_processor(cls, name):
|
112 |
+
r"""Register a processor to registry with key 'name'
|
113 |
+
|
114 |
+
Args:
|
115 |
+
name: Key with which the task will be registered.
|
116 |
+
|
117 |
+
Usage:
|
118 |
+
|
119 |
+
from medomni.common.registry import registry
|
120 |
+
"""
|
121 |
+
|
122 |
+
def wrap(processor_cls):
|
123 |
+
from medomni.processors import BaseProcessor
|
124 |
+
|
125 |
+
assert issubclass(
|
126 |
+
processor_cls, BaseProcessor
|
127 |
+
), "All processors must inherit BaseProcessor class"
|
128 |
+
if name in cls.mapping["processor_name_mapping"]:
|
129 |
+
raise KeyError(
|
130 |
+
"Name '{}' already registered for {}.".format(
|
131 |
+
name, cls.mapping["processor_name_mapping"][name]
|
132 |
+
)
|
133 |
+
)
|
134 |
+
cls.mapping["processor_name_mapping"][name] = processor_cls
|
135 |
+
return processor_cls
|
136 |
+
|
137 |
+
return wrap
|
138 |
+
|
139 |
+
@classmethod
|
140 |
+
def register_lr_scheduler(cls, name):
|
141 |
+
r"""Register a model to registry with key 'name'
|
142 |
+
|
143 |
+
Args:
|
144 |
+
name: Key with which the task will be registered.
|
145 |
+
|
146 |
+
Usage:
|
147 |
+
|
148 |
+
from medomni.common.registry import registry
|
149 |
+
"""
|
150 |
+
|
151 |
+
def wrap(lr_sched_cls):
|
152 |
+
if name in cls.mapping["lr_scheduler_name_mapping"]:
|
153 |
+
raise KeyError(
|
154 |
+
"Name '{}' already registered for {}.".format(
|
155 |
+
name, cls.mapping["lr_scheduler_name_mapping"][name]
|
156 |
+
)
|
157 |
+
)
|
158 |
+
cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
|
159 |
+
return lr_sched_cls
|
160 |
+
|
161 |
+
return wrap
|
162 |
+
|
163 |
+
@classmethod
|
164 |
+
def register_runner(cls, name):
|
165 |
+
r"""Register a model to registry with key 'name'
|
166 |
+
|
167 |
+
Args:
|
168 |
+
name: Key with which the task will be registered.
|
169 |
+
|
170 |
+
Usage:
|
171 |
+
|
172 |
+
from medomni.common.registry import registry
|
173 |
+
"""
|
174 |
+
|
175 |
+
def wrap(runner_cls):
|
176 |
+
if name in cls.mapping["runner_name_mapping"]:
|
177 |
+
raise KeyError(
|
178 |
+
"Name '{}' already registered for {}.".format(
|
179 |
+
name, cls.mapping["runner_name_mapping"][name]
|
180 |
+
)
|
181 |
+
)
|
182 |
+
cls.mapping["runner_name_mapping"][name] = runner_cls
|
183 |
+
return runner_cls
|
184 |
+
|
185 |
+
return wrap
|
186 |
+
|
187 |
+
@classmethod
|
188 |
+
def register_path(cls, name, path):
|
189 |
+
r"""Register a path to registry with key 'name'
|
190 |
+
|
191 |
+
Args:
|
192 |
+
name: Key with which the path will be registered.
|
193 |
+
|
194 |
+
Usage:
|
195 |
+
|
196 |
+
from medomni.common.registry import registry
|
197 |
+
"""
|
198 |
+
assert isinstance(path, str), "All path must be str."
|
199 |
+
if name in cls.mapping["paths"]:
|
200 |
+
raise KeyError("Name '{}' already registered.".format(name))
|
201 |
+
cls.mapping["paths"][name] = path
|
202 |
+
|
203 |
+
@classmethod
|
204 |
+
def register(cls, name, obj):
|
205 |
+
r"""Register an item to registry with key 'name'
|
206 |
+
|
207 |
+
Args:
|
208 |
+
name: Key with which the item will be registered.
|
209 |
+
|
210 |
+
Usage::
|
211 |
+
|
212 |
+
from medomni.common.registry import registry
|
213 |
+
|
214 |
+
registry.register("config", {})
|
215 |
+
"""
|
216 |
+
path = name.split(".")
|
217 |
+
current = cls.mapping["state"]
|
218 |
+
|
219 |
+
for part in path[:-1]:
|
220 |
+
if part not in current:
|
221 |
+
current[part] = {}
|
222 |
+
current = current[part]
|
223 |
+
|
224 |
+
current[path[-1]] = obj
|
225 |
+
|
226 |
+
# @classmethod
|
227 |
+
# def get_trainer_class(cls, name):
|
228 |
+
# return cls.mapping["trainer_name_mapping"].get(name, None)
|
229 |
+
|
230 |
+
@classmethod
|
231 |
+
def get_builder_class(cls, name):
|
232 |
+
return cls.mapping["builder_name_mapping"].get(name, None)
|
233 |
+
|
234 |
+
@classmethod
|
235 |
+
def get_model_class(cls, name):
|
236 |
+
return cls.mapping["model_name_mapping"].get(name, None)
|
237 |
+
|
238 |
+
@classmethod
|
239 |
+
def get_task_class(cls, name):
|
240 |
+
return cls.mapping["task_name_mapping"].get(name, None)
|
241 |
+
|
242 |
+
@classmethod
|
243 |
+
def get_processor_class(cls, name):
|
244 |
+
return cls.mapping["processor_name_mapping"].get(name, None)
|
245 |
+
|
246 |
+
@classmethod
|
247 |
+
def get_lr_scheduler_class(cls, name):
|
248 |
+
return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
|
249 |
+
|
250 |
+
@classmethod
|
251 |
+
def get_runner_class(cls, name):
|
252 |
+
return cls.mapping["runner_name_mapping"].get(name, None)
|
253 |
+
|
254 |
+
@classmethod
|
255 |
+
def list_runners(cls):
|
256 |
+
return sorted(cls.mapping["runner_name_mapping"].keys())
|
257 |
+
|
258 |
+
@classmethod
|
259 |
+
def list_models(cls):
|
260 |
+
return sorted(cls.mapping["model_name_mapping"].keys())
|
261 |
+
|
262 |
+
@classmethod
|
263 |
+
def list_tasks(cls):
|
264 |
+
return sorted(cls.mapping["task_name_mapping"].keys())
|
265 |
+
|
266 |
+
@classmethod
|
267 |
+
def list_processors(cls):
|
268 |
+
return sorted(cls.mapping["processor_name_mapping"].keys())
|
269 |
+
|
270 |
+
@classmethod
|
271 |
+
def list_lr_schedulers(cls):
|
272 |
+
return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
|
273 |
+
|
274 |
+
@classmethod
|
275 |
+
def list_datasets(cls):
|
276 |
+
return sorted(cls.mapping["builder_name_mapping"].keys())
|
277 |
+
|
278 |
+
@classmethod
|
279 |
+
def get_path(cls, name):
|
280 |
+
return cls.mapping["paths"].get(name, None)
|
281 |
+
|
282 |
+
@classmethod
|
283 |
+
def get(cls, name, default=None, no_warning=False):
|
284 |
+
r"""Get an item from registry with key 'name'
|
285 |
+
|
286 |
+
Args:
|
287 |
+
name (string): Key whose value needs to be retrieved.
|
288 |
+
default: If passed and key is not in registry, default value will
|
289 |
+
be returned with a warning. Default: None
|
290 |
+
no_warning (bool): If passed as True, warning when key doesn't exist
|
291 |
+
will not be generated. Useful for MMF's
|
292 |
+
internal operations. Default: False
|
293 |
+
"""
|
294 |
+
original_name = name
|
295 |
+
name = name.split(".")
|
296 |
+
value = cls.mapping["state"]
|
297 |
+
for subname in name:
|
298 |
+
value = value.get(subname, default)
|
299 |
+
if value is default:
|
300 |
+
break
|
301 |
+
|
302 |
+
if (
|
303 |
+
"writer" in cls.mapping["state"]
|
304 |
+
and value == default
|
305 |
+
and no_warning is False
|
306 |
+
):
|
307 |
+
cls.mapping["state"]["writer"].warning(
|
308 |
+
"Key {} is not present in registry, returning default value "
|
309 |
+
"of {}".format(original_name, default)
|
310 |
+
)
|
311 |
+
return value
|
312 |
+
|
313 |
+
@classmethod
|
314 |
+
def unregister(cls, name):
|
315 |
+
r"""Remove an item from registry with key 'name'
|
316 |
+
|
317 |
+
Args:
|
318 |
+
name: Key which needs to be removed.
|
319 |
+
Usage::
|
320 |
+
|
321 |
+
from mmf.common.registry import registry
|
322 |
+
|
323 |
+
config = registry.unregister("config")
|
324 |
+
"""
|
325 |
+
return cls.mapping["state"].pop(name, None)
|
326 |
+
|
327 |
+
registry = Registry()
|
medomni/common/utils.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import io
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import pickle
|
13 |
+
import re
|
14 |
+
import shutil
|
15 |
+
import urllib
|
16 |
+
import urllib.error
|
17 |
+
import urllib.request
|
18 |
+
from typing import Optional
|
19 |
+
from urllib.parse import urlparse
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import pandas as pd
|
23 |
+
import yaml
|
24 |
+
from iopath.common.download import download
|
25 |
+
from iopath.common.file_io import file_lock, g_pathmgr
|
26 |
+
from medomni.common.registry import registry
|
27 |
+
from torch.utils.model_zoo import tqdm
|
28 |
+
from torchvision.datasets.utils import (
|
29 |
+
check_integrity,
|
30 |
+
download_file_from_google_drive,
|
31 |
+
extract_archive,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def now():
|
36 |
+
from datetime import datetime
|
37 |
+
|
38 |
+
return datetime.now().strftime("%Y%m%d%H%M")[:-1]
|
39 |
+
|
40 |
+
|
41 |
+
def is_url(url_or_filename):
|
42 |
+
parsed = urlparse(url_or_filename)
|
43 |
+
return parsed.scheme in ("http", "https")
|
44 |
+
|
45 |
+
|
46 |
+
def get_cache_path(rel_path):
|
47 |
+
return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
|
48 |
+
|
49 |
+
|
50 |
+
def get_abs_path(rel_path):
|
51 |
+
return os.path.join(registry.get_path("library_root"), rel_path)
|
52 |
+
|
53 |
+
|
54 |
+
def load_json(filename):
|
55 |
+
with open(filename, "r") as f:
|
56 |
+
return json.load(f)
|
57 |
+
|
58 |
+
|
59 |
+
# The following are adapted from torchvision and vissl
|
60 |
+
# torchvision: https://github.com/pytorch/vision
|
61 |
+
# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
|
62 |
+
|
63 |
+
|
64 |
+
def makedir(dir_path):
|
65 |
+
"""
|
66 |
+
Create the directory if it does not exist.
|
67 |
+
"""
|
68 |
+
is_success = False
|
69 |
+
try:
|
70 |
+
if not g_pathmgr.exists(dir_path):
|
71 |
+
g_pathmgr.mkdirs(dir_path)
|
72 |
+
is_success = True
|
73 |
+
except BaseException:
|
74 |
+
print(f"Error creating directory: {dir_path}")
|
75 |
+
return is_success
|
76 |
+
|
77 |
+
|
78 |
+
def get_redirected_url(url: str):
|
79 |
+
"""
|
80 |
+
Given a URL, returns the URL it redirects to or the
|
81 |
+
original URL in case of no indirection
|
82 |
+
"""
|
83 |
+
import requests
|
84 |
+
|
85 |
+
with requests.Session() as session:
|
86 |
+
with session.get(url, stream=True, allow_redirects=True) as response:
|
87 |
+
if response.history:
|
88 |
+
return response.url
|
89 |
+
else:
|
90 |
+
return url
|
91 |
+
|
92 |
+
|
93 |
+
def to_google_drive_download_url(view_url: str) -> str:
|
94 |
+
"""
|
95 |
+
Utility function to transform a view URL of google drive
|
96 |
+
to a download URL for google drive
|
97 |
+
Example input:
|
98 |
+
https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
|
99 |
+
Example output:
|
100 |
+
https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
|
101 |
+
"""
|
102 |
+
splits = view_url.split("/")
|
103 |
+
assert splits[-1] == "view"
|
104 |
+
file_id = splits[-2]
|
105 |
+
return f"https://drive.google.com/uc?export=download&id={file_id}"
|
106 |
+
|
107 |
+
|
108 |
+
def download_google_drive_url(url: str, output_path: str, output_file_name: str):
|
109 |
+
"""
|
110 |
+
Download a file from google drive
|
111 |
+
Downloading an URL from google drive requires confirmation when
|
112 |
+
the file of the size is too big (google drive notifies that
|
113 |
+
anti-viral checks cannot be performed on such files)
|
114 |
+
"""
|
115 |
+
import requests
|
116 |
+
|
117 |
+
with requests.Session() as session:
|
118 |
+
|
119 |
+
# First get the confirmation token and append it to the URL
|
120 |
+
with session.get(url, stream=True, allow_redirects=True) as response:
|
121 |
+
for k, v in response.cookies.items():
|
122 |
+
if k.startswith("download_warning"):
|
123 |
+
url = url + "&confirm=" + v
|
124 |
+
|
125 |
+
# Then download the content of the file
|
126 |
+
with session.get(url, stream=True, verify=True) as response:
|
127 |
+
makedir(output_path)
|
128 |
+
path = os.path.join(output_path, output_file_name)
|
129 |
+
total_size = int(response.headers.get("Content-length", 0))
|
130 |
+
with open(path, "wb") as file:
|
131 |
+
from tqdm import tqdm
|
132 |
+
|
133 |
+
with tqdm(total=total_size) as progress_bar:
|
134 |
+
for block in response.iter_content(
|
135 |
+
chunk_size=io.DEFAULT_BUFFER_SIZE
|
136 |
+
):
|
137 |
+
file.write(block)
|
138 |
+
progress_bar.update(len(block))
|
139 |
+
|
140 |
+
|
141 |
+
def _get_google_drive_file_id(url: str) -> Optional[str]:
|
142 |
+
parts = urlparse(url)
|
143 |
+
|
144 |
+
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
|
145 |
+
return None
|
146 |
+
|
147 |
+
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
|
148 |
+
if match is None:
|
149 |
+
return None
|
150 |
+
|
151 |
+
return match.group("id")
|
152 |
+
|
153 |
+
|
154 |
+
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
|
155 |
+
with open(filename, "wb") as fh:
|
156 |
+
with urllib.request.urlopen(
|
157 |
+
urllib.request.Request(url, headers={"User-Agent": "vissl"})
|
158 |
+
) as response:
|
159 |
+
with tqdm(total=response.length) as pbar:
|
160 |
+
for chunk in iter(lambda: response.read(chunk_size), ""):
|
161 |
+
if not chunk:
|
162 |
+
break
|
163 |
+
pbar.update(chunk_size)
|
164 |
+
fh.write(chunk)
|
165 |
+
|
166 |
+
|
167 |
+
def download_url(
|
168 |
+
url: str,
|
169 |
+
root: str,
|
170 |
+
filename: Optional[str] = None,
|
171 |
+
md5: Optional[str] = None,
|
172 |
+
) -> None:
|
173 |
+
"""Download a file from a url and place it in root.
|
174 |
+
Args:
|
175 |
+
url (str): URL to download file from
|
176 |
+
root (str): Directory to place downloaded file in
|
177 |
+
filename (str, optional): Name to save the file under.
|
178 |
+
If None, use the basename of the URL.
|
179 |
+
md5 (str, optional): MD5 checksum of the download. If None, do not check
|
180 |
+
"""
|
181 |
+
root = os.path.expanduser(root)
|
182 |
+
if not filename:
|
183 |
+
filename = os.path.basename(url)
|
184 |
+
fpath = os.path.join(root, filename)
|
185 |
+
|
186 |
+
makedir(root)
|
187 |
+
|
188 |
+
# check if file is already present locally
|
189 |
+
if check_integrity(fpath, md5):
|
190 |
+
print("Using downloaded and verified file: " + fpath)
|
191 |
+
return
|
192 |
+
|
193 |
+
# expand redirect chain if needed
|
194 |
+
url = get_redirected_url(url)
|
195 |
+
|
196 |
+
# check if file is located on Google Drive
|
197 |
+
file_id = _get_google_drive_file_id(url)
|
198 |
+
if file_id is not None:
|
199 |
+
return download_file_from_google_drive(file_id, root, filename, md5)
|
200 |
+
|
201 |
+
# download the file
|
202 |
+
try:
|
203 |
+
print("Downloading " + url + " to " + fpath)
|
204 |
+
_urlretrieve(url, fpath)
|
205 |
+
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
|
206 |
+
if url[:5] == "https":
|
207 |
+
url = url.replace("https:", "http:")
|
208 |
+
print(
|
209 |
+
"Failed download. Trying https -> http instead."
|
210 |
+
" Downloading " + url + " to " + fpath
|
211 |
+
)
|
212 |
+
_urlretrieve(url, fpath)
|
213 |
+
else:
|
214 |
+
raise e
|
215 |
+
|
216 |
+
# check integrity of downloaded file
|
217 |
+
if not check_integrity(fpath, md5):
|
218 |
+
raise RuntimeError("File not found or corrupted.")
|
219 |
+
|
220 |
+
|
221 |
+
def download_and_extract_archive(
|
222 |
+
url: str,
|
223 |
+
download_root: str,
|
224 |
+
extract_root: Optional[str] = None,
|
225 |
+
filename: Optional[str] = None,
|
226 |
+
md5: Optional[str] = None,
|
227 |
+
remove_finished: bool = False,
|
228 |
+
) -> None:
|
229 |
+
download_root = os.path.expanduser(download_root)
|
230 |
+
if extract_root is None:
|
231 |
+
extract_root = download_root
|
232 |
+
if not filename:
|
233 |
+
filename = os.path.basename(url)
|
234 |
+
|
235 |
+
download_url(url, download_root, filename, md5)
|
236 |
+
|
237 |
+
archive = os.path.join(download_root, filename)
|
238 |
+
print("Extracting {} to {}".format(archive, extract_root))
|
239 |
+
extract_archive(archive, extract_root, remove_finished)
|
240 |
+
|
241 |
+
|
242 |
+
def cache_url(url: str, cache_dir: str) -> str:
|
243 |
+
"""
|
244 |
+
This implementation downloads the remote resource and caches it locally.
|
245 |
+
The resource will only be downloaded if not previously requested.
|
246 |
+
"""
|
247 |
+
parsed_url = urlparse(url)
|
248 |
+
dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
|
249 |
+
makedir(dirname)
|
250 |
+
filename = url.split("/")[-1]
|
251 |
+
cached = os.path.join(dirname, filename)
|
252 |
+
with file_lock(cached):
|
253 |
+
if not os.path.isfile(cached):
|
254 |
+
logging.info(f"Downloading {url} to {cached} ...")
|
255 |
+
cached = download(url, dirname, filename=filename)
|
256 |
+
logging.info(f"URL {url} cached in {cached}")
|
257 |
+
return cached
|
258 |
+
|
259 |
+
|
260 |
+
# TODO (prigoyal): convert this into RAII-style API
|
261 |
+
def create_file_symlink(file1, file2):
|
262 |
+
"""
|
263 |
+
Simply create the symlinks for a given file1 to file2.
|
264 |
+
Useful during model checkpointing to symlinks to the
|
265 |
+
latest successful checkpoint.
|
266 |
+
"""
|
267 |
+
try:
|
268 |
+
if g_pathmgr.exists(file2):
|
269 |
+
g_pathmgr.rm(file2)
|
270 |
+
g_pathmgr.symlink(file1, file2)
|
271 |
+
except Exception as e:
|
272 |
+
logging.info(f"Could NOT create symlink. Error: {e}")
|
273 |
+
|
274 |
+
|
275 |
+
def save_file(data, filename, append_to_json=True, verbose=True):
|
276 |
+
"""
|
277 |
+
Common i/o utility to handle saving data to various file formats.
|
278 |
+
Supported:
|
279 |
+
.pkl, .pickle, .npy, .json
|
280 |
+
Specifically for .json, users have the option to either append (default)
|
281 |
+
or rewrite by passing in Boolean value to append_to_json.
|
282 |
+
"""
|
283 |
+
if verbose:
|
284 |
+
logging.info(f"Saving data to file: {filename}")
|
285 |
+
file_ext = os.path.splitext(filename)[1]
|
286 |
+
if file_ext in [".pkl", ".pickle"]:
|
287 |
+
with g_pathmgr.open(filename, "wb") as fopen:
|
288 |
+
pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
|
289 |
+
elif file_ext == ".npy":
|
290 |
+
with g_pathmgr.open(filename, "wb") as fopen:
|
291 |
+
np.save(fopen, data)
|
292 |
+
elif file_ext == ".json":
|
293 |
+
if append_to_json:
|
294 |
+
with g_pathmgr.open(filename, "a") as fopen:
|
295 |
+
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
296 |
+
fopen.flush()
|
297 |
+
else:
|
298 |
+
with g_pathmgr.open(filename, "w") as fopen:
|
299 |
+
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
300 |
+
fopen.flush()
|
301 |
+
elif file_ext == ".yaml":
|
302 |
+
with g_pathmgr.open(filename, "w") as fopen:
|
303 |
+
dump = yaml.dump(data)
|
304 |
+
fopen.write(dump)
|
305 |
+
fopen.flush()
|
306 |
+
else:
|
307 |
+
raise Exception(f"Saving {file_ext} is not supported yet")
|
308 |
+
|
309 |
+
if verbose:
|
310 |
+
logging.info(f"Saved data to file: {filename}")
|
311 |
+
|
312 |
+
|
313 |
+
def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
|
314 |
+
"""
|
315 |
+
Common i/o utility to handle loading data from various file formats.
|
316 |
+
Supported:
|
317 |
+
.pkl, .pickle, .npy, .json
|
318 |
+
For the npy files, we support reading the files in mmap_mode.
|
319 |
+
If the mmap_mode of reading is not successful, we load data without the
|
320 |
+
mmap_mode.
|
321 |
+
"""
|
322 |
+
if verbose:
|
323 |
+
logging.info(f"Loading data from file: {filename}")
|
324 |
+
|
325 |
+
file_ext = os.path.splitext(filename)[1]
|
326 |
+
if file_ext == ".txt":
|
327 |
+
with g_pathmgr.open(filename, "r") as fopen:
|
328 |
+
data = fopen.readlines()
|
329 |
+
elif file_ext in [".pkl", ".pickle"]:
|
330 |
+
with g_pathmgr.open(filename, "rb") as fopen:
|
331 |
+
data = pickle.load(fopen, encoding="latin1")
|
332 |
+
elif file_ext == ".npy":
|
333 |
+
if mmap_mode:
|
334 |
+
try:
|
335 |
+
with g_pathmgr.open(filename, "rb") as fopen:
|
336 |
+
data = np.load(
|
337 |
+
fopen,
|
338 |
+
allow_pickle=allow_pickle,
|
339 |
+
encoding="latin1",
|
340 |
+
mmap_mode=mmap_mode,
|
341 |
+
)
|
342 |
+
except ValueError as e:
|
343 |
+
logging.info(
|
344 |
+
f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
|
345 |
+
)
|
346 |
+
data = np.load(
|
347 |
+
filename,
|
348 |
+
allow_pickle=allow_pickle,
|
349 |
+
encoding="latin1",
|
350 |
+
mmap_mode=mmap_mode,
|
351 |
+
)
|
352 |
+
logging.info("Successfully loaded without g_pathmgr")
|
353 |
+
except Exception:
|
354 |
+
logging.info("Could not mmap without g_pathmgr. Trying without mmap")
|
355 |
+
with g_pathmgr.open(filename, "rb") as fopen:
|
356 |
+
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
357 |
+
else:
|
358 |
+
with g_pathmgr.open(filename, "rb") as fopen:
|
359 |
+
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
360 |
+
elif file_ext == ".json":
|
361 |
+
with g_pathmgr.open(filename, "r") as fopen:
|
362 |
+
data = json.load(fopen)
|
363 |
+
elif file_ext == ".yaml":
|
364 |
+
with g_pathmgr.open(filename, "r") as fopen:
|
365 |
+
data = yaml.load(fopen, Loader=yaml.FullLoader)
|
366 |
+
elif file_ext == ".csv":
|
367 |
+
with g_pathmgr.open(filename, "r") as fopen:
|
368 |
+
data = pd.read_csv(fopen)
|
369 |
+
else:
|
370 |
+
raise Exception(f"Reading from {file_ext} is not supported yet")
|
371 |
+
return data
|
372 |
+
|
373 |
+
|
374 |
+
def abspath(resource_path: str):
|
375 |
+
"""
|
376 |
+
Make a path absolute, but take into account prefixes like
|
377 |
+
"http://" or "manifold://"
|
378 |
+
"""
|
379 |
+
regex = re.compile(r"^\w+://")
|
380 |
+
if regex.match(resource_path) is None:
|
381 |
+
return os.path.abspath(resource_path)
|
382 |
+
else:
|
383 |
+
return resource_path
|
384 |
+
|
385 |
+
|
386 |
+
def makedir(dir_path):
|
387 |
+
"""
|
388 |
+
Create the directory if it does not exist.
|
389 |
+
"""
|
390 |
+
is_success = False
|
391 |
+
try:
|
392 |
+
if not g_pathmgr.exists(dir_path):
|
393 |
+
g_pathmgr.mkdirs(dir_path)
|
394 |
+
is_success = True
|
395 |
+
except BaseException:
|
396 |
+
logging.info(f"Error creating directory: {dir_path}")
|
397 |
+
return is_success
|
398 |
+
|
399 |
+
|
400 |
+
def is_url(input_url):
|
401 |
+
"""
|
402 |
+
Check if an input string is a url. look for http(s):// and ignoring the case
|
403 |
+
"""
|
404 |
+
is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
|
405 |
+
return is_url
|
406 |
+
|
407 |
+
|
408 |
+
def cleanup_dir(dir):
|
409 |
+
"""
|
410 |
+
Utility for deleting a directory. Useful for cleaning the storage space
|
411 |
+
that contains various training artifacts like checkpoints, data etc.
|
412 |
+
"""
|
413 |
+
if os.path.exists(dir):
|
414 |
+
logging.info(f"Deleting directory: {dir}")
|
415 |
+
shutil.rmtree(dir)
|
416 |
+
logging.info(f"Deleted contents of directory: {dir}")
|
417 |
+
|
418 |
+
|
419 |
+
def get_file_size(filename):
|
420 |
+
"""
|
421 |
+
Given a file, get the size of file in MB
|
422 |
+
"""
|
423 |
+
size_in_mb = os.path.getsize(filename) / float(1024**2)
|
424 |
+
return size_in_mb
|
medomni/configs/datasets/medinterp/align.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets:
|
2 |
+
med:
|
3 |
+
data_type: images
|
4 |
+
build_info:
|
5 |
+
storage: json_files/medinterp
|
medomni/configs/default.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
env:
|
2 |
+
# For default users
|
3 |
+
# cache_root: "cache"
|
4 |
+
# For internal use with persistent storage
|
5 |
+
cache_root: "/export/home/.cache/medomni"
|
medomni/configs/models/medomni.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
arch: medomni
|
3 |
+
|
4 |
+
# vision encoder
|
5 |
+
precision: "fp16"
|
6 |
+
freeze_vit: True
|
7 |
+
|
8 |
+
# Llama
|
9 |
+
llama_model: "meta-llama/Llama-2-7b-chat-hf"
|
10 |
+
|
11 |
+
# generation configs
|
12 |
+
prompt: ""
|
medomni/conversation/__init__.py
ADDED
File without changes
|
medomni/conversation/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (145 Bytes). View file
|
|
medomni/conversation/__pycache__/conversation.cpython-39.pyc
ADDED
Binary file (7.3 kB). View file
|
|
medomni/conversation/conversation.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import time
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
|
7 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
8 |
+
|
9 |
+
import dataclasses
|
10 |
+
from enum import auto, Enum
|
11 |
+
from typing import List, Tuple, Any
|
12 |
+
|
13 |
+
from medomni.common.registry import registry
|
14 |
+
import ipdb
|
15 |
+
|
16 |
+
|
17 |
+
class SeparatorStyle(Enum):
|
18 |
+
"""Different separator style."""
|
19 |
+
SINGLE = auto()
|
20 |
+
TWO = auto()
|
21 |
+
|
22 |
+
|
23 |
+
@dataclasses.dataclass
|
24 |
+
class Conversation:
|
25 |
+
"""A class that keeps all conversation history."""
|
26 |
+
system: str
|
27 |
+
roles: List[str]
|
28 |
+
messages: List[List[str]]
|
29 |
+
offset: int
|
30 |
+
# system_img: List[Image.Image] = []
|
31 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
32 |
+
sep: str = "###"
|
33 |
+
sep2: str = None
|
34 |
+
|
35 |
+
skip_next: bool = False
|
36 |
+
conv_id: Any = None
|
37 |
+
|
38 |
+
def get_prompt(self):
|
39 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
40 |
+
ret = self.system + self.sep
|
41 |
+
for role, message in self.messages:
|
42 |
+
if message:
|
43 |
+
ret += role + ": " + message + self.sep
|
44 |
+
else:
|
45 |
+
ret += role + ":"
|
46 |
+
return ret
|
47 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
48 |
+
seps = [self.sep, self.sep2]
|
49 |
+
ret = self.system + seps[0]
|
50 |
+
for i, (role, message) in enumerate(self.messages):
|
51 |
+
if message:
|
52 |
+
ret += role + ": " + message + seps[i % 2]
|
53 |
+
else:
|
54 |
+
ret += role + ":"
|
55 |
+
return ret
|
56 |
+
else:
|
57 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
58 |
+
|
59 |
+
def append_message(self, role, message):
|
60 |
+
self.messages.append([role, message])
|
61 |
+
|
62 |
+
def to_gradio_chatbot(self):
|
63 |
+
ret = []
|
64 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
65 |
+
if i % 2 == 0:
|
66 |
+
ret.append([msg, None])
|
67 |
+
else:
|
68 |
+
ret[-1][-1] = msg
|
69 |
+
return ret
|
70 |
+
|
71 |
+
def copy(self):
|
72 |
+
return Conversation(
|
73 |
+
system=self.system,
|
74 |
+
roles=self.roles,
|
75 |
+
messages=[[x, y] for x, y in self.messages],
|
76 |
+
offset=self.offset,
|
77 |
+
sep_style=self.sep_style,
|
78 |
+
sep=self.sep,
|
79 |
+
sep2=self.sep2,
|
80 |
+
conv_id=self.conv_id)
|
81 |
+
|
82 |
+
def dict(self):
|
83 |
+
return {
|
84 |
+
"system": self.system,
|
85 |
+
# "system_img": self.system_img,
|
86 |
+
"roles": self.roles,
|
87 |
+
"messages": self.messages,
|
88 |
+
"offset": self.offset,
|
89 |
+
"sep": self.sep,
|
90 |
+
"sep2": self.sep2,
|
91 |
+
"conv_id": self.conv_id,
|
92 |
+
}
|
93 |
+
|
94 |
+
|
95 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
96 |
+
|
97 |
+
def __init__(self, stops=[], encounters=1):
|
98 |
+
super().__init__()
|
99 |
+
self.stops = stops
|
100 |
+
|
101 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
102 |
+
for stop in self.stops:
|
103 |
+
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
104 |
+
return True
|
105 |
+
|
106 |
+
return False
|
107 |
+
|
108 |
+
|
109 |
+
CONV_VISION = Conversation(
|
110 |
+
system="Give the following image: <Img>ImageContent</Img>. "
|
111 |
+
"You will be able to see the image once I provide it to you. Act as a clinician and answer my questions.",
|
112 |
+
# "You will be able to see the image once I provide it to you. Please answer my questions.",
|
113 |
+
#system="",
|
114 |
+
roles=("Human", "Assistant"),
|
115 |
+
messages=[],
|
116 |
+
offset=2,
|
117 |
+
sep_style=SeparatorStyle.SINGLE,
|
118 |
+
sep="###",
|
119 |
+
)
|
120 |
+
|
121 |
+
class Chat:
|
122 |
+
def __init__(self, model, vis_processor, device='cuda:0'):
|
123 |
+
self.device = device
|
124 |
+
self.model = model
|
125 |
+
self.vis_processor = vis_processor
|
126 |
+
stop_words_ids = [torch.tensor([835]).to(self.device),
|
127 |
+
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
128 |
+
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
129 |
+
|
130 |
+
def ask(self, text, conv):
|
131 |
+
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
132 |
+
and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
|
133 |
+
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
134 |
+
else:
|
135 |
+
conv.append_message(conv.roles[0], text) # commented by hy on 5.9
|
136 |
+
|
137 |
+
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
138 |
+
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
|
139 |
+
conv.append_message(conv.roles[1], None)
|
140 |
+
embs = self.get_context_emb(conv, img_list)
|
141 |
+
|
142 |
+
current_max_len = embs.shape[1] + max_new_tokens
|
143 |
+
if current_max_len - max_length > 0:
|
144 |
+
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
145 |
+
'The model will not see the contexts outside the range.')
|
146 |
+
begin_idx = max(0, current_max_len - max_length)
|
147 |
+
|
148 |
+
embs = embs[:, begin_idx:]
|
149 |
+
|
150 |
+
with torch.autocast("cuda"):
|
151 |
+
outputs = self.model.llama_model.generate(
|
152 |
+
inputs_embeds=embs,
|
153 |
+
max_new_tokens=max_new_tokens,
|
154 |
+
stopping_criteria=self.stopping_criteria,
|
155 |
+
num_beams=num_beams,
|
156 |
+
do_sample=True,
|
157 |
+
min_length=min_length,
|
158 |
+
top_p=top_p,
|
159 |
+
repetition_penalty=repetition_penalty,
|
160 |
+
length_penalty=length_penalty,
|
161 |
+
temperature=temperature,
|
162 |
+
)
|
163 |
+
output_token = outputs[0]
|
164 |
+
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
|
165 |
+
output_token = output_token[1:]
|
166 |
+
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
|
167 |
+
output_token = output_token[1:]
|
168 |
+
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
|
169 |
+
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
170 |
+
output_text = output_text.split('Assistant:')[-1].strip()
|
171 |
+
conv.messages[-1][1] = output_text # commented by hy on 5.9
|
172 |
+
#---5.9.2023---
|
173 |
+
conv.messages = []
|
174 |
+
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
|
175 |
+
return output_text, output_token.cpu().numpy()
|
176 |
+
|
177 |
+
def upload_img(self, image, conv, img_list):
|
178 |
+
if isinstance(image, str): # is a image path
|
179 |
+
raw_image = Image.open(image).convert('RGB')
|
180 |
+
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
181 |
+
elif isinstance(image, Image.Image):
|
182 |
+
raw_image = image
|
183 |
+
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
184 |
+
elif isinstance(image, torch.Tensor):
|
185 |
+
if len(image.shape) == 3:
|
186 |
+
image = image.unsqueeze(0)
|
187 |
+
image = image.to(self.device)
|
188 |
+
|
189 |
+
image_emb, _ = self.model.encode_img(image)
|
190 |
+
img_list.append(image_emb)
|
191 |
+
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
|
192 |
+
msg = "Received."
|
193 |
+
return msg
|
194 |
+
|
195 |
+
def get_context_emb(self, conv, img_list):
|
196 |
+
prompt = conv.get_prompt()
|
197 |
+
prompt_segs = prompt.split('<ImageHere>')
|
198 |
+
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
199 |
+
#seg_tokens = []
|
200 |
+
#for i, seg in enumerate(prompt_segs):
|
201 |
+
# if i == 1:
|
202 |
+
# prompt_ids = self.model.llama_tokenizer(
|
203 |
+
# seg,
|
204 |
+
# return_tensors="pt",
|
205 |
+
# add_special_tokens=i == 0
|
206 |
+
# ).to(self.device).input_ids
|
207 |
+
# seg_tokens.append(prompt_ids)
|
208 |
+
# else:
|
209 |
+
# prompt_ids = self.model.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
|
210 |
+
# seg_tokens.append(prompt_ids)
|
211 |
+
seg_tokens = [
|
212 |
+
self.model.llama_tokenizer(
|
213 |
+
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
|
214 |
+
# only add bos to the first seg
|
215 |
+
for i, seg in enumerate(prompt_segs)
|
216 |
+
]
|
217 |
+
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
218 |
+
# seg_embs = [self.model.llama_model.model.base_model.embed_tokens(seg_t) for seg_t in seg_tokens] # LoRA
|
219 |
+
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
220 |
+
mixed_embs = torch.cat(mixed_embs, dim=1)
|
221 |
+
return mixed_embs
|
222 |
+
|
medomni/datasets/__init__.py
ADDED
File without changes
|
medomni/datasets/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (149 Bytes). View file
|
|
medomni/datasets/__pycache__/data_utils.cpython-39.pyc
ADDED
Binary file (5.95 kB). View file
|
|
medomni/datasets/builders/__init__.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
from medomni.datasets.builders.base_dataset_builder import load_dataset_config
|
9 |
+
from medomni.datasets.builders.image_text_pair_builder import (
|
10 |
+
CCSBUBuilder,
|
11 |
+
LaionBuilder,
|
12 |
+
CCSBUAlignBuilder
|
13 |
+
)
|
14 |
+
from medomni.common.registry import registry
|
15 |
+
|
16 |
+
__all__ = [
|
17 |
+
"CCSBUBuilder",
|
18 |
+
"LaionBuilder",
|
19 |
+
"CCSBUAlignBuilder"
|
20 |
+
]
|
21 |
+
|
22 |
+
def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
|
23 |
+
"""
|
24 |
+
Example
|
25 |
+
|
26 |
+
>>> dataset = load_dataset("coco_caption", cfg=None)
|
27 |
+
>>> splits = dataset.keys()
|
28 |
+
>>> print([len(dataset[split]) for split in splits])
|
29 |
+
|
30 |
+
"""
|
31 |
+
if cfg_path is None:
|
32 |
+
cfg = None
|
33 |
+
else:
|
34 |
+
cfg = load_dataset_config(cfg_path)
|
35 |
+
|
36 |
+
try:
|
37 |
+
builder = registry.get_builder_class(name)(cfg)
|
38 |
+
except TypeError:
|
39 |
+
print(
|
40 |
+
f"Dataset {name} not found. Available datasets:\n"
|
41 |
+
+ ", ".join([str(k) for k in dataset_zoo.get_names()])
|
42 |
+
)
|
43 |
+
exit(1)
|
44 |
+
|
45 |
+
if vis_path is not None:
|
46 |
+
if data_type is None:
|
47 |
+
# use default data type in the config
|
48 |
+
data_type = builder.config.data_type
|
49 |
+
|
50 |
+
assert (
|
51 |
+
data_type in builder.config.build_info
|
52 |
+
), f"Invalid data_type {data_type} for {name}."
|
53 |
+
|
54 |
+
builder.config.build_info.get(data_type).storage = vis_path
|
55 |
+
|
56 |
+
dataset = builder.build_datasets()
|
57 |
+
return dataset
|
58 |
+
|
59 |
+
|
60 |
+
class DatasetZoo:
|
61 |
+
def __init__(self) -> None:
|
62 |
+
self.dataset_zoo = {
|
63 |
+
k: list(v.DATASET_CONFIG_DICT.keys())
|
64 |
+
for k, v in sorted(registry.mapping["builder_name_mapping"].items())
|
65 |
+
}
|
66 |
+
|
67 |
+
def get_names(self):
|
68 |
+
return list(self.dataset_zoo.keys())
|
69 |
+
|
70 |
+
|
71 |
+
dataset_zoo = DatasetZoo()
|
medomni/datasets/builders/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (2.35 kB). View file
|
|
medomni/datasets/builders/__pycache__/base_dataset_builder.cpython-39.pyc
ADDED
Binary file (6.06 kB). View file
|
|
medomni/datasets/builders/__pycache__/image_text_pair_builder.cpython-39.pyc
ADDED
Binary file (3.82 kB). View file
|
|
medomni/datasets/builders/base_dataset_builder.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file is from
|
3 |
+
Copyright (c) 2022, salesforce.com, inc.
|
4 |
+
All rights reserved.
|
5 |
+
SPDX-License-Identifier: BSD-3-Clause
|
6 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
7 |
+
"""
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import shutil
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
import torch.distributed as dist
|
16 |
+
from torchvision.datasets.utils import download_url
|
17 |
+
|
18 |
+
import medomni.common.utils as utils
|
19 |
+
from medomni.common.dist_utils import is_dist_avail_and_initialized, is_main_process
|
20 |
+
from medomni.common.registry import registry
|
21 |
+
from medomni.processors.base_processor import BaseProcessor
|
22 |
+
|
23 |
+
class BaseDatasetBuilder:
|
24 |
+
train_dataset_cls, eval_dataset_cls = None, None
|
25 |
+
|
26 |
+
def __init__(self, cfg=None):
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
if cfg is None:
|
30 |
+
# help to create datasets from default config.
|
31 |
+
self.config = load_dataset_config(self.default_config_path())
|
32 |
+
elif isinstance(cfg, str):
|
33 |
+
self.config = load_dataset_config(cfg)
|
34 |
+
else:
|
35 |
+
# when called from task.build_dataset()
|
36 |
+
self.config = cfg
|
37 |
+
|
38 |
+
self.data_type = self.config.data_type
|
39 |
+
|
40 |
+
self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
41 |
+
self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
42 |
+
|
43 |
+
def build_datasets(self):
|
44 |
+
# download, split, etc...
|
45 |
+
# only called on 1 GPU/TPU in distributed
|
46 |
+
|
47 |
+
if is_main_process():
|
48 |
+
self._download_data()
|
49 |
+
|
50 |
+
if is_dist_avail_and_initialized():
|
51 |
+
dist.barrier()
|
52 |
+
|
53 |
+
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
54 |
+
logging.info("Building datasets...")
|
55 |
+
datasets = self.build() # dataset['train'/'val'/'test']
|
56 |
+
|
57 |
+
return datasets
|
58 |
+
|
59 |
+
def build_processors(self):
|
60 |
+
vis_proc_cfg = self.config.get("vis_processor")
|
61 |
+
txt_proc_cfg = self.config.get("text_processor")
|
62 |
+
|
63 |
+
if vis_proc_cfg is not None:
|
64 |
+
vis_train_cfg = vis_proc_cfg.get("train")
|
65 |
+
vis_eval_cfg = vis_proc_cfg.get("eval")
|
66 |
+
|
67 |
+
self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
|
68 |
+
self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
|
69 |
+
|
70 |
+
if txt_proc_cfg is not None:
|
71 |
+
txt_train_cfg = txt_proc_cfg.get("train")
|
72 |
+
txt_eval_cfg = txt_proc_cfg.get("eval")
|
73 |
+
|
74 |
+
self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
|
75 |
+
self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def _build_proc_from_cfg(cfg):
|
79 |
+
return (
|
80 |
+
registry.get_processor_class(cfg.name).from_config(cfg)
|
81 |
+
if cfg is not None
|
82 |
+
else None
|
83 |
+
)
|
84 |
+
|
85 |
+
@classmethod
|
86 |
+
def default_config_path(cls, type="default"):
|
87 |
+
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
|
88 |
+
|
89 |
+
def _download_data(self):
|
90 |
+
self._download_ann()
|
91 |
+
self._download_vis()
|
92 |
+
|
93 |
+
def _download_ann(self):
|
94 |
+
"""
|
95 |
+
Download annotation files if necessary.
|
96 |
+
All the vision-language datasets should have annotations of unified format.
|
97 |
+
|
98 |
+
storage_path can be:
|
99 |
+
(1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
|
100 |
+
(2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
|
101 |
+
|
102 |
+
Local annotation paths should be relative.
|
103 |
+
"""
|
104 |
+
anns = self.config.build_info.annotations
|
105 |
+
|
106 |
+
splits = anns.keys()
|
107 |
+
|
108 |
+
cache_root = registry.get_path("cache_root")
|
109 |
+
|
110 |
+
for split in splits:
|
111 |
+
info = anns[split]
|
112 |
+
|
113 |
+
urls, storage_paths = info.get("url", None), info.storage
|
114 |
+
|
115 |
+
if isinstance(urls, str):
|
116 |
+
urls = [urls]
|
117 |
+
if isinstance(storage_paths, str):
|
118 |
+
storage_paths = [storage_paths]
|
119 |
+
|
120 |
+
assert len(urls) == len(storage_paths)
|
121 |
+
|
122 |
+
for url_or_filename, storage_path in zip(urls, storage_paths):
|
123 |
+
# if storage_path is relative, make it full by prefixing with cache_root.
|
124 |
+
if not os.path.isabs(storage_path):
|
125 |
+
storage_path = os.path.join(cache_root, storage_path)
|
126 |
+
|
127 |
+
dirname = os.path.dirname(storage_path)
|
128 |
+
if not os.path.exists(dirname):
|
129 |
+
os.makedirs(dirname)
|
130 |
+
|
131 |
+
if os.path.isfile(url_or_filename):
|
132 |
+
src, dst = url_or_filename, storage_path
|
133 |
+
if not os.path.exists(dst):
|
134 |
+
shutil.copyfile(src=src, dst=dst)
|
135 |
+
else:
|
136 |
+
logging.info("Using existing file {}.".format(dst))
|
137 |
+
else:
|
138 |
+
if os.path.isdir(storage_path):
|
139 |
+
# if only dirname is provided, suffix with basename of URL.
|
140 |
+
raise ValueError(
|
141 |
+
"Expecting storage_path to be a file path, got directory {}".format(
|
142 |
+
storage_path
|
143 |
+
)
|
144 |
+
)
|
145 |
+
else:
|
146 |
+
filename = os.path.basename(storage_path)
|
147 |
+
|
148 |
+
download_url(url=url_or_filename, root=dirname, filename=filename)
|
149 |
+
|
150 |
+
def _download_vis(self):
|
151 |
+
|
152 |
+
storage_path = self.config.build_info.get(self.data_type).storage
|
153 |
+
storage_path = utils.get_cache_path(storage_path)
|
154 |
+
|
155 |
+
if not os.path.exists(storage_path):
|
156 |
+
warnings.warn(
|
157 |
+
f"""
|
158 |
+
The specified path {storage_path} for visual inputs does not exist.
|
159 |
+
Please provide a correct path to the visual inputs or
|
160 |
+
refer to datasets/download_scripts/README.md for downloading instructions.
|
161 |
+
"""
|
162 |
+
)
|
163 |
+
|
164 |
+
def build(self):
|
165 |
+
"""
|
166 |
+
Create by split datasets inheriting torch.utils.data.Datasets.
|
167 |
+
|
168 |
+
# build() can be dataset-specific. Overwrite to customize.
|
169 |
+
"""
|
170 |
+
self.build_processors()
|
171 |
+
|
172 |
+
build_info = self.config.build_info
|
173 |
+
|
174 |
+
ann_info = build_info.annotations
|
175 |
+
vis_info = build_info.get(self.data_type)
|
176 |
+
|
177 |
+
datasets = dict()
|
178 |
+
for split in ann_info.keys():
|
179 |
+
if split not in ["train", "val", "test"]:
|
180 |
+
continue
|
181 |
+
|
182 |
+
is_train = split == "train"
|
183 |
+
|
184 |
+
# processors
|
185 |
+
vis_processor = (
|
186 |
+
self.vis_processors["train"]
|
187 |
+
if is_train
|
188 |
+
else self.vis_processors["eval"]
|
189 |
+
)
|
190 |
+
text_processor = (
|
191 |
+
self.text_processors["train"]
|
192 |
+
if is_train
|
193 |
+
else self.text_processors["eval"]
|
194 |
+
)
|
195 |
+
|
196 |
+
# annotation path
|
197 |
+
ann_paths = ann_info.get(split).storage
|
198 |
+
if isinstance(ann_paths, str):
|
199 |
+
ann_paths = [ann_paths]
|
200 |
+
|
201 |
+
abs_ann_paths = []
|
202 |
+
for ann_path in ann_paths:
|
203 |
+
if not os.path.isabs(ann_path):
|
204 |
+
ann_path = utils.get_cache_path(ann_path)
|
205 |
+
abs_ann_paths.append(ann_path)
|
206 |
+
ann_paths = abs_ann_paths
|
207 |
+
|
208 |
+
# visual data storage path
|
209 |
+
vis_path = os.path.join(vis_info.storage, split)
|
210 |
+
|
211 |
+
if not os.path.isabs(vis_path):
|
212 |
+
# vis_path = os.path.join(utils.get_cache_path(), vis_path)
|
213 |
+
vis_path = utils.get_cache_path(vis_path)
|
214 |
+
|
215 |
+
if not os.path.exists(vis_path):
|
216 |
+
warnings.warn("storage path {} does not exist.".format(vis_path))
|
217 |
+
|
218 |
+
# create datasets
|
219 |
+
dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
|
220 |
+
datasets[split] = dataset_cls(
|
221 |
+
vis_processor=vis_processor,
|
222 |
+
text_processor=text_processor,
|
223 |
+
ann_paths=ann_paths,
|
224 |
+
vis_root=vis_path,
|
225 |
+
)
|
226 |
+
|
227 |
+
return datasets
|
228 |
+
|
229 |
+
|
230 |
+
def load_dataset_config(cfg_path):
|
231 |
+
cfg = OmegaConf.load(cfg_path).datasets
|
232 |
+
cfg = cfg[list(cfg.keys())[0]]
|
233 |
+
|
234 |
+
return cfg
|