Spaces:
Runtime error
Runtime error
Commit
·
0dfe33d
1
Parent(s):
7b9f0e4
Add source code
Browse files- .gitattributes +5 -0
- app.py +35 -0
- egs/video1.mp4 +3 -0
- egs/video2.mp4 +3 -0
- egs/video3.mp4 +3 -0
- egs/video4.mp4 +3 -0
- egs/video5.mp4 +0 -0
- src/__init__.py +0 -0
- src/audiovisual_stream.py +39 -0
- src/auditory_stream.py +147 -0
- src/core.py +43 -0
- src/model +3 -0
- src/visual_stream.py +145 -0
.gitattributes
CHANGED
@@ -32,3 +32,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
egs/video1.mp4 filter=lfs diff=lfs merge=lfs -text
|
36 |
+
egs/video2.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
egs/video3.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
egs/video4.mp4 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
src/model filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
|
5 |
+
from src.core import load_model, predict_traits
|
6 |
+
|
7 |
+
TRAIT_NAMES = [
|
8 |
+
"Extraversion",
|
9 |
+
"Agreeableness",
|
10 |
+
"Conscientiousness",
|
11 |
+
"Neurotisicm",
|
12 |
+
"Openness",
|
13 |
+
]
|
14 |
+
|
15 |
+
|
16 |
+
def get_traits(video):
|
17 |
+
model = load_model()
|
18 |
+
# if webcam_video:
|
19 |
+
# trait_values = predict_traits(webcam_video, model)
|
20 |
+
# else:
|
21 |
+
trait_values = predict_traits(video, model)
|
22 |
+
return {k: float(v) for k, v in zip(TRAIT_NAMES, trait_values)}
|
23 |
+
|
24 |
+
|
25 |
+
demo = gr.Interface(
|
26 |
+
get_traits,
|
27 |
+
inputs=gr.Video(label="Video", include_audio=True),
|
28 |
+
outputs=gr.Label(num_top_classes=5, label="Results"),
|
29 |
+
title="Personality Traits Prediction [Prototype]",
|
30 |
+
description="Predicts the 5 psychological traits using an introduction video",
|
31 |
+
thumbnail="https://cdn-icons-png.flaticon.com/512/3392/3392044.png",
|
32 |
+
examples="egs",
|
33 |
+
cache_examples=True,
|
34 |
+
)
|
35 |
+
demo.launch()
|
egs/video1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:956627f2dcb5559054718462157e436c5be0bc70ff32ae4194f334e32feafcee
|
3 |
+
size 2969523
|
egs/video2.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8e67ba771391717777022cf604dc1bef79e20ded23a7055f54f33325399f5296
|
3 |
+
size 2311239
|
egs/video3.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7408594f52bb5edf658ed32c559c6e05526e21ec3d6ff7e0b69be03f8221784a
|
3 |
+
size 2579751
|
egs/video4.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:410f03f6e0ecb6b170ed65b47e520f7e071c19bb7e0f11f1c57808d095a2298f
|
3 |
+
size 2790813
|
egs/video5.mp4
ADDED
Binary file (62.9 kB). View file
|
|
src/__init__.py
ADDED
File without changes
|
src/audiovisual_stream.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import auditory_stream, visual_stream
|
2 |
+
import chainer
|
3 |
+
|
4 |
+
|
5 |
+
class ResNet18(chainer.Chain):
|
6 |
+
def __init__(self):
|
7 |
+
super(ResNet18, self).__init__(
|
8 |
+
aud=auditory_stream.ResNet18(),
|
9 |
+
vis=visual_stream.ResNet18(),
|
10 |
+
fc=chainer.links.Linear(512, 5, initialW=chainer.initializers.HeNormal()),
|
11 |
+
)
|
12 |
+
|
13 |
+
def __call__(self, x):
|
14 |
+
h = [
|
15 |
+
self.aud(chainer.Variable(chainer.cuda.to_cpu(x[0]))),
|
16 |
+
chainer.functions.expand_dims(
|
17 |
+
chainer.functions.sum(
|
18 |
+
self.vis(chainer.Variable(chainer.cuda.to_cpu(x[1][:256]))), 0
|
19 |
+
),
|
20 |
+
0,
|
21 |
+
),
|
22 |
+
]
|
23 |
+
|
24 |
+
for i in range(256, x[1].shape[0], 256):
|
25 |
+
h[1] += chainer.functions.expand_dims(
|
26 |
+
chainer.functions.sum(
|
27 |
+
self.vis(chainer.Variable(chainer.cuda.to_cpu(x[1][i : i + 256]))),
|
28 |
+
0,
|
29 |
+
),
|
30 |
+
0,
|
31 |
+
)
|
32 |
+
|
33 |
+
h[1] /= x[1].shape[0]
|
34 |
+
|
35 |
+
return chainer.cuda.to_cpu(
|
36 |
+
(
|
37 |
+
(chainer.functions.tanh(self.fc(chainer.functions.concat(h))) + 1) / 2
|
38 |
+
).data[0]
|
39 |
+
)
|
src/auditory_stream.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chainer
|
2 |
+
|
3 |
+
|
4 |
+
class ConvolutionBlock(chainer.Chain):
|
5 |
+
def __init__(self, in_channels, out_channels):
|
6 |
+
super(ConvolutionBlock, self).__init__(
|
7 |
+
conv=chainer.links.Convolution2D(
|
8 |
+
in_channels,
|
9 |
+
out_channels,
|
10 |
+
(1, 49),
|
11 |
+
(1, 4),
|
12 |
+
(0, 24),
|
13 |
+
initialW=chainer.initializers.HeNormal(),
|
14 |
+
),
|
15 |
+
bn_conv=chainer.links.BatchNormalization(out_channels),
|
16 |
+
)
|
17 |
+
|
18 |
+
def __call__(self, x):
|
19 |
+
# Set Train to False.
|
20 |
+
chainer.config.train = False
|
21 |
+
|
22 |
+
h = self.conv(x)
|
23 |
+
h = self.bn_conv(h)
|
24 |
+
y = chainer.functions.relu(h)
|
25 |
+
|
26 |
+
return y
|
27 |
+
|
28 |
+
|
29 |
+
class ResidualBlock(chainer.Chain):
|
30 |
+
def __init__(self, in_channels, out_channels):
|
31 |
+
super(ResidualBlock, self).__init__(
|
32 |
+
res_branch2a=chainer.links.Convolution2D(
|
33 |
+
in_channels,
|
34 |
+
out_channels,
|
35 |
+
(1, 9),
|
36 |
+
pad=(0, 4),
|
37 |
+
initialW=chainer.initializers.HeNormal(),
|
38 |
+
),
|
39 |
+
bn_branch2a=chainer.links.BatchNormalization(out_channels),
|
40 |
+
res_branch2b=chainer.links.Convolution2D(
|
41 |
+
out_channels,
|
42 |
+
out_channels,
|
43 |
+
(1, 9),
|
44 |
+
pad=(0, 4),
|
45 |
+
initialW=chainer.initializers.HeNormal(),
|
46 |
+
),
|
47 |
+
bn_branch2b=chainer.links.BatchNormalization(out_channels),
|
48 |
+
)
|
49 |
+
|
50 |
+
def __call__(self, x):
|
51 |
+
chainer.config.train = False
|
52 |
+
|
53 |
+
h = self.res_branch2a(x)
|
54 |
+
h = self.bn_branch2a(h)
|
55 |
+
h = chainer.functions.relu(h)
|
56 |
+
h = self.res_branch2b(h)
|
57 |
+
h = self.bn_branch2b(h)
|
58 |
+
h = x + h
|
59 |
+
y = chainer.functions.relu(h)
|
60 |
+
|
61 |
+
return y
|
62 |
+
|
63 |
+
|
64 |
+
class ResidualBlockA:
|
65 |
+
def __init__(self):
|
66 |
+
pass
|
67 |
+
|
68 |
+
def __call__(self):
|
69 |
+
pass
|
70 |
+
|
71 |
+
|
72 |
+
class ResidualBlockB(chainer.Chain):
|
73 |
+
def __init__(self, in_channels, out_channels):
|
74 |
+
super(ResidualBlockB, self).__init__(
|
75 |
+
res_branch1=chainer.links.Convolution2D(
|
76 |
+
in_channels,
|
77 |
+
out_channels,
|
78 |
+
(1, 1),
|
79 |
+
(1, 4),
|
80 |
+
initialW=chainer.initializers.HeNormal(),
|
81 |
+
),
|
82 |
+
bn_branch1=chainer.links.BatchNormalization(out_channels),
|
83 |
+
res_branch2a=chainer.links.Convolution2D(
|
84 |
+
in_channels,
|
85 |
+
out_channels,
|
86 |
+
(1, 9),
|
87 |
+
(1, 4),
|
88 |
+
(0, 4),
|
89 |
+
initialW=chainer.initializers.HeNormal(),
|
90 |
+
),
|
91 |
+
bn_branch2a=chainer.links.BatchNormalization(out_channels),
|
92 |
+
res_branch2b=chainer.links.Convolution2D(
|
93 |
+
out_channels,
|
94 |
+
out_channels,
|
95 |
+
(1, 9),
|
96 |
+
pad=(0, 4),
|
97 |
+
initialW=chainer.initializers.HeNormal(),
|
98 |
+
),
|
99 |
+
bn_branch2b=chainer.links.BatchNormalization(out_channels),
|
100 |
+
)
|
101 |
+
|
102 |
+
def __call__(self, x):
|
103 |
+
chainer.config.train = False
|
104 |
+
|
105 |
+
temp = self.res_branch1(x)
|
106 |
+
temp = self.bn_branch1(temp)
|
107 |
+
h = self.res_branch2a(x)
|
108 |
+
h = self.bn_branch2a(h)
|
109 |
+
h = chainer.functions.relu(h)
|
110 |
+
h = self.res_branch2b(h)
|
111 |
+
h = self.bn_branch2b(h)
|
112 |
+
h = temp + h
|
113 |
+
y = chainer.functions.relu(h)
|
114 |
+
|
115 |
+
return y
|
116 |
+
|
117 |
+
|
118 |
+
class ResNet18(chainer.Chain):
|
119 |
+
def __init__(self):
|
120 |
+
super(ResNet18, self).__init__(
|
121 |
+
conv1_relu=ConvolutionBlock(1, 32),
|
122 |
+
res2a_relu=ResidualBlock(32, 32),
|
123 |
+
res2b_relu=ResidualBlock(32, 32),
|
124 |
+
res3a_relu=ResidualBlockB(32, 64),
|
125 |
+
res3b_relu=ResidualBlock(64, 64),
|
126 |
+
res4a_relu=ResidualBlockB(64, 128),
|
127 |
+
res4b_relu=ResidualBlock(128, 128),
|
128 |
+
res5a_relu=ResidualBlockB(128, 256),
|
129 |
+
res5b_relu=ResidualBlock(256, 256),
|
130 |
+
)
|
131 |
+
|
132 |
+
def __call__(self, x):
|
133 |
+
chainer.config.train = False
|
134 |
+
|
135 |
+
h = self.conv1_relu(x)
|
136 |
+
h = chainer.functions.max_pooling_2d(h, (1, 9), (1, 4), (0, 4))
|
137 |
+
h = self.res2a_relu(h)
|
138 |
+
h = self.res2b_relu(h)
|
139 |
+
h = self.res3a_relu(h)
|
140 |
+
h = self.res3b_relu(h)
|
141 |
+
h = self.res4a_relu(h)
|
142 |
+
h = self.res4b_relu(h)
|
143 |
+
h = self.res5a_relu(h)
|
144 |
+
h = self.res5b_relu(h)
|
145 |
+
y = chainer.functions.average_pooling_2d(h, h.data.shape[2:])
|
146 |
+
|
147 |
+
return y
|
src/core.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import audiovisual_stream
|
2 |
+
import chainer.serializers
|
3 |
+
import librosa
|
4 |
+
import numpy
|
5 |
+
import skvideo.io
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
FRAMES_LIMIT = 25
|
9 |
+
|
10 |
+
|
11 |
+
def load_audio(data):
|
12 |
+
return librosa.load(data, 16000)[0][None, None, None, :]
|
13 |
+
|
14 |
+
|
15 |
+
def load_model():
|
16 |
+
model = audiovisual_stream.ResNet18().to_cpu()
|
17 |
+
chainer.serializers.load_npz("src/model", model)
|
18 |
+
return model
|
19 |
+
|
20 |
+
|
21 |
+
def predict_traits(data, model):
|
22 |
+
video_features = skvideo.io.vreader(data, num_frames=27)
|
23 |
+
# video_features = skvideo.io.vreader(data)
|
24 |
+
|
25 |
+
audio_features = load_audio(data)
|
26 |
+
|
27 |
+
x = []
|
28 |
+
predictions = []
|
29 |
+
|
30 |
+
frame_count = 0
|
31 |
+
for frame in video_features:
|
32 |
+
x.append(numpy.rollaxis(frame, 2))
|
33 |
+
|
34 |
+
frame_count += 1
|
35 |
+
|
36 |
+
if frame_count == FRAMES_LIMIT:
|
37 |
+
x = [audio_features, numpy.array(x, "float32")]
|
38 |
+
predictions.append(model(x))
|
39 |
+
|
40 |
+
frame_count = 0
|
41 |
+
x = []
|
42 |
+
|
43 |
+
return np.mean(np.asarray(predictions), axis=0)
|
src/model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a4794bfeae59d64c7ae9f78404c11d4517cf6da78b313301297dca4bc148deac
|
3 |
+
size 19218612
|
src/visual_stream.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chainer
|
2 |
+
|
3 |
+
class ConvolutionBlock(chainer.Chain):
|
4 |
+
def __init__(self, in_channels, out_channels):
|
5 |
+
super(ConvolutionBlock, self).__init__(
|
6 |
+
conv=chainer.links.Convolution2D(
|
7 |
+
in_channels,
|
8 |
+
out_channels,
|
9 |
+
7,
|
10 |
+
2,
|
11 |
+
3,
|
12 |
+
initialW=chainer.initializers.HeNormal(),
|
13 |
+
),
|
14 |
+
bn_conv=chainer.links.BatchNormalization(out_channels),
|
15 |
+
)
|
16 |
+
|
17 |
+
def __call__(self, x):
|
18 |
+
chainer.config.train = False
|
19 |
+
|
20 |
+
h = self.conv(x)
|
21 |
+
h = self.bn_conv(h)
|
22 |
+
y = chainer.functions.relu(h)
|
23 |
+
|
24 |
+
return y
|
25 |
+
|
26 |
+
|
27 |
+
class ResidualBlock(chainer.Chain):
|
28 |
+
def __init__(self, in_channels, out_channels):
|
29 |
+
super(ResidualBlock, self).__init__(
|
30 |
+
res_branch2a=chainer.links.Convolution2D(
|
31 |
+
in_channels,
|
32 |
+
out_channels,
|
33 |
+
3,
|
34 |
+
pad=1,
|
35 |
+
initialW=chainer.initializers.HeNormal(),
|
36 |
+
),
|
37 |
+
bn_branch2a=chainer.links.BatchNormalization(out_channels),
|
38 |
+
res_branch2b=chainer.links.Convolution2D(
|
39 |
+
out_channels,
|
40 |
+
out_channels,
|
41 |
+
3,
|
42 |
+
pad=1,
|
43 |
+
initialW=chainer.initializers.HeNormal(),
|
44 |
+
),
|
45 |
+
bn_branch2b=chainer.links.BatchNormalization(out_channels),
|
46 |
+
)
|
47 |
+
|
48 |
+
def __call__(self, x):
|
49 |
+
chainer.config.train = False
|
50 |
+
|
51 |
+
h = self.res_branch2a(x)
|
52 |
+
h = self.bn_branch2a(h)
|
53 |
+
h = chainer.functions.relu(h)
|
54 |
+
h = self.res_branch2b(h)
|
55 |
+
h = self.bn_branch2b(h)
|
56 |
+
h = x + h
|
57 |
+
y = chainer.functions.relu(h)
|
58 |
+
|
59 |
+
return y
|
60 |
+
|
61 |
+
|
62 |
+
class ResidualBlockA:
|
63 |
+
def __init__(self):
|
64 |
+
pass
|
65 |
+
|
66 |
+
def __call__(self):
|
67 |
+
pass
|
68 |
+
|
69 |
+
|
70 |
+
class ResidualBlockB(chainer.Chain):
|
71 |
+
def __init__(self, in_channels, out_channels):
|
72 |
+
super(ResidualBlockB, self).__init__(
|
73 |
+
res_branch1=chainer.links.Convolution2D(
|
74 |
+
in_channels,
|
75 |
+
out_channels,
|
76 |
+
1,
|
77 |
+
2,
|
78 |
+
initialW=chainer.initializers.HeNormal(),
|
79 |
+
),
|
80 |
+
bn_branch1=chainer.links.BatchNormalization(out_channels),
|
81 |
+
res_branch2a=chainer.links.Convolution2D(
|
82 |
+
in_channels,
|
83 |
+
out_channels,
|
84 |
+
3,
|
85 |
+
2,
|
86 |
+
1,
|
87 |
+
initialW=chainer.initializers.HeNormal(),
|
88 |
+
),
|
89 |
+
bn_branch2a=chainer.links.BatchNormalization(out_channels),
|
90 |
+
res_branch2b=chainer.links.Convolution2D(
|
91 |
+
out_channels,
|
92 |
+
out_channels,
|
93 |
+
3,
|
94 |
+
pad=1,
|
95 |
+
initialW=chainer.initializers.HeNormal(),
|
96 |
+
),
|
97 |
+
bn_branch2b=chainer.links.BatchNormalization(out_channels),
|
98 |
+
)
|
99 |
+
|
100 |
+
def __call__(self, x):
|
101 |
+
chainer.config.train = False
|
102 |
+
|
103 |
+
temp = self.res_branch1(x)
|
104 |
+
temp = self.bn_branch1(temp)
|
105 |
+
h = self.res_branch2a(x)
|
106 |
+
h = self.bn_branch2a(h)
|
107 |
+
h = chainer.functions.relu(h)
|
108 |
+
h = self.res_branch2b(h)
|
109 |
+
h = self.bn_branch2b(h)
|
110 |
+
h = temp + h
|
111 |
+
y = chainer.functions.relu(h)
|
112 |
+
|
113 |
+
return y
|
114 |
+
|
115 |
+
|
116 |
+
class ResNet18(chainer.Chain):
|
117 |
+
def __init__(self):
|
118 |
+
super(ResNet18, self).__init__(
|
119 |
+
conv1_relu=ConvolutionBlock(3, 32),
|
120 |
+
res2a_relu=ResidualBlock(32, 32),
|
121 |
+
res2b_relu=ResidualBlock(32, 32),
|
122 |
+
res3a_relu=ResidualBlockB(32, 64),
|
123 |
+
res3b_relu=ResidualBlock(64, 64),
|
124 |
+
res4a_relu=ResidualBlockB(64, 128),
|
125 |
+
res4b_relu=ResidualBlock(128, 128),
|
126 |
+
res5a_relu=ResidualBlockB(128, 256),
|
127 |
+
res5b_relu=ResidualBlock(256, 256),
|
128 |
+
)
|
129 |
+
|
130 |
+
def __call__(self, x):
|
131 |
+
chainer.config.train = False
|
132 |
+
|
133 |
+
h = self.conv1_relu(x)
|
134 |
+
h = chainer.functions.max_pooling_2d(h, 3, 2, 1)
|
135 |
+
h = self.res2a_relu(h)
|
136 |
+
h = self.res2b_relu(h)
|
137 |
+
h = self.res3a_relu(h)
|
138 |
+
h = self.res3b_relu(h)
|
139 |
+
h = self.res4a_relu(h)
|
140 |
+
h = self.res4b_relu(h)
|
141 |
+
h = self.res5a_relu(h)
|
142 |
+
h = self.res5b_relu(h)
|
143 |
+
y = chainer.functions.average_pooling_2d(h, h.data.shape[2:])
|
144 |
+
|
145 |
+
return y
|