lauracabayol commited on
Commit
7b08294
·
1 Parent(s): dabc757

remove da target redsfhit

Browse files
Files changed (1) hide show
  1. temps/temps.py +7 -15
temps/temps.py CHANGED
@@ -31,17 +31,15 @@ class Temps_module():
31
 
32
 
33
 
34
- def _get_dataloaders(self, input_data, target_data, input_data_DA, target_data_DA, val_fraction=0.1):
35
  input_data = torch.Tensor(input_data)
36
  target_data = torch.Tensor(target_data)
37
  if input_data_DA is not None:
38
  input_data_DA = torch.Tensor(input_data_DA)
39
- target_data_DA = torch.Tensor(target_data_DA)
40
  else:
41
  input_data_DA = input_data.clone()
42
- target_data_DA = target_data.clone()
43
 
44
- dataset = TensorDataset(input_data, input_data_DA, target_data, target_data_DA)
45
  trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
46
  loader_train = DataLoader(trainig_dataset, batch_size=self.batch_size, shuffle = True)
47
  loader_val = DataLoader(val_dataset, batch_size=64, shuffle = True)
@@ -76,7 +74,6 @@ class Temps_module():
76
  def train(self,input_data,
77
  input_data_DA,
78
  target_data,
79
- target_data_DA,
80
  nepochs=10,
81
  step_size = 100,
82
  val_fraction=0.1,
@@ -85,7 +82,7 @@ class Temps_module():
85
  self.modelZ = self.modelZ.train()
86
  self.modelF = self.modelF.train()
87
 
88
- loader_train, loader_val = self._get_dataloaders(input_data, target_data, input_data_DA, target_data_DA, val_fraction=0.1)
89
  optimizerZ = optim.Adam(self.modelZ.parameters(), lr=lr, weight_decay=weight_decay)
90
  optimizerF = optim.Adam(self.modelF.parameters(), lr=lr, weight_decay=weight_decay)
91
 
@@ -98,7 +95,7 @@ class Temps_module():
98
  self.loss_train, self.loss_validation = [],[]
99
 
100
  for epoch in range(nepochs):
101
- for input_data, input_data_da, target_data, target_data_DA in loader_train:
102
  _loss_train, _loss_validation = [],[]
103
 
104
  input_data = input_data.to(self.device)
@@ -106,8 +103,8 @@ class Temps_module():
106
 
107
  if self.da:
108
  input_data_da = input_data_da.to(self.device)
109
- target_data_DA = target_data_DA.to(self.device)
110
 
 
111
  optimizerF.zero_grad()
112
  optimizerZ.zero_grad()
113
 
@@ -120,12 +117,7 @@ class Temps_module():
120
  sig = torch.exp(logsig)
121
 
122
  lossZ = self._loss_function(mu, sig, logmix_coeff, target_data)
123
-
124
- #mu, logsig, logmix_coeff = self.modelZ(features_DA)
125
- #logsig = torch.clamp(logsig,-6,2)
126
- #sig = torch.exp(logsig)
127
-
128
- #lossZ_DA = self._loss_function(mu, sig, logmix_coeff, target_data_DA)
129
 
130
  if self.da:
131
  lossDA = maximum_mean_discrepancy(features, features_DA, kernel_type='rbf')
@@ -145,7 +137,7 @@ class Temps_module():
145
 
146
  self.loss_train.append(np.mean(_loss_train))
147
 
148
- for input_data, _, target_data, _ in loader_val:
149
 
150
  input_data = input_data.to(self.device)
151
  target_data = target_data.to(self.device)
 
31
 
32
 
33
 
34
+ def _get_dataloaders(self, input_data, target_data, input_data_DA, val_fraction=0.1):
35
  input_data = torch.Tensor(input_data)
36
  target_data = torch.Tensor(target_data)
37
  if input_data_DA is not None:
38
  input_data_DA = torch.Tensor(input_data_DA)
 
39
  else:
40
  input_data_DA = input_data.clone()
 
41
 
42
+ dataset = TensorDataset(input_data, input_data_DA, target_data)
43
  trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
44
  loader_train = DataLoader(trainig_dataset, batch_size=self.batch_size, shuffle = True)
45
  loader_val = DataLoader(val_dataset, batch_size=64, shuffle = True)
 
74
  def train(self,input_data,
75
  input_data_DA,
76
  target_data,
 
77
  nepochs=10,
78
  step_size = 100,
79
  val_fraction=0.1,
 
82
  self.modelZ = self.modelZ.train()
83
  self.modelF = self.modelF.train()
84
 
85
+ loader_train, loader_val = self._get_dataloaders(input_data, target_data, input_data_DA, val_fraction=0.1)
86
  optimizerZ = optim.Adam(self.modelZ.parameters(), lr=lr, weight_decay=weight_decay)
87
  optimizerF = optim.Adam(self.modelF.parameters(), lr=lr, weight_decay=weight_decay)
88
 
 
95
  self.loss_train, self.loss_validation = [],[]
96
 
97
  for epoch in range(nepochs):
98
+ for input_data, input_data_da, target_data in loader_train:
99
  _loss_train, _loss_validation = [],[]
100
 
101
  input_data = input_data.to(self.device)
 
103
 
104
  if self.da:
105
  input_data_da = input_data_da.to(self.device)
 
106
 
107
+
108
  optimizerF.zero_grad()
109
  optimizerZ.zero_grad()
110
 
 
117
  sig = torch.exp(logsig)
118
 
119
  lossZ = self._loss_function(mu, sig, logmix_coeff, target_data)
120
+
 
 
 
 
 
121
 
122
  if self.da:
123
  lossDA = maximum_mean_discrepancy(features, features_DA, kernel_type='rbf')
 
137
 
138
  self.loss_train.append(np.mean(_loss_train))
139
 
140
+ for input_data, _, target_data in loader_val:
141
 
142
  input_data = input_data.to(self.device)
143
  target_data = target_data.to(self.device)