|
def replace_layer_recursive(model, old_layer, new_layer): |
|
for name, layer in model._modules.items(): |
|
if layer == old_layer: |
|
model._modules[name] = new_layer |
|
return True |
|
elif replace_layer_recursive(layer, old_layer, new_layer): |
|
return True |
|
return False |
|
|
|
|
|
def replace_all_layer_type_recursive(model, old_layer_type, new_layer): |
|
for name, layer in model._modules.items(): |
|
if isinstance(layer, old_layer_type): |
|
model._modules[name] = new_layer |
|
replace_all_layer_type_recursive(layer, old_layer_type, new_layer) |
|
|
|
|
|
def find_layer_types_recursive(model, layer_types): |
|
def predicate(layer): |
|
return type(layer) in layer_types |
|
return find_layer_predicate_recursive(model, predicate) |
|
|
|
|
|
def find_layer_predicate_recursive(model, predicate): |
|
result = [] |
|
for name, layer in model._modules.items(): |
|
if predicate(layer): |
|
result.append(layer) |
|
result.extend(find_layer_predicate_recursive(layer, predicate)) |
|
return result |
|
|