File size: 3,431 Bytes
d945eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from dataclasses import dataclass, field
from typing import Any, List, Optional

import torch.nn as nn
from jaxtyping import Float
from torch import Tensor

from sf3d.models.network import get_activation
from sf3d.models.utils import BaseModule


@dataclass
class HeadSpec:
    name: str
    out_channels: int
    n_hidden_layers: int
    output_activation: Optional[str] = None
    output_bias: float = 0.0
    add_to_decoder_features: bool = False
    shape: Optional[list[int]] = None


class MultiHeadEstimator(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        triplane_features: int = 1024

        n_layers: int = 2
        hidden_features: int = 512
        activation: str = "relu"

        pool: str = "max"
        # Literal["mean", "max"] = "mean"  # noqa: F821

        heads: List[HeadSpec] = field(default_factory=lambda: [])

    cfg: Config

    def configure(self):
        layers = []
        cur_features = self.cfg.triplane_features * 3
        for _ in range(self.cfg.n_layers):
            layers.append(
                nn.Conv2d(
                    cur_features,
                    self.cfg.hidden_features,
                    kernel_size=3,
                    padding=0,
                    stride=2,
                )
            )
            layers.append(self.make_activation(self.cfg.activation))

            cur_features = self.cfg.hidden_features

        self.layers = nn.Sequential(*layers)

        assert len(self.cfg.heads) > 0
        heads = {}
        for head in self.cfg.heads:
            head_layers = []
            for i in range(head.n_hidden_layers):
                head_layers += [
                    nn.Linear(
                        self.cfg.hidden_features,
                        self.cfg.hidden_features,
                    ),
                    self.make_activation(self.cfg.activation),
                ]
            head_layers += [
                nn.Linear(
                    self.cfg.hidden_features,
                    head.out_channels,
                ),
            ]
            heads[head.name] = nn.Sequential(*head_layers)
        self.heads = nn.ModuleDict(heads)

    def make_activation(self, activation):
        if activation == "relu":
            return nn.ReLU(inplace=True)
        elif activation == "silu":
            return nn.SiLU(inplace=True)
        else:
            raise NotImplementedError

    def forward(
        self,
        triplane: Float[Tensor, "B 3 F Ht Wt"],
    ) -> dict[str, Any]:
        x = self.layers(
            triplane.reshape(
                triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1]
            )
        )

        if self.cfg.pool == "max":
            x = x.amax(dim=[-2, -1])
        elif self.cfg.pool == "mean":
            x = x.mean(dim=[-2, -1])
        else:
            raise NotImplementedError

        out = {
            ("decoder_" if head.add_to_decoder_features else "")
            + head.name: get_activation(head.output_activation)(
                self.heads[head.name](x) + head.output_bias
            )
            for head in self.cfg.heads
        }
        for head in self.cfg.heads:
            if head.shape:
                head_name = (
                    "decoder_" if head.add_to_decoder_features else ""
                ) + head.name
                out[head_name] = out[head_name].reshape(*head.shape)

        return out