Skip to content

Commit 05340d2

Browse files
Merge pull request #2038 from fredrik-johansson/matmul
Linear algebra tuning for nfloat + cmpabs
2 parents 91f0ece + 99bd846 commit 05340d2

File tree

8 files changed

+539
-19
lines changed

8 files changed

+539
-19
lines changed

doc/source/nfloat.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,10 @@ Matrix functions
324324

325325
Different implementations of matrix multiplication.
326326

327+
.. function:: int nfloat_mat_nonsingular_solve_tril(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx)
328+
int nfloat_mat_nonsingular_solve_triu(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx)
329+
int nfloat_mat_lu(slong * rank, slong * P, gr_mat_t LU, const gr_mat_t A, int rank_check, gr_ctx_t ctx)
330+
327331
Internal functions
328332
-------------------------------------------------------------------------------
329333

@@ -417,3 +421,6 @@ real pairs.
417421
int nfloat_complex_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_block_size, gr_ctx_t ctx)
418422
int nfloat_complex_mat_mul_reorder(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
419423
int nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
424+
int nfloat_complex_mat_nonsingular_solve_tril(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx)
425+
int nfloat_complex_mat_nonsingular_solve_triu(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx)
426+
int nfloat_complex_mat_lu(slong * rank, slong * P, gr_mat_t LU, const gr_mat_t A, int rank_check, gr_ctx_t ctx)

src/gr/acb.c

Lines changed: 95 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -987,21 +987,108 @@ _gr_acb_cmp(int * res, const acb_t x, const acb_t y, const gr_ctx_t ctx)
987987
}
988988
}
989989

990+
int
991+
_gr_arb_cmpabs(int * res, const arb_t x, const arb_t y, const gr_ctx_t ctx);
992+
990993
int
991994
_gr_acb_cmpabs(int * res, const acb_t x, const acb_t y, const gr_ctx_t ctx)
992995
{
993-
acb_t t, u;
996+
if (arb_is_zero(acb_imagref(x)) && arb_is_zero(acb_imagref(y)))
997+
{
998+
arb_srcptr a = acb_realref(x);
999+
arb_srcptr c = acb_realref(y);
9941000

995-
*t = *x;
996-
*u = *y;
1001+
/* OK; ignores the context object */
1002+
return _gr_arb_cmpabs(res, a, c, ctx);
1003+
}
1004+
else
1005+
{
1006+
slong prec = ACB_CTX_PREC(ctx);
1007+
int status = GR_SUCCESS;
9971008

998-
if (arf_sgn(arb_midref(acb_realref(t))) < 0)
999-
ARF_NEG(arb_midref(acb_realref(t)));
1009+
arb_srcptr a = acb_realref(x);
1010+
arb_srcptr b = acb_imagref(x);
1011+
arb_srcptr c = acb_realref(y);
1012+
arb_srcptr d = acb_imagref(y);
1013+
1014+
mag_t xlo, xhi, ylo, yhi, t;
1015+
1016+
mag_init(xlo);
1017+
mag_init(xhi);
1018+
mag_init(ylo);
1019+
mag_init(yhi);
1020+
mag_init(t);
1021+
1022+
arb_get_mag_lower(xlo, a);
1023+
arb_get_mag_lower(t, b);
1024+
mag_mul_lower(xlo, xlo, xlo);
1025+
mag_mul_lower(t, t, t);
1026+
mag_add_lower(xlo, xlo, t);
1027+
1028+
arb_get_mag_lower(ylo, c);
1029+
arb_get_mag_lower(t, d);
1030+
mag_mul_lower(ylo, ylo, ylo);
1031+
mag_mul_lower(t, t, t);
1032+
mag_add_lower(ylo, ylo, t);
1033+
1034+
arb_get_mag(xhi, a);
1035+
arb_get_mag(t, b);
1036+
mag_mul(xhi, xhi, xhi);
1037+
mag_mul(t, t, t);
1038+
mag_add(xhi, xhi, t);
1039+
1040+
arb_get_mag(yhi, c);
1041+
arb_get_mag(t, d);
1042+
mag_mul(yhi, yhi, yhi);
1043+
mag_mul(t, t, t);
1044+
mag_add(yhi, yhi, t);
1045+
1046+
if (mag_cmp(xhi, ylo) < 0)
1047+
{
1048+
*res = -1;
1049+
status = GR_SUCCESS;
1050+
}
1051+
else if (mag_cmp(xlo, yhi) > 0)
1052+
{
1053+
*res = 1;
1054+
status = GR_SUCCESS;
1055+
}
1056+
else
1057+
{
1058+
arb_t t, u;
10001059

1001-
if (arf_sgn(arb_midref(acb_realref(u))) < 0)
1002-
ARF_NEG(arb_midref(acb_realref(u)));
1060+
arb_init(t);
1061+
arb_init(u);
10031062

1004-
return _gr_acb_cmp(res, t, u, ctx);
1063+
arb_mul(t, a, a, prec);
1064+
arb_addmul(t, b, b, prec);
1065+
arb_mul(u, c, c, prec);
1066+
arb_addmul(u, d, d, prec);
1067+
1068+
if ((arb_is_exact(t) && arb_is_exact(u)) || !arb_overlaps(t, u))
1069+
{
1070+
*res = arf_cmp(arb_midref(t), arb_midref(u));
1071+
status = GR_SUCCESS;
1072+
}
1073+
else
1074+
{
1075+
/* todo: worth it to do an exact computation? */
1076+
*res = 0;
1077+
status = GR_UNABLE;
1078+
}
1079+
1080+
arb_clear(t);
1081+
arb_clear(u);
1082+
}
1083+
1084+
mag_clear(xlo);
1085+
mag_clear(xhi);
1086+
mag_clear(ylo);
1087+
mag_clear(yhi);
1088+
mag_clear(t);
1089+
1090+
return status;
1091+
}
10051092
}
10061093

10071094
int

src/gr/acf.c

Lines changed: 134 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -666,13 +666,144 @@ _gr_acf_cmp(int * res, const acf_t x, const acf_t y, const gr_ctx_t ctx)
666666
return GR_SUCCESS;
667667
}
668668

669+
/* ignores ctx, so we can pass in the acf context */
670+
int
671+
_gr_arf_cmpabs(int * res, const arf_t x, const arf_t y, const gr_ctx_t ctx);
672+
673+
#include "double_extras.h"
674+
669675
int
670676
_gr_acf_cmpabs(int * res, const acf_t x, const acf_t y, const gr_ctx_t ctx)
671677
{
672-
if (!arf_is_zero(acf_imagref(x)) || !arf_is_zero(acf_imagref(y)))
673-
return GR_UNABLE;
678+
arf_srcptr a = acf_realref(x);
679+
arf_srcptr b = acf_imagref(x);
680+
arf_srcptr c = acf_realref(y);
681+
arf_srcptr d = acf_imagref(y);
682+
683+
if (arf_is_zero(b))
684+
{
685+
if (arf_is_zero(d))
686+
return _gr_arf_cmpabs(res, a, c, ctx);
687+
if (arf_is_zero(c))
688+
return _gr_arf_cmpabs(res, a, d, ctx);
689+
if (arf_is_zero(a))
690+
{
691+
*res = -1;
692+
return GR_SUCCESS;
693+
}
694+
}
695+
696+
if (arf_is_zero(a))
697+
{
698+
if (arf_is_zero(d))
699+
return _gr_arf_cmpabs(res, b, c, ctx);
700+
if (arf_is_zero(c))
701+
return _gr_arf_cmpabs(res, b, d, ctx);
702+
}
703+
704+
if (arf_is_zero(c) && arf_is_zero(d))
705+
{
706+
*res = 1;
707+
return GR_SUCCESS;
708+
}
709+
710+
if (ARF_IS_LAGOM(a) && ARF_IS_LAGOM(b) && ARF_IS_LAGOM(c) && ARF_IS_LAGOM(d))
711+
{
712+
slong aexp, bexp, cexp, dexp, xexp, yexp, exp;
713+
714+
aexp = arf_is_zero(a) ? WORD_MIN : ARF_EXP(a);
715+
bexp = arf_is_zero(b) ? WORD_MIN : ARF_EXP(b);
716+
cexp = arf_is_zero(c) ? WORD_MIN : ARF_EXP(c);
717+
dexp = arf_is_zero(d) ? WORD_MIN : ARF_EXP(d);
718+
719+
/* 0.5 * 2^xexp <= |x| < sqrt(2) * 2^xexp */
720+
xexp = FLINT_MAX(aexp, bexp);
721+
/* 0.5 * 2^yexp <= |y| < sqrt(2) * 2^yexp */
722+
yexp = FLINT_MAX(cexp, dexp);
723+
724+
if (xexp + 2 < yexp)
725+
{
726+
*res = -1;
727+
return GR_SUCCESS;
728+
}
729+
730+
if (xexp > yexp + 2)
731+
{
732+
*res = 1;
733+
return GR_SUCCESS;
734+
}
735+
736+
exp = FLINT_MAX(xexp, yexp);
737+
738+
double tt, xx = 0.0, yy = 0.0;
739+
nn_srcptr xp;
740+
slong xn;
741+
742+
if (aexp >= exp - 53)
743+
{
744+
ARF_GET_MPN_READONLY(xp, xn, a);
745+
tt = d_mul_2exp_inrange(xp[xn - 1], aexp - exp - FLINT_BITS);
746+
xx += tt * tt;
747+
}
748+
749+
if (bexp >= exp - 53)
750+
{
751+
ARF_GET_MPN_READONLY(xp, xn, b);
752+
tt = d_mul_2exp_inrange(xp[xn - 1], bexp - exp - FLINT_BITS);
753+
xx += tt * tt;
754+
}
755+
756+
if (cexp >= exp - 53)
757+
{
758+
ARF_GET_MPN_READONLY(xp, xn, c);
759+
tt = d_mul_2exp_inrange(xp[xn - 1], cexp - exp - FLINT_BITS);
760+
yy += tt * tt;
761+
}
762+
763+
if (dexp >= exp - 53)
764+
{
765+
ARF_GET_MPN_READONLY(xp, xn, d);
766+
tt = d_mul_2exp_inrange(xp[xn - 1], dexp - exp - FLINT_BITS);
767+
yy += tt * tt;
768+
}
769+
770+
if (xx < yy * 0.999999)
771+
{
772+
*res = -1;
773+
return GR_SUCCESS;
774+
}
775+
776+
if (xx * 0.999999 > yy)
777+
{
778+
*res = 1;
779+
return GR_SUCCESS;
780+
}
781+
}
782+
783+
arf_struct s[5];
784+
785+
arf_init(s + 0);
786+
arf_init(s + 1);
787+
arf_init(s + 2);
788+
arf_init(s + 3);
789+
arf_init(s + 4);
790+
791+
arf_mul(s + 0, a, a, ARF_PREC_EXACT, ARF_RND_DOWN);
792+
arf_mul(s + 1, b, b, ARF_PREC_EXACT, ARF_RND_DOWN);
793+
arf_mul(s + 2, c, c, ARF_PREC_EXACT, ARF_RND_DOWN);
794+
arf_mul(s + 3, d, d, ARF_PREC_EXACT, ARF_RND_DOWN);
795+
arf_neg(s + 2, s + 2);
796+
arf_neg(s + 3, s + 3);
797+
arf_sum(s + 4, s, 4, 30, ARF_RND_DOWN);
798+
799+
*res = arf_sgn(s + 4);
800+
801+
arf_clear(s + 0);
802+
arf_clear(s + 1);
803+
arf_clear(s + 2);
804+
arf_clear(s + 3);
805+
arf_clear(s + 4);
674806

675-
*res = arf_cmpabs(acf_realref(x), acf_realref(y));
676807
return GR_SUCCESS;
677808
}
678809

src/gr/test_ring.c

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2904,7 +2904,7 @@ gr_test_ordered_ring_cmpabs(gr_ctx_t R, flint_rand_t state, int test_flags)
29042904
status = GR_TEST_FAIL;
29052905
}
29062906

2907-
if (status & GR_DOMAIN && !(status & GR_UNABLE))
2907+
if (gr_ctx_is_ordered_ring(R) == T_TRUE && (status & GR_DOMAIN && !(status & GR_UNABLE)))
29082908
{
29092909
status = GR_TEST_FAIL;
29102910
}
@@ -4315,10 +4315,9 @@ gr_test_ring(gr_ctx_t R, slong iters, int test_flags)
43154315
gr_test_iter(R, state, "pow: ui/si/fmpz/fmpq", gr_test_pow_type_variants, iters, test_flags & (~GR_TEST_ALWAYS_ABLE));
43164316

43174317
if (gr_ctx_is_ordered_ring(R) == T_TRUE)
4318-
{
43194318
gr_test_iter(R, state, "ordered_ring_cmp", gr_test_ordered_ring_cmp, iters, test_flags);
4320-
gr_test_iter(R, state, "ordered_ring_cmpabs", gr_test_ordered_ring_cmpabs, iters, test_flags);
4321-
}
4319+
4320+
gr_test_iter(R, state, "ordered_ring_cmpabs", gr_test_ordered_ring_cmpabs, iters, test_flags);
43224321

43234322
gr_test_iter(R, state, "numerator_denominator", gr_test_numerator_denominator, iters, test_flags);
43244323
gr_test_iter(R, state, "complex_parts", gr_test_complex_parts, iters, test_flags);

src/nfloat.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,11 @@ int nfloat_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ct
458458
int nfloat_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_block_size, gr_ctx_t ctx);
459459
int nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx);
460460

461+
int nfloat_mat_nonsingular_solve_tril(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx);
462+
int nfloat_mat_nonsingular_solve_triu(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx);
463+
int nfloat_mat_lu(slong * rank, slong * P, gr_mat_t LU, const gr_mat_t A, int rank_check, gr_ctx_t ctx);
464+
465+
461466
/* Complex numbers */
462467
/* Note: we use the same context data for real and complex rings
463468
(only which_ring and sizeof_elem differ). This allows us to call
@@ -569,6 +574,10 @@ int nfloat_complex_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B,
569574
int nfloat_complex_mat_mul_reorder(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx);
570575
int nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx);
571576

577+
int nfloat_complex_mat_nonsingular_solve_tril(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx);
578+
int nfloat_complex_mat_nonsingular_solve_triu(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx);
579+
int nfloat_complex_mat_lu(slong * rank, slong * P, gr_mat_t LU, const gr_mat_t A, int rank_check, gr_ctx_t ctx);
580+
572581
#ifdef __cplusplus
573582
}
574583
#endif

0 commit comments

Comments
 (0)