Bokanovskii commited on
Commit
6be2a43
·
1 Parent(s): a2982e9

Upload shred_model.py

Browse files
Files changed (1) hide show
  1. shred_model.py +109 -0
shred_model.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow.keras as keras
2
+ import tensorflow as tf
3
+
4
+ import PIL.Image
5
+ import PIL.ImageOps
6
+
7
+ import numpy as np
8
+
9
+ IMG_SIZE = [256,256]
10
+
11
+ def prepare_image(path):
12
+ # Load the image with PIL
13
+ img = PIL.Image.open(path)
14
+ img, rotated = exif_transpose(img)
15
+ img = img.resize(IMG_SIZE)
16
+ return np.expand_dims(np.asarray(img), axis=0)
17
+
18
+ # def prepare_model(checkpoint_folder_path):
19
+ # base_model = keras.applications.EfficientNetB7(
20
+ # weights='imagenet',
21
+ # include_top=False,
22
+ # input_shape=tuple(IMG_SIZE + [3])
23
+ # )
24
+ # base_model.trainable = True
25
+
26
+ # model = keras.Sequential()
27
+ # model.add(keras.Input(shape=tuple(IMG_SIZE + [3])))
28
+ # model.add(keras.layers.RandomFlip("horizontal"))
29
+ # model.add(keras.layers.RandomRotation(0.1))
30
+ # model.add(base_model)
31
+ # model.add(keras.layers.GlobalMaxPooling2D())
32
+ # model.add(keras.layers.Dense(1, activation='sigmoid'))
33
+
34
+ # model.compile(optimizer=keras.optimizers.Adam(1e-5), # Low learning rate
35
+ # loss=keras.losses.BinaryCrossentropy(from_logits=False),
36
+ # metrics=[keras.metrics.BinaryAccuracy(), 'Precision', 'Recall',
37
+ # tf.keras.metrics.SpecificityAtSensitivity(.9)],)
38
+ # model.load_weights(checkpoint_folder_path)
39
+ # return model
40
+
41
+ def prepare_EfficientNet_model(base_trainable=False, fine_tuning=False):
42
+ base_model = keras.applications.EfficientNetB7(
43
+ weights="imagenet",
44
+ include_top=False,
45
+ input_shape=tuple(IMG_SIZE + [3])
46
+ )
47
+ base_model.trainable = False
48
+
49
+ model = keras.Sequential()
50
+ model.add(keras.Input(shape=tuple(IMG_SIZE + [3])))
51
+ model.add(keras.layers.RandomFlip("horizontal"))
52
+ model.add(keras.layers.RandomRotation(0.1))
53
+ model.add(base_model)
54
+ model.add(keras.layers.GlobalMaxPooling2D())
55
+ model.add(keras.layers.Dense(1, activation='sigmoid'))
56
+
57
+ if not fine_tuning:
58
+ if not base_trainable:
59
+ base_model.trainable = False
60
+ model.compile(optimizer=keras.optimizers.Adam(),
61
+ loss=keras.losses.BinaryCrossentropy(from_logits=False),
62
+ metrics=[keras.metrics.BinaryAccuracy(), 'Precision', 'Recall'],)
63
+ else:
64
+ base_model.trainable = True
65
+ model.compile(optimizer=keras.optimizers.Adam(1e-5), # Low learning rate
66
+ loss=keras.losses.BinaryCrossentropy(from_logits=False),
67
+ metrics=[keras.metrics.BinaryAccuracy(), 'Precision', 'Recall',
68
+ tf.keras.metrics.SpecificityAtSensitivity(.9)],)
69
+ return model
70
+
71
+ def exif_transpose(img):
72
+ if not img:
73
+ return img
74
+
75
+ exif_orientation_tag = 274
76
+
77
+ # Check for EXIF data (only present on some files)
78
+ if hasattr(img, "_getexif") and isinstance(img._getexif(), dict) and exif_orientation_tag in img._getexif():
79
+ exif_data = img._getexif()
80
+ orientation = exif_data[exif_orientation_tag]
81
+
82
+ # Handle EXIF Orientation
83
+ if orientation == 1:
84
+ # Normal image - nothing to do!
85
+ pass
86
+ elif orientation == 2:
87
+ # Mirrored left to right
88
+ img = img.transpose(PIL.Image.FLIP_LEFT_RIGHT)
89
+ elif orientation == 3:
90
+ # Rotated 180 degrees
91
+ img = img.rotate(180)
92
+ elif orientation == 4:
93
+ # Mirrored top to bottom
94
+ img = img.rotate(180).transpose(PIL.Image.FLIP_LEFT_RIGHT)
95
+ elif orientation == 5:
96
+ # Mirrored along top-left diagonal
97
+ img = img.rotate(-90, expand=True).transpose(PIL.Image.FLIP_LEFT_RIGHT)
98
+ elif orientation == 6:
99
+ # Rotated 90 degrees
100
+ img = img.rotate(-90, expand=True)
101
+ elif orientation == 7:
102
+ # Mirrored along top-right diagonal
103
+ img = img.rotate(90, expand=True).transpose(PIL.Image.FLIP_LEFT_RIGHT)
104
+ elif orientation == 8:
105
+ # Rotated 270 degrees
106
+ img = img.rotate(90, expand=True)
107
+ return img, True
108
+ return img, False
109
+