Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
8f1745b
1
Parent(s):
df051eb
GlobalAttentionPoolingHead layer
Browse files
.ipynb_checkpoints/Model visualisation-checkpoint.ipynb
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [],
|
3 |
-
"metadata": {},
|
4 |
-
"nbformat": 4,
|
5 |
-
"nbformat_minor": 5
|
6 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qarac/models/layers/GlobalAttentionPoolingHead.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Tue Sep 5 07:32:55 2023
|
5 |
+
|
6 |
+
@author: peter
|
7 |
+
"""
|
8 |
+
|
9 |
+
import keras
|
10 |
+
import tensorflow
|
11 |
+
|
12 |
+
class GlobalAttentionPoolingHead(keras.layers.Layer):
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
super(GlobalAttentionPoolingHead,self).__init__()
|
16 |
+
self.global_projection = None
|
17 |
+
self.local_projection = None
|
18 |
+
|
19 |
+
|
20 |
+
def build(self,input_shape):
|
21 |
+
width = input_shape[-1]
|
22 |
+
self.global_projection = self.add_weight('global projection',shape=(width,width))
|
23 |
+
self.local_projection = self.add_weight('local projection',shape=(width,width))
|
24 |
+
self.build=True
|
25 |
+
|
26 |
+
@tensorflow.function
|
27 |
+
def project(self,X):
|
28 |
+
return tensorflow.tensordot(X,self.local_projection,axes=1)
|
29 |
+
|
30 |
+
def attention_function(self,gp):
|
31 |
+
@tensorflow.function
|
32 |
+
def inner(lp):
|
33 |
+
return tensorflow.tensordot(lp,gp,axes=1)
|
34 |
+
return inner
|
35 |
+
|
36 |
+
def call(self,X,training=None):
|
37 |
+
gp = tensorflow.linalg.l2_normalize(tensorflow.tensordot([tensorflow.reduce_sum(X,
|
38 |
+
axis=1),
|
39 |
+
self.global_projection],
|
40 |
+
axes=1),
|
41 |
+
axis=1)
|
42 |
+
lp = tensorflow.linalg.l2_normalize(tensorflow.ragged.map_flat_values(self.project,
|
43 |
+
X),
|
44 |
+
axis=2)
|
45 |
+
attention = tensorflow.ragged.map_flat_values(self.attention_function(gp),
|
46 |
+
lp)
|
47 |
+
return tensorflow.reduce_sum(attention *X,
|
48 |
+
axis=1)
|
qarac/models/layers/HyenaLayer.py
CHANGED
@@ -11,22 +11,33 @@ import keras_nlp
|
|
11 |
import tensorflow
|
12 |
import warnings
|
13 |
|
|
|
14 |
@tensorflow.function
|
15 |
def convolve(x,y):
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
19 |
fz = fx*fy
|
20 |
-
|
|
|
21 |
|
22 |
-
@tensorflow.function
|
23 |
-
def fft(x):
|
24 |
-
|
25 |
|
26 |
-
@tensorflow.function
|
27 |
-
def ifft(x):
|
28 |
-
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
class HyenaLayer(keras.layers.Layer):
|
32 |
"""Keras implementation of Hyena layer. Unlike in the original paper,
|
@@ -77,24 +88,31 @@ class HyenaLayer(keras.layers.Layer):
|
|
77 |
trainable=True)
|
78 |
self.filters = self.add_weight(shape=(width,width,self.stages),
|
79 |
trainable=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
def call(self,X,training=None):
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
f_flat = tensorflow.tensordot(self.positional_encoding(X).flat_values,
|
86 |
-
self.filters,
|
87 |
-
axes=1)
|
88 |
-
x = tensorflow.RaggedTensor.from_row_lengths(x_flat,X.row_lengths())
|
89 |
-
f = tensorflow.RaggedTensor.from_row_lengths(f_flat,X.row_lengths())
|
90 |
if self.causal:
|
91 |
-
|
92 |
-
|
93 |
-
f = concat(f,tensorflow.zeros_like(f))
|
94 |
y = x[:,:,:,0]
|
95 |
for i in tensorflow.range(self.stages):
|
96 |
y = convolve(y,f[:,:,:,i])*x[:,:,:,i+1]
|
97 |
if self.causal:
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
11 |
import tensorflow
|
12 |
import warnings
|
13 |
|
14 |
+
|
15 |
@tensorflow.function
|
16 |
def convolve(x,y):
|
17 |
+
xT = tensorflow.vectorized_map(tensorflow.transpose, x)
|
18 |
+
yT = tensorflow.vectorized_map(tensorflow.transpose, y)
|
19 |
+
fx = tensorflow.vectorized_map(tensorflow.signal.rfft, xT)
|
20 |
+
fy = tensorflow.vectorized_map(tensorflow.signal.rfft, yT)
|
21 |
fz = fx*fy
|
22 |
+
zT = tensorflow.vectorized_map(tensorflow.signal.irfft, fz)
|
23 |
+
return tensorflow.vectorized_map(tensorflow.transpose,zT)
|
24 |
|
25 |
+
# @tensorflow.function
|
26 |
+
# def fft(x):
|
27 |
+
# return tensorflow.signal.rfft(tensorflow.transpose(x))
|
28 |
|
29 |
+
# @tensorflow.function
|
30 |
+
# def ifft(x):
|
31 |
+
# return tensorflow.transpose(tensorflow.signal.irfft(x))
|
32 |
|
33 |
+
@tensorflow.function
|
34 |
+
def pad(x):
|
35 |
+
return tensorflow.concat([x,tensorflow.zeros_like(x)],0)
|
36 |
+
|
37 |
+
@tensorflow.function()
|
38 |
+
def truncate(args):
|
39 |
+
(data,length)=args
|
40 |
+
return data[:length]
|
41 |
|
42 |
class HyenaLayer(keras.layers.Layer):
|
43 |
"""Keras implementation of Hyena layer. Unlike in the original paper,
|
|
|
88 |
trainable=True)
|
89 |
self.filters = self.add_weight(shape=(width,width,self.stages),
|
90 |
trainable=True)
|
91 |
+
self.built = True
|
92 |
+
|
93 |
+
def conpute_output_shape(self,input_shape):
|
94 |
+
return input_shape
|
95 |
+
|
96 |
+
@tensorflow.function
|
97 |
+
def project(self,x):
|
98 |
+
return tensorflow.tensordot(x,self.data_projection,axes=1)
|
99 |
+
|
100 |
+
@tensorflow.function
|
101 |
+
def generate_filters(self,t):
|
102 |
+
return tensorflow.tensordot(t, self.filters,axes=1)
|
103 |
|
104 |
def call(self,X,training=None):
|
105 |
+
|
106 |
+
x = tensorflow.ragged.map_flat_values(self.project, X)
|
107 |
+
f = tensorflow.ragged.map_flat_values(self.generate_filters,self.positional_encoding(X))
|
|
|
|
|
|
|
|
|
|
|
108 |
if self.causal:
|
109 |
+
x = tensorflow.vectorize_map(pad,x)
|
110 |
+
f = tensorflow.vectorize_map(pad,f)
|
|
|
111 |
y = x[:,:,:,0]
|
112 |
for i in tensorflow.range(self.stages):
|
113 |
y = convolve(y,f[:,:,:,i])*x[:,:,:,i+1]
|
114 |
if self.causal:
|
115 |
+
y = tensorflow.vectorized_map(truncate,(y,X.row_lengths()))
|
116 |
+
return tensorflow.raw_ops.RaggedTensorToVariant(rt_nested_splits=y.row_splits,
|
117 |
+
rt_dense_values=y.flat_values,
|
118 |
+
batched_input=True)
|
scripts.py
CHANGED
@@ -7,7 +7,9 @@ import qarac.corpora.BNCorpus
|
|
7 |
import qarac.corpora.Batcher
|
8 |
import qarac.models.qarac_base_model
|
9 |
import keras
|
|
|
10 |
|
|
|
11 |
|
12 |
|
13 |
|
|
|
7 |
import qarac.corpora.Batcher
|
8 |
import qarac.models.qarac_base_model
|
9 |
import keras
|
10 |
+
import tensorflow
|
11 |
|
12 |
+
#tensorflow.debugging.disable_traceback_filtering()
|
13 |
|
14 |
|
15 |
|