@@ -20,22 +20,18 @@ module nf_cross_attention_layer
2020 end type cross_attention_layer
2121
2222 interface cross_attention_layer
23- module function cross_attention_layer_cons (n_heads ) result(res)
24- ! ! This function returns the `cross_attention_layer` instance.
25- integer , intent (in ) :: sequence_length, model_dimension, n_heads
26- type (cross_attention_layer) :: res
27- end function cross_attention_layer_cons
23+ module procedure cross_attention_layer_cons
2824 end interface cross_attention_layer
2925
3026contains
31- module function cross_attention_layer_cons (n_heads ) result(res)
27+ function cross_attention_layer_cons (n_heads ) result(res)
3228 ! ! This function returns the `cross_attention_layer` instance.
3329 integer , intent (in ) :: n_heads
3430 type (cross_attention_layer) :: res
3531 res % n_heads = n_heads
3632 end function cross_attention_layer_cons
3733
38- pure module subroutine backward(self, input, gradient)
34+ pure subroutine backward (self , input , gradient )
3935 ! ! Cross Attention Back propagation
4036 class(cross_attention_layer), intent (in out ) :: self
4137 real , intent (in ) :: input(:, :, :)
@@ -46,7 +42,7 @@ pure module subroutine backward(self, input, gradient)
4642 self % gradient(2 , :, :) = self % key_layer % gradient + self % value_layer % gradient
4743 end subroutine backward
4844
49- pure module subroutine forward(self, input)
45+ pure subroutine forward (self , input )
5046 ! ! Cross Attention Forward propagation
5147 ! ! Input Shape (kind, sequence_length, model_dimension)
5248 ! ! where kind is 1 for Query and 2 for Key-Value
@@ -56,7 +52,7 @@ pure module subroutine forward(self, input)
5652 call self % common_forward(input(1 , :, :), input(2 , :, :), input(2 , :, :))
5753 end subroutine forward
5854
59- module subroutine init (self , input_shape )
55+ subroutine init (self , input_shape )
6056 class(cross_attention_layer), intent (in out ) :: self
6157 integer , intent (in ) :: input_shape(:)
6258
0 commit comments