Spaces:
Runtime error
Runtime error
File size: 6,818 Bytes
5d756f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from pathlib import Path
from typing import Union, Optional
import numpy as np
import torch
import tops
import torchvision.transforms.functional as F
from motpy import Detection, MultiObjectTracker
from dp2.utils import load_config
from dp2.infer import build_trained_generator
from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection
def load_generator_from_cfg_path(cfg_path: Union[str, Path]):
cfg = load_config(cfg_path)
G = build_trained_generator(cfg)
tops.logger.log(f"Loaded generator from: {cfg_path}")
return G
class Anonymizer:
def __init__(
self,
detector,
load_cache: bool = False,
person_G_cfg: Optional[Union[str, Path]] = None,
cse_person_G_cfg: Optional[Union[str, Path]] = None,
face_G_cfg: Optional[Union[str, Path]] = None,
car_G_cfg: Optional[Union[str, Path]] = None,
) -> None:
self.detector = detector
self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]}
self.load_cache = load_cache
if cse_person_G_cfg is not None:
self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg)
if person_G_cfg is not None:
self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg)
if face_G_cfg is not None:
self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg)
if car_G_cfg is not None:
self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg)
def initialize_tracker(self, fps: float):
self.tracker = MultiObjectTracker(dt=1/fps)
self.track_to_z_idx = dict()
def reset_tracker(self):
self.track_to_z_idx = dict()
def forward_G(self,
G,
batch,
multi_modal_truncation: bool,
amp: bool,
z_idx: int,
truncation_value: float,
idx: int,
all_styles=None):
batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
batch["img"] = batch["img"].float()
batch["condition"] = batch["mask"].float() * batch["img"]
with torch.cuda.amp.autocast(amp):
z = None
if z_idx is not None:
state = np.random.RandomState(seed=z_idx[idx])
z = state.normal(size=(1, G.z_channels)).astype(np.float32)
z = tops.to_cuda(torch.from_numpy(z))
if all_styles is not None:
anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"]
elif multi_modal_truncation:
w_indices = None
if z_idx is not None:
w_indices = [z_idx[idx] % len(G.style_net.w_centers)]
anonymized_im = G.multi_modal_truncate(
**batch, truncation_value=truncation_value,
w_indices=w_indices,
z=z
)["img"]
else:
anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"]
anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255)
return anonymized_im
@torch.no_grad()
def anonymize_detections(self,
im, detection,
update_identity=None,
**synthesis_kwargs
):
G = self.generators[type(detection)]
if G is None:
return im
C, H, W = im.shape
if update_identity is None:
update_identity = [True for i in range(len(detection))]
for idx in range(len(detection)):
if not update_identity[idx]:
continue
batch = detection.get_crop(idx, im)
x0, y0, x1, y1 = batch.pop("boxes")[0]
batch = {k: tops.to_cuda(v) for k, v in batch.items()}
anonymized_im = self.forward_G(G, batch, **synthesis_kwargs, idx=idx)
gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.BICUBIC, antialias=True)
mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0)
# Remove padding
pad = [max(-x0, 0), max(-y0, 0)]
pad = [*pad, max(x1-W, 0), max(y1-H, 0)]
def remove_pad(x): return x[..., pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]]
gim = remove_pad(gim)
mask = remove_pad(mask) > 0.5
x0, y0 = max(x0, 0), max(y0, 0)
x1, y1 = min(x1, W), min(y1, H)
mask = mask.logical_not()[None].repeat(3, 1, 1)
im[:, y0:y1, x0:x1][mask] = gim[mask].round().clamp(0, 255).byte()
return im
def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor:
all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
im = im.cpu()
for det in all_detections:
im = det.visualize(im)
return im
@torch.no_grad()
def forward(self, im: torch.Tensor, cache_id: str = None, track=True, detections=None, **synthesis_kwargs) -> torch.Tensor:
assert im.dtype == torch.uint8
im = tops.to_cuda(im)
all_detections = detections
if detections is None:
if self.load_cache:
all_detections = self.detector.forward_and_cache(im, cache_id)
else:
all_detections = self.detector(im)
if hasattr(self, "tracker") and track:
[_.pre_process() for _ in all_detections]
boxes = np.concatenate([_.boxes for _ in all_detections])
boxes = [Detection(box) for box in boxes]
self.tracker.step(boxes)
track_ids = self.tracker.detections_matched_ids
z_idx = []
for track_id in track_ids:
if track_id not in self.track_to_z_idx:
self.track_to_z_idx[track_id] = np.random.randint(0, 2**32-1)
z_idx.append(self.track_to_z_idx[track_id])
z_idx = np.array(z_idx)
idx_offset = 0
for detection in all_detections:
zs = None
if hasattr(self, "tracker") and track:
zs = z_idx[idx_offset:idx_offset+len(detection)]
idx_offset += len(detection)
im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs)
return im.cpu()
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
|