File size: 629 Bytes
32ca76b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from models.invblock import INV_block


class Hinet(torch.nn.Module):

    def __init__(self, in_channel=2, num_layers=16):
        super(Hinet, self).__init__()
        self.inv_blocks = torch.nn.ModuleList([INV_block(in_channel) for _ in range(num_layers)])

    def forward(self, x1, x2, rev=False):
        # x1:cover
        # x2:secret
        if not rev:
            for inv_block in self.inv_blocks:
                x1, x2 = inv_block(x1, x2)
        else:
            for inv_block in reversed(self.inv_blocks):
                x1, x2 = inv_block(x1, x2, rev=True)
        return x1, x2