alessandro trinca tornidor commited on
Commit
4364464
·
1 Parent(s): 540a51f

feat: use setup_logging() in routes.py, update docstrings

Browse files
Files changed (2) hide show
  1. lisa_on_cuda/LISA.py +18 -2
  2. lisa_on_cuda/routes.py +12 -5
lisa_on_cuda/LISA.py CHANGED
@@ -16,15 +16,27 @@ def dice_loss(
16
  num_masks: float,
17
  scale=1000, # 100000.0,
18
  eps=1e-6,
19
- ):
20
  """
21
- Compute the DICE loss, similar to generalized IOU for masks
 
 
 
 
 
22
  Args:
23
  inputs: A float tensor of arbitrary shape.
24
  The predictions for each example.
25
  targets: A float tensor with the same shape as inputs. Stores the binary
26
  classification label for each element in inputs
27
  (0 for the negative class and 1 for the positive class).
 
 
 
 
 
 
 
28
  """
29
  inputs = inputs.sigmoid()
30
  inputs = inputs.flatten(1, 2)
@@ -32,7 +44,9 @@ def dice_loss(
32
  numerator = 2 * (inputs / scale * targets).sum(-1)
33
  denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1)
34
  loss = 1 - (numerator + eps) / (denominator + eps)
 
35
  loss = loss.sum() / (num_masks + 1e-8)
 
36
  return loss
37
 
38
 
@@ -48,6 +62,8 @@ def sigmoid_ce_loss(
48
  targets: A float tensor with the same shape as inputs. Stores the binary
49
  classification label for each element in inputs
50
  (0 for the negative class and 1 for the positive class).
 
 
51
  Returns:
52
  Loss tensor
53
  """
 
16
  num_masks: float,
17
  scale=1000, # 100000.0,
18
  eps=1e-6,
19
+ ) -> torch.Tensor:
20
  """
21
+ Compute the DICE loss, similar to generalized IOU for masks.
22
+ Arguments 'num_masks', 'scale', 'eps' and return value 'loss' are undocumented in original project
23
+ https://github.com/dvlab-research/LISA
24
+ About 'num_masks': it's similar to 'avg_factor' in weight_reduce_loss() from
25
+ https://github.com/open-mmlab/mmdetection/blob/e9cae2d0787cd5c2fc6165a6061f92fa09e48fb1/mmdet/models/losses/utils.py#L30
26
+
27
  Args:
28
  inputs: A float tensor of arbitrary shape.
29
  The predictions for each example.
30
  targets: A float tensor with the same shape as inputs. Stores the binary
31
  classification label for each element in inputs
32
  (0 for the negative class and 1 for the positive class).
33
+ num_masks: Average factor when computing the mean of losses (?)
34
+ scale: weight factor applied before computing mean of losses (?)
35
+ eps: Avoid dividing by zero (?)
36
+
37
+ return:
38
+ Processed loss values.
39
+
40
  """
41
  inputs = inputs.sigmoid()
42
  inputs = inputs.flatten(1, 2)
 
44
  numerator = 2 * (inputs / scale * targets).sum(-1)
45
  denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1)
46
  loss = 1 - (numerator + eps) / (denominator + eps)
47
+
48
  loss = loss.sum() / (num_masks + 1e-8)
49
+
50
  return loss
51
 
52
 
 
62
  targets: A float tensor with the same shape as inputs. Stores the binary
63
  classification label for each element in inputs
64
  (0 for the negative class and 1 for the positive class).
65
+ num_masks: Average factor when computing the mean of losses (?)
66
+
67
  Returns:
68
  Loss tensor
69
  """
lisa_on_cuda/routes.py CHANGED
@@ -1,19 +1,26 @@
1
  import json
2
- import logging
 
 
 
3
  from fastapi import APIRouter
4
 
5
- from lisa_on_cuda.utils import session_logger
 
6
 
 
7
 
 
 
 
8
  router = APIRouter()
9
 
10
 
11
  @router.get("/health")
12
- @session_logger.set_uuid_logging
13
  def health() -> str:
14
  try:
15
- logging.info("health check")
16
  return json.dumps({"msg": "ok"})
17
  except Exception as e:
18
- logging.error(f"exception:{e}.")
19
  return json.dumps({"msg": "request failed"})
 
1
  import json
2
+ import os
3
+
4
+ import structlog
5
+ from dotenv import load_dotenv
6
  from fastapi import APIRouter
7
 
8
+ from samgis_core.utilities.session_logger import setup_logging
9
+
10
 
11
+ load_dotenv()
12
 
13
+ log_level = os.getenv("LOG_LEVEL", "INFO")
14
+ setup_logging(log_level=log_level)
15
+ app_logger = structlog.stdlib.get_logger()
16
  router = APIRouter()
17
 
18
 
19
  @router.get("/health")
 
20
  def health() -> str:
21
  try:
22
+ app_logger.info("health check")
23
  return json.dumps({"msg": "ok"})
24
  except Exception as e:
25
+ app_logger.error(f"exception:{e}.")
26
  return json.dumps({"msg": "request failed"})