PeteBleackley commited on
Commit
8f1745b
·
1 Parent(s): df051eb

GlobalAttentionPoolingHead layer

Browse files
.ipynb_checkpoints/Model visualisation-checkpoint.ipynb DELETED
@@ -1,6 +0,0 @@
1
- {
2
- "cells": [],
3
- "metadata": {},
4
- "nbformat": 4,
5
- "nbformat_minor": 5
6
- }
 
 
 
 
 
 
 
qarac/models/layers/GlobalAttentionPoolingHead.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Tue Sep 5 07:32:55 2023
5
+
6
+ @author: peter
7
+ """
8
+
9
+ import keras
10
+ import tensorflow
11
+
12
+ class GlobalAttentionPoolingHead(keras.layers.Layer):
13
+
14
+ def __init__(self):
15
+ super(GlobalAttentionPoolingHead,self).__init__()
16
+ self.global_projection = None
17
+ self.local_projection = None
18
+
19
+
20
+ def build(self,input_shape):
21
+ width = input_shape[-1]
22
+ self.global_projection = self.add_weight('global projection',shape=(width,width))
23
+ self.local_projection = self.add_weight('local projection',shape=(width,width))
24
+ self.build=True
25
+
26
+ @tensorflow.function
27
+ def project(self,X):
28
+ return tensorflow.tensordot(X,self.local_projection,axes=1)
29
+
30
+ def attention_function(self,gp):
31
+ @tensorflow.function
32
+ def inner(lp):
33
+ return tensorflow.tensordot(lp,gp,axes=1)
34
+ return inner
35
+
36
+ def call(self,X,training=None):
37
+ gp = tensorflow.linalg.l2_normalize(tensorflow.tensordot([tensorflow.reduce_sum(X,
38
+ axis=1),
39
+ self.global_projection],
40
+ axes=1),
41
+ axis=1)
42
+ lp = tensorflow.linalg.l2_normalize(tensorflow.ragged.map_flat_values(self.project,
43
+ X),
44
+ axis=2)
45
+ attention = tensorflow.ragged.map_flat_values(self.attention_function(gp),
46
+ lp)
47
+ return tensorflow.reduce_sum(attention *X,
48
+ axis=1)
qarac/models/layers/HyenaLayer.py CHANGED
@@ -11,22 +11,33 @@ import keras_nlp
11
  import tensorflow
12
  import warnings
13
 
 
14
  @tensorflow.function
15
  def convolve(x,y):
16
-
17
- fx = tensorflow.vectorized_map(fft, x, warn=False)
18
- fy = tensorflow.vectorized_map(fft, y, warn=False)
 
19
  fz = fx*fy
20
- return tensorflow.vectorized_map(ifft,fz,warn=False)
 
21
 
22
- @tensorflow.function
23
- def fft(x):
24
- return tensorflow.signal.rfft(tensorflow.transpose(x))
25
 
26
- @tensorflow.function
27
- def ifft(x):
28
- return tensorflow.transpose(tensorflow.signal.irfft(x))
29
 
 
 
 
 
 
 
 
 
30
 
31
  class HyenaLayer(keras.layers.Layer):
32
  """Keras implementation of Hyena layer. Unlike in the original paper,
@@ -77,24 +88,31 @@ class HyenaLayer(keras.layers.Layer):
77
  trainable=True)
78
  self.filters = self.add_weight(shape=(width,width,self.stages),
79
  trainable=True)
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  def call(self,X,training=None):
82
- x_flat = tensorflow.tensordot(X.flat_values,
83
- self.data_projection,
84
- axes=1)
85
- f_flat = tensorflow.tensordot(self.positional_encoding(X).flat_values,
86
- self.filters,
87
- axes=1)
88
- x = tensorflow.RaggedTensor.from_row_lengths(x_flat,X.row_lengths())
89
- f = tensorflow.RaggedTensor.from_row_lengths(f_flat,X.row_lengths())
90
  if self.causal:
91
- concat = keras.layers.Concatenate()
92
- x = concat(x,tensorflow.zeros_like(x))
93
- f = concat(f,tensorflow.zeros_like(f))
94
  y = x[:,:,:,0]
95
  for i in tensorflow.range(self.stages):
96
  y = convolve(y,f[:,:,:,i])*x[:,:,:,i+1]
97
  if self.causal:
98
- for (i,n) in enumerate(X.row_lengths()):
99
- y[i] = y[i,:n]
100
- return y
 
 
11
  import tensorflow
12
  import warnings
13
 
14
+
15
  @tensorflow.function
16
  def convolve(x,y):
17
+ xT = tensorflow.vectorized_map(tensorflow.transpose, x)
18
+ yT = tensorflow.vectorized_map(tensorflow.transpose, y)
19
+ fx = tensorflow.vectorized_map(tensorflow.signal.rfft, xT)
20
+ fy = tensorflow.vectorized_map(tensorflow.signal.rfft, yT)
21
  fz = fx*fy
22
+ zT = tensorflow.vectorized_map(tensorflow.signal.irfft, fz)
23
+ return tensorflow.vectorized_map(tensorflow.transpose,zT)
24
 
25
+ # @tensorflow.function
26
+ # def fft(x):
27
+ # return tensorflow.signal.rfft(tensorflow.transpose(x))
28
 
29
+ # @tensorflow.function
30
+ # def ifft(x):
31
+ # return tensorflow.transpose(tensorflow.signal.irfft(x))
32
 
33
+ @tensorflow.function
34
+ def pad(x):
35
+ return tensorflow.concat([x,tensorflow.zeros_like(x)],0)
36
+
37
+ @tensorflow.function()
38
+ def truncate(args):
39
+ (data,length)=args
40
+ return data[:length]
41
 
42
  class HyenaLayer(keras.layers.Layer):
43
  """Keras implementation of Hyena layer. Unlike in the original paper,
 
88
  trainable=True)
89
  self.filters = self.add_weight(shape=(width,width,self.stages),
90
  trainable=True)
91
+ self.built = True
92
+
93
+ def conpute_output_shape(self,input_shape):
94
+ return input_shape
95
+
96
+ @tensorflow.function
97
+ def project(self,x):
98
+ return tensorflow.tensordot(x,self.data_projection,axes=1)
99
+
100
+ @tensorflow.function
101
+ def generate_filters(self,t):
102
+ return tensorflow.tensordot(t, self.filters,axes=1)
103
 
104
  def call(self,X,training=None):
105
+
106
+ x = tensorflow.ragged.map_flat_values(self.project, X)
107
+ f = tensorflow.ragged.map_flat_values(self.generate_filters,self.positional_encoding(X))
 
 
 
 
 
108
  if self.causal:
109
+ x = tensorflow.vectorize_map(pad,x)
110
+ f = tensorflow.vectorize_map(pad,f)
 
111
  y = x[:,:,:,0]
112
  for i in tensorflow.range(self.stages):
113
  y = convolve(y,f[:,:,:,i])*x[:,:,:,i+1]
114
  if self.causal:
115
+ y = tensorflow.vectorized_map(truncate,(y,X.row_lengths()))
116
+ return tensorflow.raw_ops.RaggedTensorToVariant(rt_nested_splits=y.row_splits,
117
+ rt_dense_values=y.flat_values,
118
+ batched_input=True)
scripts.py CHANGED
@@ -7,7 +7,9 @@ import qarac.corpora.BNCorpus
7
  import qarac.corpora.Batcher
8
  import qarac.models.qarac_base_model
9
  import keras
 
10
 
 
11
 
12
 
13
 
 
7
  import qarac.corpora.Batcher
8
  import qarac.models.qarac_base_model
9
  import keras
10
+ import tensorflow
11
 
12
+ #tensorflow.debugging.disable_traceback_filtering()
13
 
14
 
15