ahsanMah commited on
Commit
cb54c64
·
1 Parent(s): c2b030e

switching to conditional gaussian

Browse files
Files changed (1) hide show
  1. flowutils.py +63 -7
flowutils.py CHANGED
@@ -5,6 +5,61 @@ import numpy as np
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange, repeat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  def build_flows(
@@ -26,13 +81,14 @@ def build_flows(
26
 
27
  # Set base distribution
28
 
29
- # context_encoder = nn.Sequential([
30
- # nn.Linear(context_size, context_size),
31
- # nn.SiLU(),
32
- # nn.Linear(context_size, context_size)
33
- # ])
 
34
 
35
- q0 = nf.distributions.DiagGaussian(latent_size, trainable=True)
36
 
37
  # Construct flow model
38
  model = nf.ConditionalNormalizingFlow(q0, flows)
@@ -239,7 +295,7 @@ class PatchFlow(torch.nn.Module):
239
  context=context_vector,
240
  )
241
 
242
- loss = -torch.mean(flow_model.flow.q0.log_prob(z) + ldj)
243
  loss *= n_patches
244
 
245
  if train:
 
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange, repeat
8
+ from normflows.distributions import BaseDistribution
9
+
10
+
11
+ class ConditionalDiagGaussian(BaseDistribution):
12
+ """
13
+ Conditional multivariate Gaussian distribution with diagonal
14
+ covariance matrix, parameters are obtained by a context encoder,
15
+ context meaning the variable to condition on
16
+ """
17
+
18
+ def __init__(self, shape, context_encoder):
19
+ """Constructor
20
+
21
+ Args:
22
+ shape: Tuple with shape of data, if int shape has one dimension
23
+ context_encoder: Computes mean and log of the standard deviation
24
+ of the Gaussian, mean is the first half of the last dimension
25
+ of the encoder output, log of the standard deviation the second
26
+ half
27
+ """
28
+ super().__init__()
29
+ if isinstance(shape, int):
30
+ shape = (shape,)
31
+ if isinstance(shape, list):
32
+ shape = tuple(shape)
33
+ self.shape = shape
34
+ self.n_dim = len(shape)
35
+ self.d = np.prod(shape)
36
+ self.context_encoder = context_encoder
37
+
38
+ def forward(self, num_samples=1, context=None):
39
+ encoder_output = self.context_encoder(context)
40
+ split_ind = encoder_output.shape[-1] // 2
41
+ mean = encoder_output[..., :split_ind]
42
+ log_scale = encoder_output[..., split_ind:]
43
+ eps = torch.randn(
44
+ (num_samples,) + self.shape, dtype=mean.dtype, device=mean.device
45
+ )
46
+ z = mean + torch.exp(log_scale) * eps
47
+ log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
48
+ log_scale + 0.5 * torch.pow(eps, 2), list(range(1, self.n_dim + 1))
49
+ )
50
+ return z, log_p
51
+
52
+ def log_prob(self, z, context=None):
53
+ encoder_output = self.context_encoder(context)
54
+ split_ind = encoder_output.shape[-1] // 2
55
+ mean = encoder_output[..., :split_ind]
56
+ log_scale = encoder_output[..., split_ind:]
57
+ log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
58
+ log_scale + 0.5 * torch.pow((z - mean) / torch.exp(log_scale), 2),
59
+ list(range(1, self.n_dim + 1)),
60
+ )
61
+ return log_p
62
+
63
 
64
 
65
  def build_flows(
 
81
 
82
  # Set base distribution
83
 
84
+ context_encoder = nn.Sequential(
85
+ nn.Linear(context_size, context_size),
86
+ nn.SiLU(),
87
+ # output mean and scales for K=latent_size dimensions
88
+ nn.Linear(context_size, latent_size * 2)
89
+ )
90
 
91
+ q0 = ConditionalDiagGaussian(latent_size, context_encoder)
92
 
93
  # Construct flow model
94
  model = nf.ConditionalNormalizingFlow(q0, flows)
 
295
  context=context_vector,
296
  )
297
 
298
+ loss = -torch.mean(flow_model.flow.q0.log_prob(z, context_vector) + ldj)
299
  loss *= n_patches
300
 
301
  if train: