from res.impl.HRNetV2 import HRNetV2 import torch class Config: pass class HRNetV2Wrapper: def __init__(self): config = Config() config.data_len = 5000 config.kernel_size = 5 config.dilation = 1 config.num_stages = 3 config.num_blocks = 6 config.num_modules = [1, 1, 1, 4, 3] config.use_bottleneck = [1, 0, 0, 0, 0] config.stage1_channels = 128 config.num_channels_init = 48 config.interpolate_mode = "linear" config.output_size = 3 self.model = HRNetV2(config) weights = torch.load("./res/models/hrnetv2/weights.pth") self.model.load_state_dict(weights) self.model = self.model.to("cpu").eval()