Delete loss
Browse files- loss/__init__.py +0 -0
- loss/__pycache__/__init__.cpython-39.pyc +0 -0
- loss/__pycache__/loss.cpython-39.pyc +0 -0
- loss/loss.py +0 -55
loss/__init__.py
DELETED
File without changes
|
loss/__pycache__/__init__.cpython-39.pyc
DELETED
Binary file (124 Bytes)
|
|
loss/__pycache__/loss.cpython-39.pyc
DELETED
Binary file (2.16 kB)
|
|
loss/loss.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
|
4 |
-
class Loss_VAE(nn.Module):
|
5 |
-
def __init__(self):
|
6 |
-
super().__init__()
|
7 |
-
self.mse = nn.MSELoss(reduction='sum')
|
8 |
-
|
9 |
-
def forward(self, recon_x, x, mu, log_var):
|
10 |
-
mse = self.mse(recon_x, x)
|
11 |
-
kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
|
12 |
-
loss = mse + kld
|
13 |
-
return loss
|
14 |
-
|
15 |
-
|
16 |
-
def DiceScore(
|
17 |
-
y_pred: torch.Tensor,
|
18 |
-
y: torch.Tensor,
|
19 |
-
include_background: bool = True,
|
20 |
-
) -> torch.Tensor:
|
21 |
-
"""Computes Dice score metric from full size Tensor and collects average.
|
22 |
-
Args:
|
23 |
-
y_pred: input data to compute, typical segmentation model output.
|
24 |
-
It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
|
25 |
-
should be binarized.
|
26 |
-
y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch.
|
27 |
-
The values should be binarized.
|
28 |
-
include_background: whether to skip Dice computation on the first channel of
|
29 |
-
the predicted output. Defaults to True.
|
30 |
-
Returns:
|
31 |
-
Dice scores per batch and per class, (shape [batch_size, num_classes]).
|
32 |
-
Raises:
|
33 |
-
ValueError: when `y_pred` and `y` have different shapes.
|
34 |
-
"""
|
35 |
-
|
36 |
-
y = y.float()
|
37 |
-
y_pred = y_pred.float()
|
38 |
-
|
39 |
-
if y.shape != y_pred.shape:
|
40 |
-
raise ValueError("y_pred and y should have same shapes.")
|
41 |
-
|
42 |
-
# reducing only spatial dimensions (not batch nor channels)
|
43 |
-
n_len = len(y_pred.shape)
|
44 |
-
reduce_axis = list(range(2, n_len))
|
45 |
-
intersection = torch.sum(y * y_pred, dim=reduce_axis)
|
46 |
-
|
47 |
-
y_o = torch.sum(y, reduce_axis)
|
48 |
-
y_pred_o = torch.sum(y_pred, dim=reduce_axis)
|
49 |
-
denominator = y_o + y_pred_o
|
50 |
-
|
51 |
-
return torch.where(
|
52 |
-
denominator > 0,
|
53 |
-
(2.0 * intersection) / denominator,
|
54 |
-
torch.tensor(float("1"), device=y_o.device),
|
55 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|