Skip to content

Commit edd3f70

Browse files
authored
Introduce optimizer_base_type in support of different optimizers (#116)
1 parent dce3b0a commit edd3f70

File tree

3 files changed

+35
-26
lines changed

3 files changed

+35
-26
lines changed

src/nf/nf_network.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module nf_network
33
!! This module provides the network type to create new models.
44

55
use nf_layer, only: layer
6-
use nf_optimizers, only: sgd
6+
use nf_optimizers, only: optimizer_base_type
77

88
implicit none
99

@@ -193,7 +193,7 @@ module subroutine train(self, input_data, output_data, batch_size, &
193193
!! Set to `size(input_data, dim=2)` for a batch gradient descent.
194194
integer, intent(in) :: epochs
195195
!! Number of epochs to run
196-
type(sgd), intent(in) :: optimizer
196+
class(optimizer_base_type), intent(in) :: optimizer
197197
!! Optimizer instance; currently this is an `sgd` optimizer type
198198
!! and it will be made to be a more general optimizer type.
199199
end subroutine train

src/nf/nf_network_submodule.f90

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
use nf_layer, only: layer
1313
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
1414
use nf_loss, only: quadratic_derivative
15-
use nf_optimizers, only: sgd
15+
use nf_optimizers, only: optimizer_base_type, sgd
1616
use nf_parallel, only: tile_indices
1717

1818
implicit none
@@ -426,7 +426,7 @@ module subroutine train(self, input_data, output_data, batch_size, &
426426
real, intent(in) :: output_data(:,:)
427427
integer, intent(in) :: batch_size
428428
integer, intent(in) :: epochs
429-
type(sgd), intent(in) :: optimizer
429+
class(optimizer_base_type), intent(in) :: optimizer
430430

431431
real :: pos
432432
integer :: dataset_size
@@ -439,26 +439,31 @@ module subroutine train(self, input_data, output_data, batch_size, &
439439
epoch_loop: do n = 1, epochs
440440
batch_loop: do i = 1, dataset_size / batch_size
441441

442-
! Pull a random mini-batch from the dataset
443-
call random_number(pos)
444-
batch_start = int(pos * (dataset_size - batch_size + 1)) + 1
445-
batch_end = batch_start + batch_size - 1
446-
447-
! FIXME shuffle in a way that doesn't require co_broadcast
448-
call co_broadcast(batch_start, 1)
449-
call co_broadcast(batch_end, 1)
450-
451-
! Distribute the batch in nearly equal pieces to all images
452-
indices = tile_indices(batch_size)
453-
istart = indices(1) + batch_start - 1
454-
iend = indices(2) + batch_start - 1
455-
456-
do concurrent(j = istart:iend)
457-
call self % forward(input_data(:,j))
458-
call self % backward(output_data(:,j))
459-
end do
460-
461-
call self % update(optimizer % learning_rate / batch_size)
442+
! Pull a random mini-batch from the dataset
443+
call random_number(pos)
444+
batch_start = int(pos * (dataset_size - batch_size + 1)) + 1
445+
batch_end = batch_start + batch_size - 1
446+
447+
! FIXME shuffle in a way that doesn't require co_broadcast
448+
call co_broadcast(batch_start, 1)
449+
call co_broadcast(batch_end, 1)
450+
451+
! Distribute the batch in nearly equal pieces to all images
452+
indices = tile_indices(batch_size)
453+
istart = indices(1) + batch_start - 1
454+
iend = indices(2) + batch_start - 1
455+
456+
do concurrent(j = istart:iend)
457+
call self % forward(input_data(:,j))
458+
call self % backward(output_data(:,j))
459+
end do
460+
461+
select type (optimizer)
462+
type is (sgd)
463+
call self % update(optimizer % learning_rate / batch_size)
464+
class default
465+
error stop 'Unsupported optimizer'
466+
end select
462467

463468
end do batch_loop
464469
end do epoch_loop

src/nf/nf_optimizers.f90

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@ module nf_optimizers
55
implicit none
66

77
private
8-
public :: sgd
8+
public :: optimizer_base_type, sgd
99

10-
type :: sgd
10+
type, abstract :: optimizer_base_type
11+
character(:), allocatable :: name
12+
end type optimizer_base_type
13+
14+
type, extends(optimizer_base_type) :: sgd
1115
!! Stochastic Gradient Descent optimizer
1216
real :: learning_rate
1317
real :: momentum = 0 !TODO

0 commit comments

Comments
 (0)