File size: 3,776 Bytes
6723494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import numpy.typing as npt
import torch.nn.functional as F

from app.configs import DEVICE
from app.mobile_sam import SamPredictor
from .model import point_selection, MaskWeights
from .loss import calculate_dice_loss, calculate_sigmoid_focal_loss


def train(
    predictor: SamPredictor,
    ref_images: list[npt.NDArray],
    ref_masks: list[npt.NDArray],
    lr: float = 1e-3,
    epochs: int = 200,
) -> tuple[torch.Tensor, torch.Tensor]:
    gt_masks = []
    points = []
    target_feats = []
    for ref_image, ref_mask in zip(ref_images, ref_masks):
        gt_mask = torch.from_numpy(ref_mask)[:, :] > 0
        gt_mask = gt_mask.float().unsqueeze(0).flatten(1).to(DEVICE)
        gt_masks.append(gt_mask)

        # Image features encoding
        predictor.set_image(ref_image)
        ref_mask = predictor.get_mask(ref_mask[:, :, None])
        ref_feat = predictor.features.squeeze().permute(1, 2, 0)

        ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0:2], mode="bilinear")
        ref_mask = ref_mask.squeeze()

        # Target feature extraction
        target_feat = ref_feat[ref_mask > 0]
        target_feat_mean = target_feat.mean(0)
        target_feat_max = torch.max(target_feat, dim=0)[0]
        target_feat = (target_feat_max / 2 + target_feat_mean / 2).unsqueeze(0)

        # Cosine similarity
        h, w, C = ref_feat.shape
        target_feat = target_feat / target_feat.norm(dim=-1, keepdim=True)
        target_feats.append(target_feat)
        ref_feat = ref_feat / ref_feat.norm(dim=-1, keepdim=True)
        ref_feat = ref_feat.permute(2, 0, 1).reshape(C, h * w)
        sim = target_feat @ ref_feat

        sim = sim.reshape(1, 1, h, w)
        sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
        sim = predictor.model.postprocess_masks(
            sim, input_size=predictor.input_size, original_size=predictor.original_size
        ).squeeze()

        # Positive location prior
        topk_xy, topk_label = point_selection(sim, topk=1)
        points.append((topk_xy, topk_label))

    target_feat = torch.concat(target_feats, axis=0).mean(axis=0)

    # Learnable mask weights
    mask_weights = MaskWeights().to(DEVICE)
    mask_weights.train()

    optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=lr, eps=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    for _ in range(epochs):
        for i in range(len(gt_masks)):
            gt_mask = gt_masks[i]
            topk_xy, topk_label = points[i]
            # Run the decoder
            (
                logits_high,
                _,
                _,
            ) = predictor.predict(
                point_coords=topk_xy,
                point_labels=topk_label,
                multimask_output=True,
                return_logits=True,
                return_numpy=False,
            )
            logits_high = logits_high.flatten(1)

            # Weighted sum three-scale masks
            weights = torch.cat(
                (1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights),
                dim=0,
            )
            logits_high = logits_high * weights
            logits_high = logits_high.sum(0).unsqueeze(0)

            dice_loss = calculate_dice_loss(logits_high, gt_mask)
            focal_loss = calculate_sigmoid_focal_loss(logits_high, gt_mask, alpha=1.0)
            loss = dice_loss + focal_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

    # current_lr = scheduler.get_last_lr()[0]
    mask_weights.eval()
    weights = torch.cat(
        (1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0
    )
    return weights, target_feat