@@ -930,164 +930,143 @@ end
930930
931931
932932# multiply 2x2 matrices
933- function matmul2x2 (tA, tB, A:: AbstractMatrix{T} , B:: AbstractMatrix{S} ) where {T,S}
933+ Base . @constprop :aggressive function matmul2x2 (tA, tB, A:: AbstractMatrix{T} , B:: AbstractMatrix{S} ) where {T,S}
934934 matmul2x2! (similar (B, promote_op (matprod, T, S), 2 , 2 ), tA, tB, A, B)
935935end
936936
937- function matmul2x2! (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix ,
938- _add:: MulAddMul = MulAddMul ())
937+ function __matmul_checks (C, A, B, sz)
939938 require_one_based_indexing (C, A, B)
940939 if C === A || B === C
941940 throw (ArgumentError (" output matrix must not be aliased with input matrix" ))
942941 end
943- if ! (size (A) == size (B) == size (C) == ( 2 , 2 ) )
942+ if ! (size (A) == size (B) == size (C) == sz )
944943 throw (DimensionMismatch (lazy " A has size $(size(A)), B has size $(size(B)), C has size $(size(C))" ))
945944 end
945+ return nothing
946+ end
947+
948+ # separate function with the core of matmul2x2! that doesn't depend on a MulAddMul
949+ Base. @constprop :aggressive function _matmul2x2_elements (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix )
950+ __matmul_checks (C, A, B, (2 ,2 ))
951+ __matmul2x2_elements (tA, tB, A, B)
952+ end
953+ Base. @constprop :aggressive function __matmul2x2_elements (tA, A:: AbstractMatrix )
946954 @inbounds begin
947- if tA == ' N'
955+ tA_uc = uppercase (tA) # possibly unwrap a WrapperChar
956+ if tA_uc == ' N'
948957 A11 = A[1 ,1 ]; A12 = A[1 ,2 ]; A21 = A[2 ,1 ]; A22 = A[2 ,2 ]
949- elseif tA == ' T'
958+ elseif tA_uc == ' T'
950959 # TODO making these lazy could improve perf
951960 A11 = copy (transpose (A[1 ,1 ])); A12 = copy (transpose (A[2 ,1 ]))
952961 A21 = copy (transpose (A[1 ,2 ])); A22 = copy (transpose (A[2 ,2 ]))
953- elseif tA == ' C'
962+ elseif tA_uc == ' C'
954963 # TODO making these lazy could improve perf
955964 A11 = copy (A[1 ,1 ]' ); A12 = copy (A[2 ,1 ]' )
956965 A21 = copy (A[1 ,2 ]' ); A22 = copy (A[2 ,2 ]' )
957- elseif tA == ' S'
958- A11 = symmetric (A[1 ,1 ], :U ); A12 = A[1 ,2 ]
959- A21 = copy (transpose (A[1 ,2 ])); A22 = symmetric (A[2 ,2 ], :U )
960- elseif tA == ' s'
961- A11 = symmetric (A[1 ,1 ], :L ); A12 = copy (transpose (A[2 ,1 ]))
962- A21 = A[2 ,1 ]; A22 = symmetric (A[2 ,2 ], :L )
963- elseif tA == ' H'
964- A11 = hermitian (A[1 ,1 ], :U ); A12 = A[1 ,2 ]
965- A21 = copy (adjoint (A[1 ,2 ])); A22 = hermitian (A[2 ,2 ], :U )
966- else # if tA == 'h'
967- A11 = hermitian (A[1 ,1 ], :L ); A12 = copy (adjoint (A[2 ,1 ]))
968- A21 = A[2 ,1 ]; A22 = hermitian (A[2 ,2 ], :L )
969- end
970- if tB == ' N'
971- B11 = B[1 ,1 ]; B12 = B[1 ,2 ];
972- B21 = B[2 ,1 ]; B22 = B[2 ,2 ]
973- elseif tB == ' T'
974- # TODO making these lazy could improve perf
975- B11 = copy (transpose (B[1 ,1 ])); B12 = copy (transpose (B[2 ,1 ]))
976- B21 = copy (transpose (B[1 ,2 ])); B22 = copy (transpose (B[2 ,2 ]))
977- elseif tB == ' C'
978- # TODO making these lazy could improve perf
979- B11 = copy (B[1 ,1 ]' ); B12 = copy (B[2 ,1 ]' )
980- B21 = copy (B[1 ,2 ]' ); B22 = copy (B[2 ,2 ]' )
981- elseif tB == ' S'
982- B11 = symmetric (B[1 ,1 ], :U ); B12 = B[1 ,2 ]
983- B21 = copy (transpose (B[1 ,2 ])); B22 = symmetric (B[2 ,2 ], :U )
984- elseif tB == ' s'
985- B11 = symmetric (B[1 ,1 ], :L ); B12 = copy (transpose (B[2 ,1 ]))
986- B21 = B[2 ,1 ]; B22 = symmetric (B[2 ,2 ], :L )
987- elseif tB == ' H'
988- B11 = hermitian (B[1 ,1 ], :U ); B12 = B[1 ,2 ]
989- B21 = copy (adjoint (B[1 ,2 ])); B22 = hermitian (B[2 ,2 ], :U )
990- else # if tB == 'h'
991- B11 = hermitian (B[1 ,1 ], :L ); B12 = copy (adjoint (B[2 ,1 ]))
992- B21 = B[2 ,1 ]; B22 = hermitian (B[2 ,2 ], :L )
966+ elseif tA_uc == ' S'
967+ if isuppercase (tA) # tA == 'S'
968+ A11 = symmetric (A[1 ,1 ], :U ); A12 = A[1 ,2 ]
969+ A21 = copy (transpose (A[1 ,2 ])); A22 = symmetric (A[2 ,2 ], :U )
970+ else
971+ A11 = symmetric (A[1 ,1 ], :L ); A12 = copy (transpose (A[2 ,1 ]))
972+ A21 = A[2 ,1 ]; A22 = symmetric (A[2 ,2 ], :L )
973+ end
974+ elseif tA_uc == ' H'
975+ if isuppercase (tA) # tA == 'H'
976+ A11 = hermitian (A[1 ,1 ], :U ); A12 = A[1 ,2 ]
977+ A21 = copy (adjoint (A[1 ,2 ])); A22 = hermitian (A[2 ,2 ], :U )
978+ else # if tA == 'h'
979+ A11 = hermitian (A[1 ,1 ], :L ); A12 = copy (adjoint (A[2 ,1 ]))
980+ A21 = A[2 ,1 ]; A22 = hermitian (A[2 ,2 ], :L )
981+ end
993982 end
983+ end # inbounds
984+ A11, A12, A21, A22
985+ end
986+ Base. @constprop :aggressive __matmul2x2_elements (tA, tB, A, B) = __matmul2x2_elements (tA, A), __matmul2x2_elements (tB, B)
987+
988+ Base. @constprop :aggressive function matmul2x2! (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix ,
989+ _add:: MulAddMul = MulAddMul ())
990+ (A11, A12, A21, A22), (B11, B12, B21, B22) = _matmul2x2_elements (C, tA, tB, A, B)
991+ @inbounds begin
994992 _modify! (_add, A11* B11 + A12* B21, C, (1 ,1 ))
995- _modify! (_add, A11* B12 + A12* B22, C, (1 ,2 ))
996993 _modify! (_add, A21* B11 + A22* B21, C, (2 ,1 ))
994+ _modify! (_add, A11* B12 + A12* B22, C, (1 ,2 ))
997995 _modify! (_add, A21* B12 + A22* B22, C, (2 ,2 ))
998996 end # inbounds
999997 C
1000998end
1001999
10021000# Multiply 3x3 matrices
1003- function matmul3x3 (tA, tB, A:: AbstractMatrix{T} , B:: AbstractMatrix{S} ) where {T,S}
1001+ Base . @constprop :aggressive function matmul3x3 (tA, tB, A:: AbstractMatrix{T} , B:: AbstractMatrix{S} ) where {T,S}
10041002 matmul3x3! (similar (B, promote_op (matprod, T, S), 3 , 3 ), tA, tB, A, B)
10051003end
10061004
1007- function matmul3x3! (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix ,
1008- _add:: MulAddMul = MulAddMul ())
1009- require_one_based_indexing (C, A, B)
1010- if C === A || B === C
1011- throw (ArgumentError (" output matrix must not be aliased with input matrix" ))
1012- end
1013- if ! (size (A) == size (B) == size (C) == (3 ,3 ))
1014- throw (DimensionMismatch (lazy " A has size $(size(A)), B has size $(size(B)), C has size $(size(C))" ))
1015- end
1005+ # separate function with the core of matmul3x3! that doesn't depend on a MulAddMul
1006+ Base. @constprop :aggressive function _matmul3x3_elements (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix )
1007+ __matmul_checks (C, A, B, (3 ,3 ))
1008+ __matmul3x3_elements (tA, tB, A, B)
1009+ end
1010+ Base. @constprop :aggressive function __matmul3x3_elements (tA, A:: AbstractMatrix )
10161011 @inbounds begin
1017- if tA == ' N'
1012+ tA_uc = uppercase (tA) # possibly unwrap a WrapperChar
1013+ if tA_uc == ' N'
10181014 A11 = A[1 ,1 ]; A12 = A[1 ,2 ]; A13 = A[1 ,3 ]
10191015 A21 = A[2 ,1 ]; A22 = A[2 ,2 ]; A23 = A[2 ,3 ]
10201016 A31 = A[3 ,1 ]; A32 = A[3 ,2 ]; A33 = A[3 ,3 ]
1021- elseif tA == ' T'
1017+ elseif tA_uc == ' T'
10221018 # TODO making these lazy could improve perf
10231019 A11 = copy (transpose (A[1 ,1 ])); A12 = copy (transpose (A[2 ,1 ])); A13 = copy (transpose (A[3 ,1 ]))
10241020 A21 = copy (transpose (A[1 ,2 ])); A22 = copy (transpose (A[2 ,2 ])); A23 = copy (transpose (A[3 ,2 ]))
10251021 A31 = copy (transpose (A[1 ,3 ])); A32 = copy (transpose (A[2 ,3 ])); A33 = copy (transpose (A[3 ,3 ]))
1026- elseif tA == ' C'
1022+ elseif tA_uc == ' C'
10271023 # TODO making these lazy could improve perf
10281024 A11 = copy (A[1 ,1 ]' ); A12 = copy (A[2 ,1 ]' ); A13 = copy (A[3 ,1 ]' )
10291025 A21 = copy (A[1 ,2 ]' ); A22 = copy (A[2 ,2 ]' ); A23 = copy (A[3 ,2 ]' )
10301026 A31 = copy (A[1 ,3 ]' ); A32 = copy (A[2 ,3 ]' ); A33 = copy (A[3 ,3 ]' )
1031- elseif tA == ' S'
1032- A11 = symmetric (A[1 ,1 ], :U ); A12 = A[1 ,2 ]; A13 = A[1 ,3 ]
1033- A21 = copy (transpose (A[1 ,2 ])); A22 = symmetric (A[2 ,2 ], :U ); A23 = A[2 ,3 ]
1034- A31 = copy (transpose (A[1 ,3 ])); A32 = copy (transpose (A[2 ,3 ])); A33 = symmetric (A[3 ,3 ], :U )
1035- elseif tA == ' s'
1036- A11 = symmetric (A[1 ,1 ], :L ); A12 = copy (transpose (A[2 ,1 ])); A13 = copy (transpose (A[3 ,1 ]))
1037- A21 = A[2 ,1 ]; A22 = symmetric (A[2 ,2 ], :L ); A23 = copy (transpose (A[3 ,2 ]))
1038- A31 = A[3 ,1 ]; A32 = A[3 ,2 ]; A33 = symmetric (A[3 ,3 ], :L )
1039- elseif tA == ' H'
1040- A11 = hermitian (A[1 ,1 ], :U ); A12 = A[1 ,2 ]; A13 = A[1 ,3 ]
1041- A21 = copy (adjoint (A[1 ,2 ])); A22 = hermitian (A[2 ,2 ], :U ); A23 = A[2 ,3 ]
1042- A31 = copy (adjoint (A[1 ,3 ])); A32 = copy (adjoint (A[2 ,3 ])); A33 = hermitian (A[3 ,3 ], :U )
1043- else # if tA == 'h'
1044- A11 = hermitian (A[1 ,1 ], :L ); A12 = copy (adjoint (A[2 ,1 ])); A13 = copy (adjoint (A[3 ,1 ]))
1045- A21 = A[2 ,1 ]; A22 = hermitian (A[2 ,2 ], :L ); A23 = copy (adjoint (A[3 ,2 ]))
1046- A31 = A[3 ,1 ]; A32 = A[3 ,2 ]; A33 = hermitian (A[3 ,3 ], :L )
1027+ elseif tA_uc == ' S'
1028+ if isuppercase (tA) # tA == 'S'
1029+ A11 = symmetric (A[1 ,1 ], :U ); A12 = A[1 ,2 ]; A13 = A[1 ,3 ]
1030+ A21 = copy (transpose (A[1 ,2 ])); A22 = symmetric (A[2 ,2 ], :U ); A23 = A[2 ,3 ]
1031+ A31 = copy (transpose (A[1 ,3 ])); A32 = copy (transpose (A[2 ,3 ])); A33 = symmetric (A[3 ,3 ], :U )
1032+ else
1033+ A11 = symmetric (A[1 ,1 ], :L ); A12 = copy (transpose (A[2 ,1 ])); A13 = copy (transpose (A[3 ,1 ]))
1034+ A21 = A[2 ,1 ]; A22 = symmetric (A[2 ,2 ], :L ); A23 = copy (transpose (A[3 ,2 ]))
1035+ A31 = A[3 ,1 ]; A32 = A[3 ,2 ]; A33 = symmetric (A[3 ,3 ], :L )
1036+ end
1037+ elseif tA_uc == ' H'
1038+ if isuppercase (tA) # tA == 'H'
1039+ A11 = hermitian (A[1 ,1 ], :U ); A12 = A[1 ,2 ]; A13 = A[1 ,3 ]
1040+ A21 = copy (adjoint (A[1 ,2 ])); A22 = hermitian (A[2 ,2 ], :U ); A23 = A[2 ,3 ]
1041+ A31 = copy (adjoint (A[1 ,3 ])); A32 = copy (adjoint (A[2 ,3 ])); A33 = hermitian (A[3 ,3 ], :U )
1042+ else # if tA == 'h'
1043+ A11 = hermitian (A[1 ,1 ], :L ); A12 = copy (adjoint (A[2 ,1 ])); A13 = copy (adjoint (A[3 ,1 ]))
1044+ A21 = A[2 ,1 ]; A22 = hermitian (A[2 ,2 ], :L ); A23 = copy (adjoint (A[3 ,2 ]))
1045+ A31 = A[3 ,1 ]; A32 = A[3 ,2 ]; A33 = hermitian (A[3 ,3 ], :L )
1046+ end
10471047 end
1048+ end # inbounds
1049+ A11, A12, A13, A21, A22, A23, A31, A32, A33
1050+ end
1051+ Base. @constprop :aggressive __matmul3x3_elements (tA, tB, A, B) = __matmul3x3_elements (tA, A), __matmul3x3_elements (tB, B)
10481052
1049- if tB == ' N'
1050- B11 = B[1 ,1 ]; B12 = B[1 ,2 ]; B13 = B[1 ,3 ]
1051- B21 = B[2 ,1 ]; B22 = B[2 ,2 ]; B23 = B[2 ,3 ]
1052- B31 = B[3 ,1 ]; B32 = B[3 ,2 ]; B33 = B[3 ,3 ]
1053- elseif tB == ' T'
1054- # TODO making these lazy could improve perf
1055- B11 = copy (transpose (B[1 ,1 ])); B12 = copy (transpose (B[2 ,1 ])); B13 = copy (transpose (B[3 ,1 ]))
1056- B21 = copy (transpose (B[1 ,2 ])); B22 = copy (transpose (B[2 ,2 ])); B23 = copy (transpose (B[3 ,2 ]))
1057- B31 = copy (transpose (B[1 ,3 ])); B32 = copy (transpose (B[2 ,3 ])); B33 = copy (transpose (B[3 ,3 ]))
1058- elseif tB == ' C'
1059- # TODO making these lazy could improve perf
1060- B11 = copy (B[1 ,1 ]' ); B12 = copy (B[2 ,1 ]' ); B13 = copy (B[3 ,1 ]' )
1061- B21 = copy (B[1 ,2 ]' ); B22 = copy (B[2 ,2 ]' ); B23 = copy (B[3 ,2 ]' )
1062- B31 = copy (B[1 ,3 ]' ); B32 = copy (B[2 ,3 ]' ); B33 = copy (B[3 ,3 ]' )
1063- elseif tB == ' S'
1064- B11 = symmetric (B[1 ,1 ], :U ); B12 = B[1 ,2 ]; B13 = B[1 ,3 ]
1065- B21 = copy (transpose (B[1 ,2 ])); B22 = symmetric (B[2 ,2 ], :U ); B23 = B[2 ,3 ]
1066- B31 = copy (transpose (B[1 ,3 ])); B32 = copy (transpose (B[2 ,3 ])); B33 = symmetric (B[3 ,3 ], :U )
1067- elseif tB == ' s'
1068- B11 = symmetric (B[1 ,1 ], :L ); B12 = copy (transpose (B[2 ,1 ])); B13 = copy (transpose (B[3 ,1 ]))
1069- B21 = B[2 ,1 ]; B22 = symmetric (B[2 ,2 ], :L ); B23 = copy (transpose (B[3 ,2 ]))
1070- B31 = B[3 ,1 ]; B32 = B[3 ,2 ]; B33 = symmetric (B[3 ,3 ], :L )
1071- elseif tB == ' H'
1072- B11 = hermitian (B[1 ,1 ], :U ); B12 = B[1 ,2 ]; B13 = B[1 ,3 ]
1073- B21 = copy (adjoint (B[1 ,2 ])); B22 = hermitian (B[2 ,2 ], :U ); B23 = B[2 ,3 ]
1074- B31 = copy (adjoint (B[1 ,3 ])); B32 = copy (adjoint (B[2 ,3 ])); B33 = hermitian (B[3 ,3 ], :U )
1075- else # if tB == 'h'
1076- B11 = hermitian (B[1 ,1 ], :L ); B12 = copy (adjoint (B[2 ,1 ])); B13 = copy (adjoint (B[3 ,1 ]))
1077- B21 = B[2 ,1 ]; B22 = hermitian (B[2 ,2 ], :L ); B23 = copy (adjoint (B[3 ,2 ]))
1078- B31 = B[3 ,1 ]; B32 = B[3 ,2 ]; B33 = hermitian (B[3 ,3 ], :L )
1079- end
1053+ Base. @constprop :aggressive function matmul3x3! (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix ,
1054+ _add:: MulAddMul = MulAddMul ())
10801055
1081- _modify! (_add, A11* B11 + A12* B21 + A13* B31, C, (1 ,1 ))
1082- _modify! (_add, A11* B12 + A12* B22 + A13* B32, C, (1 ,2 ))
1083- _modify! (_add, A11* B13 + A12* B23 + A13* B33, C, (1 ,3 ))
1056+ (A11, A12, A13, A21, A22, A23, A31, A32, A33),
1057+ (B11, B12, B13, B21, B22, B23, B31, B32, B33) = _matmul3x3_elements (C, tA, tB, A, B)
10841058
1059+ @inbounds begin
1060+ _modify! (_add, A11* B11 + A12* B21 + A13* B31, C, (1 ,1 ))
10851061 _modify! (_add, A21* B11 + A22* B21 + A23* B31, C, (2 ,1 ))
1086- _modify! (_add, A21* B12 + A22* B22 + A23* B32, C, (2 ,2 ))
1087- _modify! (_add, A21* B13 + A22* B23 + A23* B33, C, (2 ,3 ))
1088-
10891062 _modify! (_add, A31* B11 + A32* B21 + A33* B31, C, (3 ,1 ))
1063+
1064+ _modify! (_add, A11* B12 + A12* B22 + A13* B32, C, (1 ,2 ))
1065+ _modify! (_add, A21* B12 + A22* B22 + A23* B32, C, (2 ,2 ))
10901066 _modify! (_add, A31* B12 + A32* B22 + A33* B32, C, (3 ,2 ))
1067+
1068+ _modify! (_add, A11* B13 + A12* B23 + A13* B33, C, (1 ,3 ))
1069+ _modify! (_add, A21* B13 + A22* B23 + A23* B33, C, (2 ,3 ))
10911070 _modify! (_add, A31* B13 + A32* B23 + A33* B33, C, (3 ,3 ))
10921071 end # inbounds
10931072 C
0 commit comments