File size: 2,080 Bytes
d2ff88f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# ---------------------------------------------------------------------------------------------------
# CLIP-DINOiser
# authors: Monika Wysoczanska, Warsaw University of Technology
# ---------------------------------------------------------------------------------------------------
# modified from TCL
# Copyright (c) 2023 Kakao Brain. All Rights Reserved.
# ---------------------------------------------------------------------------------------------------

import mmcv
from mmseg.datasets import build_dataloader, build_dataset
from mmcv.utils import Registry
from mmcv.cnn import MODELS as MMCV_MODELS
MODELS = Registry('models', parent=MMCV_MODELS)
SEGMENTORS = MODELS
from .clip_dinoiser_eval import DinoCLIP_Infrencer


def build_seg_dataset(config):
    """Build a dataset from config."""
    cfg = mmcv.Config.fromfile(config)
    dataset = build_dataset(cfg.data.test)
    return dataset


def build_seg_dataloader(dataset, dist=True):
    # batch size is set to 1 to handle varying image size (due to different aspect ratio)
    if dist:
        data_loader = build_dataloader(
            dataset,
            samples_per_gpu=1,
            workers_per_gpu=2,
            dist=dist,
            shuffle=False,
            persistent_workers=True,
            pin_memory=False,
        )
    else:
        data_loader = build_dataloader(
            dataset=dataset,
            samples_per_gpu=1,
            workers_per_gpu=2,
            dist=dist,
            shuffle=False,
            persistent_workers=True,
            pin_memory=False,
        )
    return data_loader


def build_seg_inference(
    model,
    dataset,
    config,
    seg_config,
):
    dset_cfg = mmcv.Config.fromfile(seg_config)  # dataset config
    classnames = dataset.CLASSES
    kwargs = dict()
    if hasattr(dset_cfg, "test_cfg"):
        kwargs["test_cfg"] = dset_cfg.test_cfg

    seg_model = DinoCLIP_Infrencer(model, num_classes=len(classnames), **kwargs, **config.evaluate)
    seg_model.CLASSES = dataset.CLASSES
    seg_model.PALETTE = dataset.PALETTE

    return seg_model