ECG_Delineation / model.py
wogh2012's picture
refactor: remove aitiautils
b7914f1
raw
history blame contribute delete
736 Bytes
from res.impl.HRNetV2 import HRNetV2
import torch
class Config:
pass
class HRNetV2Wrapper:
def __init__(self):
config = Config()
config.data_len = 5000
config.kernel_size = 5
config.dilation = 1
config.num_stages = 3
config.num_blocks = 6
config.num_modules = [1, 1, 1, 4, 3]
config.use_bottleneck = [1, 0, 0, 0, 0]
config.stage1_channels = 128
config.num_channels_init = 48
config.interpolate_mode = "linear"
config.output_size = 3
self.model = HRNetV2(config)
weights = torch.load("./res/models/hrnetv2/weights.pth")
self.model.load_state_dict(weights)
self.model = self.model.to("cpu").eval()