geekyrakshit commited on
Commit
ffad30c
1 Parent(s): 021f172

added mixed precision

Browse files
enhance_me/mirnet/mirnet.py CHANGED
@@ -5,7 +5,7 @@ from typing import List
5
  from datetime import datetime
6
 
7
  from tensorflow import keras
8
- from tensorflow.keras import optimizers, models
9
 
10
  from wandb.keras import WandbCallback
11
 
@@ -59,12 +59,16 @@ class MIRNet:
59
 
60
  def build_model(
61
  self,
 
62
  num_recursive_residual_groups: int = 3,
63
  num_multi_scale_residual_blocks: int = 2,
64
  channels: int = 64,
65
  learning_rate: float = 1e-4,
66
  epsilon: float = 1e-3,
67
  ):
 
 
 
68
  self.model = build_mirnet_model(
69
  num_rrg=num_recursive_residual_groups,
70
  num_mrb=num_multi_scale_residual_blocks,
 
5
  from datetime import datetime
6
 
7
  from tensorflow import keras
8
+ from tensorflow.keras import optimizers, models, mixed_precision
9
 
10
  from wandb.keras import WandbCallback
11
 
 
59
 
60
  def build_model(
61
  self,
62
+ use_mixed_precision: bool = False,
63
  num_recursive_residual_groups: int = 3,
64
  num_multi_scale_residual_blocks: int = 2,
65
  channels: int = 64,
66
  learning_rate: float = 1e-4,
67
  epsilon: float = 1e-3,
68
  ):
69
+ if use_mixed_precision:
70
+ policy = mixed_precision.Policy("mixed_float16")
71
+ mixed_precision.set_global_policy(policy)
72
  self.model = build_mirnet_model(
73
  num_rrg=num_recursive_residual_groups,
74
  num_mrb=num_multi_scale_residual_blocks,
notebooks/enhance_me_train.ipynb CHANGED
@@ -22,7 +22,7 @@
22
  },
23
  "outputs": [],
24
  "source": [
25
- "!git clone https://github.com/soumik12345/enhance-me -b mirnet\n",
26
  "!pip install -qqq wandb streamlit"
27
  ]
28
  },
 
22
  },
23
  "outputs": [],
24
  "source": [
25
+ "!git clone https://github.com/soumik12345/enhance-me\n",
26
  "!pip install -qqq wandb streamlit"
27
  ]
28
  },