Skip to content

Commit c96a3b6

Browse files
committed
impr(opt): Lib/math isqrt: use static table over loop
1 parent e9f14d1 commit c96a3b6

File tree

1 file changed

+78
-4
lines changed

1 file changed

+78
-4
lines changed

src/pylib/Lib/n_math.nim

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -622,17 +622,87 @@ func factorial*(x: Natural): int =
622622
expM sqrt
623623

624624

625+
const approximate_isqrt_tab: array[192, uint8] = [
626+
128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
627+
140, 141, 142, 143, 144, 144, 145, 146, 147, 148, 149, 150,
628+
151, 151, 152, 153, 154, 155, 156, 156, 157, 158, 159, 160,
629+
160, 161, 162, 163, 164, 164, 165, 166, 167, 167, 168, 169,
630+
170, 170, 171, 172, 173, 173, 174, 175, 176, 176, 177, 178,
631+
179, 179, 180, 181, 181, 182, 183, 183, 184, 185, 186, 186,
632+
187, 188, 188, 189, 190, 190, 191, 192, 192, 193, 194, 194,
633+
195, 196, 196, 197, 198, 198, 199, 200, 200, 201, 201, 202,
634+
203, 203, 204, 205, 205, 206, 206, 207, 208, 208, 209, 210,
635+
210, 211, 211, 212, 213, 213, 214, 214, 215, 216, 216, 217,
636+
217, 218, 219, 219, 220, 220, 221, 221, 222, 223, 223, 224,
637+
224, 225, 225, 226, 227, 227, 228, 228, 229, 229, 230, 230,
638+
231, 232, 232, 233, 233, 234, 234, 235, 235, 236, 237, 237,
639+
238, 238, 239, 239, 240, 240, 241, 241, 242, 242, 243, 243,
640+
244, 244, 245, 246, 246, 247, 247, 248, 248, 249, 249, 250,
641+
250, 251, 251, 252, 252, 253, 253, 254, 254, 255, 255, 255,
642+
] ##[The approximate_isqrt_tab table provides approximate square roots for
643+
16-bit integers. For any n in the range 2**14 <= n < 2**16, the value
644+
645+
a = approximate_isqrt_tab[(n >> 8) - 64]
646+
647+
is an approximate square root of n, satisfying (a - 1)**2 < n < (a + 1)**2.
648+
649+
The table was computed in Python using the expression:
650+
651+
[min(round(sqrt(256*n + 128)), 255) for n in range(64, 256)]
652+
653+
Or in Nim using the expression:
654+
655+
```Nim
656+
import std/sugar
657+
import std/math
658+
collect:
659+
for n in 64..255:
660+
min(round(sqrt(float(256*n + 128))), 255)
661+
```
662+
]##
663+
664+
func approximate_isqrt(n: uint64 #[range[2u64 shl 62 .. uint64.high-1]]#): uint32{.inline.} =
665+
##[Approximate square root of a large 64-bit integer.
666+
667+
Given `n` satisfying `2**62 <= n < 2**64`, return `a`
668+
satisfying `(a - 1)**2 < n < (a + 1)**2`.
669+
]##
670+
assert n in 2u64 shl 61 .. uint64.high # 2**62 ..< 2**64
671+
{.push boundChecks: off.}
672+
let idx = (n shr 56) - 64
673+
assert idx in (approximate_isqrt_tab.low.uint64 .. approximate_isqrt_tab.high.uint64)
674+
var u = uint32 approximate_isqrt_tab[cast[int](idx)]
675+
{.pop.}
676+
u = (u shl 7) + cast[uint32](n shr 41) div u
677+
(u shl 15) + cast[uint32]((n shr 17) div u)
678+
625679
func isqrtPositive*(n: Positive): int{.inline.} =
626680
## EXT: isqrt for Positive only,
627681
## as we all know, in Python:
628-
## - isqrt(0) == 0
629-
## - isqrt(-int)
630-
let c = (n.bit_length() - 1) div 2
682+
## - `isqrt(0)` == 0
683+
## - `isqrt(-<positive int>)` raises ValueError
684+
685+
let nBits = n.bit_lengthUsingBitops()
686+
assert 0 < nBits
687+
688+
let c = (nBits - 1) div 2
689+
690+
assert c <= 31
691+
692+
let shift = 31 - int c
693+
let m = uint64 n
694+
var u = approximate_isqrt(m shl (2*shift)) shr shift
695+
if uint64(u) * u > m:
696+
u.dec
697+
cast[int](u)
698+
699+
#[slow impl:
700+
let c = (n.bit_lengthUsingBitops() - 1) div 2
631701
var
632702
a = 1
633703
d = 0
634704
if c != 0:
635-
for s in countdown(c.bit_length() - 1, 0):
705+
for s in countdown(c.bit_lengthUsingBitops() - 1, 0):
636706
# Loop invariant: (a-1)**2 < (n >> 2*(c - d)) < (a+1)**2
637707
let e = d
638708
d = c shr s
@@ -641,6 +711,7 @@ func isqrtPositive*(n: Positive): int{.inline.} =
641711
result = a
642712
if (a*a > n):
643713
result.dec
714+
]#
644715

645716
func isqrt*(n: Natural): int{.raises: [].} =
646717
runnableExamples:
@@ -679,3 +750,6 @@ func isqrt*[T: SomeFloat](x: T): int{.raises: [].} =
679750
isqrt i
680751

681752

753+
when isMainModule:
754+
for i in (high(int)-1)..(high(int)-1):
755+
discard

0 commit comments

Comments
 (0)