File size: 6,818 Bytes
5d756f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from pathlib import Path
from typing import Union, Optional
import numpy as np
import torch
import tops
import torchvision.transforms.functional as F
from motpy import Detection, MultiObjectTracker
from dp2.utils import load_config
from dp2.infer import build_trained_generator
from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection


def load_generator_from_cfg_path(cfg_path: Union[str, Path]):
    cfg = load_config(cfg_path)
    G = build_trained_generator(cfg)
    tops.logger.log(f"Loaded generator from: {cfg_path}")
    return G


class Anonymizer:

    def __init__(
            self,
            detector,
            load_cache: bool = False,
            person_G_cfg: Optional[Union[str, Path]] = None,
            cse_person_G_cfg: Optional[Union[str, Path]] = None,
            face_G_cfg: Optional[Union[str, Path]] = None,
            car_G_cfg: Optional[Union[str, Path]] = None,
    ) -> None:
        self.detector = detector
        self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]}
        self.load_cache = load_cache
        if cse_person_G_cfg is not None:
            self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg)
        if person_G_cfg is not None:
            self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg)
        if face_G_cfg is not None:
            self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg)
        if car_G_cfg is not None:
            self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg)

    def initialize_tracker(self, fps: float):
        self.tracker = MultiObjectTracker(dt=1/fps)
        self.track_to_z_idx = dict()

    def reset_tracker(self):
        self.track_to_z_idx = dict()

    def forward_G(self,
                  G,
                  batch,
                  multi_modal_truncation: bool,
                  amp: bool,
                  z_idx: int,
                  truncation_value: float,
                  idx: int,
                  all_styles=None):
        batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
        batch["img"] = batch["img"].float()
        batch["condition"] = batch["mask"].float() * batch["img"]

        with torch.cuda.amp.autocast(amp):
            z = None
            if z_idx is not None:
                state = np.random.RandomState(seed=z_idx[idx])
                z = state.normal(size=(1, G.z_channels)).astype(np.float32)
                z = tops.to_cuda(torch.from_numpy(z))

            if all_styles is not None:
                anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"]
            elif multi_modal_truncation:
                w_indices = None
                if z_idx is not None:
                    w_indices = [z_idx[idx] % len(G.style_net.w_centers)]
                anonymized_im = G.multi_modal_truncate(
                    **batch, truncation_value=truncation_value,
                    w_indices=w_indices,
                    z=z
                )["img"]
            else:
                anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"]
        anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255)
        return anonymized_im

    @torch.no_grad()
    def anonymize_detections(self,
                             im, detection,
                             update_identity=None,
                             **synthesis_kwargs
                             ):
        G = self.generators[type(detection)]
        if G is None:
            return im
        C, H, W = im.shape
        if update_identity is None:
            update_identity = [True for i in range(len(detection))]
        for idx in range(len(detection)):
            if not update_identity[idx]:
                continue
            batch = detection.get_crop(idx, im)
            x0, y0, x1, y1 = batch.pop("boxes")[0]
            batch = {k: tops.to_cuda(v) for k, v in batch.items()}
            anonymized_im = self.forward_G(G, batch, **synthesis_kwargs, idx=idx)

            gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.BICUBIC, antialias=True)
            mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0)
            # Remove padding
            pad = [max(-x0, 0), max(-y0, 0)]
            pad = [*pad, max(x1-W, 0), max(y1-H, 0)]
            def remove_pad(x): return x[..., pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]]

            gim = remove_pad(gim)
            mask = remove_pad(mask) > 0.5
            x0, y0 = max(x0, 0), max(y0, 0)
            x1, y1 = min(x1, W), min(y1, H)
            mask = mask.logical_not()[None].repeat(3, 1, 1)

            im[:, y0:y1, x0:x1][mask] = gim[mask].round().clamp(0, 255).byte()
        return im

    def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor:
        all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
        im = im.cpu()
        for det in all_detections:
            im = det.visualize(im)
        return im

    @torch.no_grad()
    def forward(self, im: torch.Tensor, cache_id: str = None, track=True, detections=None, **synthesis_kwargs) -> torch.Tensor:
        assert im.dtype == torch.uint8
        im = tops.to_cuda(im)
        all_detections = detections
        if detections is None:
            if self.load_cache:
                all_detections = self.detector.forward_and_cache(im, cache_id)
            else:
                all_detections = self.detector(im)
        if hasattr(self, "tracker") and track:
            [_.pre_process() for _ in all_detections]
            boxes = np.concatenate([_.boxes for _ in all_detections])
            boxes = [Detection(box) for box in boxes]
            self.tracker.step(boxes)
            track_ids = self.tracker.detections_matched_ids
            z_idx = []
            for track_id in track_ids:
                if track_id not in self.track_to_z_idx:
                    self.track_to_z_idx[track_id] = np.random.randint(0, 2**32-1)
                z_idx.append(self.track_to_z_idx[track_id])
            z_idx = np.array(z_idx)
            idx_offset = 0

        for detection in all_detections:
            zs = None
            if hasattr(self, "tracker") and track:
                zs = z_idx[idx_offset:idx_offset+len(detection)]
                idx_offset += len(detection)
            im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs)

        return im.cpu()

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)