Spaces:
Build error
Build error
''' | |
Utilities for instrumenting a torch model. | |
InstrumentedModel will wrap a pytorch model and allow hooking | |
arbitrary layers to monitor or modify their output directly. | |
Modified by Erik Härkönen: | |
- 29.11.2019: Unhooking bugfix | |
- 25.01.2020: Offset edits, removed old API | |
''' | |
import torch, numpy, types | |
from collections import OrderedDict | |
class InstrumentedModel(torch.nn.Module): | |
''' | |
A wrapper for hooking, probing and intervening in pytorch Modules. | |
Example usage: | |
``` | |
model = load_my_model() | |
with inst as InstrumentedModel(model): | |
inst.retain_layer(layername) | |
inst.edit_layer(layername, 0.5, target_features) | |
inst.edit_layer(layername, offset=offset_tensor) | |
inst(inputs) | |
original_features = inst.retained_layer(layername) | |
``` | |
''' | |
def __init__(self, model): | |
super(InstrumentedModel, self).__init__() | |
self.model = model | |
self._retained = OrderedDict() | |
self._ablation = {} | |
self._replacement = {} | |
self._offset = {} | |
self._hooked_layer = {} | |
self._old_forward = {} | |
def __enter__(self): | |
return self | |
def __exit__(self, type, value, traceback): | |
self.close() | |
def forward(self, *inputs, **kwargs): | |
return self.model(*inputs, **kwargs) | |
def retain_layer(self, layername): | |
''' | |
Pass a fully-qualified layer name (E.g., module.submodule.conv3) | |
to hook that layer and retain its output each time the model is run. | |
A pair (layername, aka) can be provided, and the aka will be used | |
as the key for the retained value instead of the layername. | |
''' | |
self.retain_layers([layername]) | |
def retain_layers(self, layernames): | |
''' | |
Retains a list of a layers at once. | |
''' | |
self.add_hooks(layernames) | |
for layername in layernames: | |
aka = layername | |
if not isinstance(aka, str): | |
layername, aka = layername | |
if aka not in self._retained: | |
self._retained[aka] = None | |
def retained_features(self): | |
''' | |
Returns a dict of all currently retained features. | |
''' | |
return OrderedDict(self._retained) | |
def retained_layer(self, aka=None, clear=False): | |
''' | |
Retrieve retained data that was previously hooked by retain_layer. | |
Call this after the model is run. If clear is set, then the | |
retained value will return and also cleared. | |
''' | |
if aka is None: | |
# Default to the first retained layer. | |
aka = next(self._retained.keys().__iter__()) | |
result = self._retained[aka] | |
if clear: | |
self._retained[aka] = None | |
return result | |
def edit_layer(self, layername, ablation=None, replacement=None, offset=None): | |
''' | |
Pass a fully-qualified layer name (E.g., module.submodule.conv3) | |
to hook that layer and modify its output each time the model is run. | |
The output of the layer will be modified to be a convex combination | |
of the replacement and x interpolated according to the ablation, i.e.: | |
`output = x * (1 - a) + (r * a)`. | |
Additionally or independently, an offset can be added to the output. | |
''' | |
if not isinstance(layername, str): | |
layername, aka = layername | |
else: | |
aka = layername | |
# The default ablation if a replacement is specified is 1.0. | |
if ablation is None and replacement is not None: | |
ablation = 1.0 | |
self.add_hooks([(layername, aka)]) | |
if ablation is not None: | |
self._ablation[aka] = ablation | |
if replacement is not None: | |
self._replacement[aka] = replacement | |
if offset is not None: | |
self._offset[aka] = offset | |
# If needed, could add an arbitrary postprocessing lambda here. | |
def remove_edits(self, layername=None, remove_offset=True, remove_replacement=True): | |
''' | |
Removes edits at the specified layer, or removes edits at all layers | |
if no layer name is specified. | |
''' | |
if layername is None: | |
if remove_replacement: | |
self._ablation.clear() | |
self._replacement.clear() | |
if remove_offset: | |
self._offset.clear() | |
return | |
if not isinstance(layername, str): | |
layername, aka = layername | |
else: | |
aka = layername | |
if remove_replacement and aka in self._ablation: | |
del self._ablation[aka] | |
if remove_replacement and aka in self._replacement: | |
del self._replacement[aka] | |
if remove_offset and aka in self._offset: | |
del self._offset[aka] | |
def add_hooks(self, layernames): | |
''' | |
Sets up a set of layers to be hooked. | |
Usually not called directly: use edit_layer or retain_layer instead. | |
''' | |
needed = set() | |
aka_map = {} | |
for name in layernames: | |
aka = name | |
if not isinstance(aka, str): | |
name, aka = name | |
if self._hooked_layer.get(aka, None) != name: | |
aka_map[name] = aka | |
needed.add(name) | |
if not needed: | |
return | |
for name, layer in self.model.named_modules(): | |
if name in aka_map: | |
needed.remove(name) | |
aka = aka_map[name] | |
self._hook_layer(layer, name, aka) | |
for name in needed: | |
raise ValueError('Layer %s not found in model' % name) | |
def _hook_layer(self, layer, layername, aka): | |
''' | |
Internal method to replace a forward method with a closure that | |
intercepts the call, and tracks the hook so that it can be reverted. | |
''' | |
if aka in self._hooked_layer: | |
raise ValueError('Layer %s already hooked' % aka) | |
if layername in self._old_forward: | |
raise ValueError('Layer %s already hooked' % layername) | |
self._hooked_layer[aka] = layername | |
self._old_forward[layername] = (layer, aka, | |
layer.__dict__.get('forward', None)) | |
editor = self | |
original_forward = layer.forward | |
def new_forward(self, *inputs, **kwargs): | |
original_x = original_forward(*inputs, **kwargs) | |
x = editor._postprocess_forward(original_x, aka) | |
return x | |
layer.forward = types.MethodType(new_forward, layer) | |
def _unhook_layer(self, aka): | |
''' | |
Internal method to remove a hook, restoring the original forward method. | |
''' | |
if aka not in self._hooked_layer: | |
return | |
layername = self._hooked_layer[aka] | |
layer, check, old_forward = self._old_forward[layername] | |
assert check == aka | |
if old_forward is None: | |
if 'forward' in layer.__dict__: | |
del layer.__dict__['forward'] | |
else: | |
layer.forward = old_forward | |
del self._old_forward[layername] | |
del self._hooked_layer[aka] | |
if aka in self._ablation: | |
del self._ablation[aka] | |
if aka in self._replacement: | |
del self._replacement[aka] | |
if aka in self._offset: | |
del self._offset[aka] | |
if aka in self._retained: | |
del self._retained[aka] | |
def _postprocess_forward(self, x, aka): | |
''' | |
The internal method called by the hooked layers after they are run. | |
''' | |
# Retain output before edits, if desired. | |
if aka in self._retained: | |
self._retained[aka] = x.detach() | |
# Apply replacement edit | |
a = make_matching_tensor(self._ablation, aka, x) | |
if a is not None: | |
x = x * (1 - a) | |
v = make_matching_tensor(self._replacement, aka, x) | |
if v is not None: | |
x += (v * a) | |
# Apply offset edit | |
b = make_matching_tensor(self._offset, aka, x) | |
if b is not None: | |
x = x + b | |
return x | |
def close(self): | |
''' | |
Unhooks all hooked layers in the model. | |
''' | |
for aka in list(self._old_forward.keys()): | |
self._unhook_layer(aka) | |
assert len(self._old_forward) == 0 | |
def make_matching_tensor(valuedict, name, data): | |
''' | |
Converts `valuedict[name]` to be a tensor with the same dtype, device, | |
and dimension count as `data`, and caches the converted tensor. | |
''' | |
v = valuedict.get(name, None) | |
if v is None: | |
return None | |
if not isinstance(v, torch.Tensor): | |
# Accept non-torch data. | |
v = torch.from_numpy(numpy.array(v)) | |
valuedict[name] = v | |
if not v.device == data.device or not v.dtype == data.dtype: | |
# Ensure device and type matches. | |
assert not v.requires_grad, '%s wrong device or type' % (name) | |
v = v.to(device=data.device, dtype=data.dtype) | |
valuedict[name] = v | |
if len(v.shape) < len(data.shape): | |
# Ensure dimensions are unsqueezed as needed. | |
assert not v.requires_grad, '%s wrong dimensions' % (name) | |
v = v.view((1,) + tuple(v.shape) + | |
(1,) * (len(data.shape) - len(v.shape) - 1)) | |
valuedict[name] = v | |
return v | |