1212  use  nf_layer, only: layer
1313  use  nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
1414  use  nf_loss, only: quadratic_derivative
15-   use  nf_optimizers, only: sgd
15+   use  nf_optimizers, only: optimizer_base_type,  sgd
1616  use  nf_parallel, only: tile_indices
1717
1818  implicit none 
@@ -426,7 +426,7 @@ module subroutine train(self, input_data, output_data, batch_size, &
426426    real , intent (in ) ::  output_data(:,:)
427427    integer , intent (in ) ::  batch_size
428428    integer , intent (in ) ::  epochs
429-     type (sgd ), intent (in ) ::  optimizer
429+     class(optimizer_base_type ), intent (in ) ::  optimizer
430430
431431    real  ::  pos
432432    integer  ::  dataset_size
@@ -439,26 +439,31 @@ module subroutine train(self, input_data, output_data, batch_size, &
439439    epoch_loop: do  n =  1 , epochs
440440      batch_loop: do  i =  1 , dataset_size /  batch_size
441441
442-       !  Pull a random mini-batch from the dataset
443-       call  random_number (pos)
444-       batch_start =  int (pos *  (dataset_size -  batch_size +  1 )) +  1 
445-       batch_end =  batch_start +  batch_size -  1 
446- 
447-       !  FIXME shuffle in a way that doesn't require co_broadcast
448-       call  co_broadcast(batch_start, 1 )
449-       call  co_broadcast(batch_end, 1 )
450- 
451-       !  Distribute the batch in nearly equal pieces to all images
452-       indices =  tile_indices(batch_size)
453-       istart =  indices(1 ) +  batch_start -  1 
454-       iend =  indices(2 ) +  batch_start -  1 
455- 
456-       do  concurrent(j =  istart:iend)
457-         call  self %  forward(input_data(:,j))
458-         call  self %  backward(output_data(:,j))
459-       end do 
460- 
461-       call  self %  update(optimizer %  learning_rate /  batch_size)
442+         !  Pull a random mini-batch from the dataset
443+         call  random_number (pos)
444+         batch_start =  int (pos *  (dataset_size -  batch_size +  1 )) +  1 
445+         batch_end =  batch_start +  batch_size -  1 
446+ 
447+         !  FIXME shuffle in a way that doesn't require co_broadcast
448+         call  co_broadcast(batch_start, 1 )
449+         call  co_broadcast(batch_end, 1 )
450+ 
451+         !  Distribute the batch in nearly equal pieces to all images
452+         indices =  tile_indices(batch_size)
453+         istart =  indices(1 ) +  batch_start -  1 
454+         iend =  indices(2 ) +  batch_start -  1 
455+ 
456+         do  concurrent(j =  istart:iend)
457+           call  self %  forward(input_data(:,j))
458+           call  self %  backward(output_data(:,j))
459+         end do 
460+ 
461+         select type  (optimizer)
462+           type  is  (sgd)
463+             call  self %  update(optimizer %  learning_rate /  batch_size)
464+           class default
465+             error stop  ' Unsupported optimizer' 
466+         end  select
462467
463468      end do  batch_loop
464469    end do  epoch_loop
0 commit comments