File size: 626 Bytes
2f4febc
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN
from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail

# MOVE IT SOMERWHERE ELSE
def update_weights_ema(tgt_model, src_model, beta=0.999):
    for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()):
        self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta)
    for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()):
        self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta)