PeteBleackley commited on
Commit
064f030
·
1 Parent(s): 646cf92

HierarchicalLogits layer and function to create base models

Browse files
src/models/layers/HierarchicalLogits.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Thu Aug 19 15:22:21 2021
5
+
6
+ @author: peter
7
+ """
8
+
9
+ import keras
10
+ import tensorflow
11
+
12
+ class LeafNode(keras.layers.Layer):
13
+ def __init__(self):
14
+ self.bias = self.add_weight(shape=(1,),
15
+ initializer='random_normal',
16
+ trainable=True)
17
+
18
+ def build(self,input_shape):
19
+ pass
20
+
21
+ def call(self,X,training=None):
22
+ return self.bias
23
+
24
+ class HierarchicalLogits(keras.layers.Layer):
25
+
26
+ def __init__(self,n):#structure,row=-1,order=None):
27
+ super(HierarchicalLogits,self).__init__()
28
+ # self.structure = structure
29
+ # self.row = row
30
+ self.normal = None
31
+
32
+ self.n_outputs = n
33
+ l = n//2
34
+ if l==1:
35
+ self.left=LeafNode()
36
+ else:
37
+ self.left=HierarchicalLogits(l)
38
+ if n-l==1:
39
+ self.right=LeafNode()
40
+ else:
41
+ self.right=HierarchicalLogits(n-l)
42
+ self.concat = keras.layers.Concatenate()
43
+
44
+
45
+ def build(self,input_shape):
46
+ self.normal = self.add_weight(shape=(input_shape[-1],),
47
+ initializer='random_normal',
48
+ trainable=True)
49
+ self.left.build(input_shape)
50
+ self.right.build(input_shape)
51
+
52
+ def compute_output_shape(self, input_shape):
53
+ return input_shape[:-1]+(self.n_outputs,)
54
+
55
+ def call(self,X,training=None):
56
+
57
+ y=tensorflow.tensordot(X,self.normal,1)
58
+ result = self.concat([self.left(X)+y,self.right(X)]-y)
59
+ return result
60
+
61
+
62
+ def get_config(self):
63
+ return {'n':self.n_outputs}
64
+
65
+ @classmethod
66
+ def from_config(cls,config):
67
+ return cls(config['n'])
68
+
69
+
src/models/quarac_base_model.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Wed Aug 23 09:50:14 2023
5
+
6
+ @author: peter
7
+ """
8
+
9
+ import keras
10
+ import layers
11
+
12
+ def quarac_base_model(vocab_size,width,depth,decoder=True):
13
+ stack = [keras.layers.Embedding(vocab_size,width)]
14
+ for _ in range(depth):
15
+ stack.append(layers.HyenaLayer(causal=decoder))
16
+ stack.append(keras.layers.Timedistributed(layers.HierarchicalLogits()))
17
+ stack.append(keras.layers.Timedistributed(keras.layers.Softmax()))
18
+ return keras.models.Sequential(stack)