Skip to content

Commit 357e63d

Browse files
jvdp1milancurcic
andauthored
Support of the argument stride for locally_connected2d_layer (#240)
* Support of the argument stride for locally_connected2d_layer * fix layer constructor * Addition of a test for stride in locally_connected2d_layer * fix API constructor layer * Stop at 20 epochs --------- Co-authored-by: milancurcic <caomaco@gmail.com>
1 parent a1d2d24 commit 357e63d

File tree

6 files changed

+68
-16
lines changed

6 files changed

+68
-16
lines changed

example/cnn_mnist_1d.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ program cnn_mnist_1d
1212
real, allocatable :: validation_images(:,:), validation_labels(:)
1313
real, allocatable :: testing_images(:,:), testing_labels(:)
1414
integer :: n
15-
integer, parameter :: num_epochs = 250
15+
integer, parameter :: num_epochs = 20
1616

1717
call load_mnist(training_images, training_labels, &
1818
validation_images, validation_labels, &

src/nf/nf_layer_constructors.f90

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ end function conv2d
160160

161161
interface locally_connected
162162

163-
module function locally_connected2d(filters, kernel_size, activation) result(res)
163+
module function locally_connected2d(filters, kernel_size, activation, stride) result(res)
164164
!! 1-d locally connected network constructor
165165
!!
166166
!! This layer is for building 1-d locally connected network.
@@ -183,6 +183,8 @@ module function locally_connected2d(filters, kernel_size, activation) result(res
183183
!! Width of the convolution window, commonly 3 or 5
184184
class(activation_function), intent(in), optional :: activation
185185
!! Activation function (default sigmoid)
186+
integer, intent(in), optional :: stride
187+
!! Size of the stride (default 1)
186188
type(layer) :: res
187189
!! Resulting layer instance
188190
end function locally_connected2d

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,14 @@ module function conv2d(filters, kernel_width, kernel_height, activation, stride)
105105

106106
end function conv2d
107107

108-
module function locally_connected2d(filters, kernel_size, activation) result(res)
108+
module function locally_connected2d(filters, kernel_size, activation, stride) result(res)
109109
integer, intent(in) :: filters
110110
integer, intent(in) :: kernel_size
111111
class(activation_function), intent(in), optional :: activation
112+
integer, intent(in), optional :: stride
112113
type(layer) :: res
113114

115+
integer :: stride_tmp
114116
class(activation_function), allocatable :: activation_tmp
115117

116118
res % name = 'locally_connected2d'
@@ -123,9 +125,18 @@ module function locally_connected2d(filters, kernel_size, activation) result(res
123125

124126
res % activation = activation_tmp % get_name()
125127

128+
if (present(stride)) then
129+
stride_tmp = stride
130+
else
131+
stride_tmp = 1
132+
endif
133+
134+
if (stride_tmp < 1) &
135+
error stop 'stride must be >= 1 in a conv1d layer'
136+
126137
allocate( &
127138
res % p, &
128-
source=locally_connected2d_layer(filters, kernel_size, activation_tmp) &
139+
source=locally_connected2d_layer(filters, kernel_size, activation_tmp, stride_tmp) &
129140
)
130141

131142
end function locally_connected2d

src/nf/nf_locally_connected2d_layer.f90

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ module nf_locally_connected2d_layer
1515
integer :: channels
1616
integer :: kernel_size
1717
integer :: filters
18+
integer :: stride
1819

1920
real, allocatable :: biases(:,:) ! size(filters)
2021
real, allocatable :: kernel(:,:,:,:) ! filters x channels x window x window
@@ -40,12 +41,13 @@ module nf_locally_connected2d_layer
4041
end type locally_connected2d_layer
4142

4243
interface locally_connected2d_layer
43-
module function locally_connected2d_layer_cons(filters, kernel_size, activation) &
44+
module function locally_connected2d_layer_cons(filters, kernel_size, activation, stride) &
4445
result(res)
4546
!! `locally_connected2d_layer` constructor function
4647
integer, intent(in) :: filters
4748
integer, intent(in) :: kernel_size
4849
class(activation_function), intent(in) :: activation
50+
integer, intent(in) :: stride
4951
type(locally_connected2d_layer) :: res
5052
end function locally_connected2d_layer_cons
5153
end interface locally_connected2d_layer
@@ -91,7 +93,9 @@ end function get_num_params
9193
module subroutine get_params_ptr(self, w_ptr, b_ptr)
9294
class(locally_connected2d_layer), intent(in), target :: self
9395
real, pointer, intent(out) :: w_ptr(:)
96+
!! Pointer to the kernel weights (flattened)
9497
real, pointer, intent(out) :: b_ptr(:)
98+
!! Pointer to the biases
9599
end subroutine get_params_ptr
96100

97101
module function get_gradients(self) result(gradients)
@@ -106,7 +110,9 @@ end function get_gradients
106110
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
107111
class(locally_connected2d_layer), intent(in), target :: self
108112
real, pointer, intent(out) :: dw_ptr(:)
113+
!! Pointer to the kernel weight gradients (flattened)
109114
real, pointer, intent(out) :: db_ptr(:)
115+
!! Pointer to the bias gradients
110116
end subroutine get_gradients_ptr
111117

112118
end interface

src/nf/nf_locally_connected2d_layer_submodule.f90

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77

88
contains
99

10-
module function locally_connected2d_layer_cons(filters, kernel_size, activation) result(res)
10+
module function locally_connected2d_layer_cons(filters, kernel_size, activation, stride) result(res)
1111
integer, intent(in) :: filters
1212
integer, intent(in) :: kernel_size
1313
class(activation_function), intent(in) :: activation
14+
integer, intent(in) :: stride
1415
type(locally_connected2d_layer) :: res
1516

1617
res % kernel_size = kernel_size
1718
res % filters = filters
1819
res % activation_name = activation % get_name()
20+
res % stride = stride
1921
allocate(res % activation, source = activation)
2022
end function locally_connected2d_layer_cons
2123

@@ -24,8 +26,11 @@ module subroutine init(self, input_shape)
2426
integer, intent(in) :: input_shape(:)
2527

2628
self % channels = input_shape(1)
27-
self % width = input_shape(2) - self % kernel_size + 1
29+
self % width = (input_shape(2) - self % kernel_size) / self % stride +1
2830

31+
if (mod(input_shape(2) - self % kernel_size , self % stride) /= 0) self % width = self % width + 1
32+
33+
! Output of shape: filters x width
2934
allocate(self % output(self % filters, self % width))
3035
self % output = 0
3136

@@ -52,14 +57,17 @@ end subroutine init
5257
pure module subroutine forward(self, input)
5358
class(locally_connected2d_layer), intent(in out) :: self
5459
real, intent(in) :: input(:,:)
60+
integer :: input_width
5561
integer :: j, n
5662
integer :: iws, iwe
63+
64+
input_width = size(input, dim=2)
5765

5866
do j = 1, self % width
59-
iws = j
60-
iwe = j + self % kernel_size - 1
67+
iws = self % stride * (j-1) + 1
68+
iwe = min(iws + self % kernel_size - 1, input_width)
6169
do n = 1, self % filters
62-
self % z(n, j) = sum(self % kernel(n, j, :, :) * input(:, iws:iwe)) + self % biases(n, j)
70+
self % z(n, j) = sum(self % kernel(n, j, :, 1:iwe-iws+1) * input(:, iws:iwe)) + self % biases(n, j)
6371
end do
6472
end do
6573
self % output = self % activation % eval(self % z)
@@ -69,12 +77,15 @@ pure module subroutine backward(self, input, gradient)
6977
class(locally_connected2d_layer), intent(in out) :: self
7078
real, intent(in) :: input(:,:)
7179
real, intent(in) :: gradient(:,:)
80+
integer :: input_width
7281
integer :: j, n, k
7382
integer :: iws, iwe
7483
real :: gdz(self % filters, self % width)
7584
real :: db_local(self % filters, self % width)
7685
real :: dw_local(self % filters, self % width, self % channels, self % kernel_size)
7786

87+
input_width = size(input, dim=2)
88+
7889
do j = 1, self % width
7990
gdz(:, j) = gradient(:, j) * self % activation % eval_prime(self % z(:, j))
8091
end do
@@ -90,11 +101,11 @@ pure module subroutine backward(self, input, gradient)
90101

91102
do n = 1, self % filters
92103
do j = 1, self % width
93-
iws = j
94-
iwe = j + self % kernel_size - 1
104+
iws = self % stride * (j-1) + 1
105+
iwe = min(iws + self % kernel_size - 1, input_width)
95106
do k = 1, self % channels
96-
dw_local(n, j, k, :) = dw_local(n, j, k, :) + input(k, iws:iwe) * gdz(n, j)
97-
self % gradient(k, iws:iwe) = self % gradient(k, iws:iwe) + self % kernel(n, j, k, :) * gdz(n, j)
107+
dw_local(n, j, k, 1:iwe-iws+1) = dw_local(n, j, k, 1:iwe-iws+1) + input(k, iws:iwe) * gdz(n, j)
108+
self % gradient(k, iws:iwe) = self % gradient(k, iws:iwe) + self % kernel(n, j, k, 1:iwe-iws+1) * gdz(n, j)
98109
end do
99110
end do
100111
end do
@@ -131,5 +142,4 @@ module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
131142
db_ptr(1:size(self % db)) => self % db
132143
end subroutine get_gradients_ptr
133144

134-
135145
end submodule nf_locally_connected2d_layer_submodule

test/test_locally_connected2d_layer.f90

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ program test_locally_connected2d_layer
5858
select type(this_layer => input_layer % p); type is(input2d_layer)
5959
call this_layer % set(sample_input)
6060
end select
61+
deallocate(sample_input)
6162

6263
call locally_connected_1d_layer % forward(input_layer)
6364
call locally_connected_1d_layer % get_output(output)
@@ -67,11 +68,33 @@ program test_locally_connected2d_layer
6768
write(stderr, '(a)') 'locally_connected2d layer with zero input and sigmoid function must forward to all 0.5.. failed'
6869
end if
6970

71+
! Minimal locally_connected_1d layer: 1 channel, 3x3 pixel image, stride = 3;
72+
allocate(sample_input(1, 17))
73+
sample_input = 0
74+
75+
input_layer = input(1, 17)
76+
locally_connected_1d_layer = locally_connected(filters, kernel_size, stride = 3)
77+
call locally_connected_1d_layer % init(input_layer)
78+
79+
select type(this_layer => input_layer % p); type is(input2d_layer)
80+
call this_layer % set(sample_input)
81+
end select
82+
deallocate(sample_input)
83+
84+
call locally_connected_1d_layer % forward(input_layer)
85+
call locally_connected_1d_layer % get_output(output)
86+
87+
if (.not. all(abs(output) < tolerance)) then
88+
ok = .false.
89+
write(stderr, '(a)') 'locally_connected2d layer with zero input and sigmoid function must forward to all 0.5.. failed'
90+
end if
91+
92+
!Final
7093
if (ok) then
7194
print '(a)', 'test_locally_connected2d_layer: All tests passed.'
7295
else
7396
write(stderr, '(a)') 'test_locally_connected2d_layer: One or more tests failed.'
7497
stop 1
7598
end if
76-
99+
77100
end program test_locally_connected2d_layer

0 commit comments

Comments
 (0)