Skip to content

Optional SIMD str(c)spn #597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions libc-top-half/musl/src/string/strcspn.c
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#if !defined(__wasm_simd128__) || !defined(__wasilibc_simd_string) || \
__clang_major__ == 19 || __clang_major__ == 20
// The SIMD implementation is in strspn_simd.c

#include <string.h>

#define BITOP(a,b,op) \
Expand All @@ -15,3 +19,5 @@ size_t strcspn(const char *s, const char *c)
for (; *s && !BITOP(byteset, *(unsigned char *)s, &); s++);
return s-a;
}

#endif
6 changes: 6 additions & 0 deletions libc-top-half/musl/src/string/strspn.c
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#if !defined(__wasm_simd128__) || !defined(__wasilibc_simd_string) || \
__clang_major__ == 19 || __clang_major__ == 20
// The SIMD implementation is in strspn_simd.c

#include <string.h>

#define BITOP(a,b,op) \
Expand All @@ -18,3 +22,5 @@ size_t strspn(const char *s, const char *c)
for (; *s && BITOP(byteset, *(unsigned char *)s, &); s++);
return s-a;
}

#endif
175 changes: 175 additions & 0 deletions libc-top-half/musl/src/string/strspn_simd.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#if defined(__wasm_simd128__) && defined(__wasilibc_simd_string)
// Skip Clang 19 and Clang 20 which have a bug (llvm/llvm-project#146574)
// which results in an ICE when inline assembly is used with a vector result.
#if __clang_major__ != 19 && __clang_major__ != 20

#include <stdint.h>
#include <string.h>
#include <wasm_simd128.h>

#if !defined(__wasm_relaxed_simd__) || !defined(__RELAXED_FN_ATTRS)
#define wasm_i8x16_relaxed_swizzle wasm_i8x16_swizzle
#endif

// SIMDized check which bytes are in a set (Geoff Langdale)
// http://0x80.pl/notesen/2018-10-18-simd-byte-lookup.html

typedef struct {
__u8x16 lo;
__u8x16 hi;
} __wasm_v128_bitmap256_t;

__attribute__((always_inline))
static void __wasm_v128_setbit(__wasm_v128_bitmap256_t *bitmap, int i) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not the following?

Suggested change
static void __wasm_v128_setbit(__wasm_v128_bitmap256_t *bitmap, int i) {
static void __wasm_v128_setbit(__wasm_v128_bitmap256_t *bitmap, uint8_t i) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. It moves the cast out to the loop, best not rely on implicit char to uint8_t.

Arguably, this bit would be more correct with unsigned char and unsigned.

uint8_t hi_nibble = (uint8_t)i >> 4;
uint8_t lo_nibble = (uint8_t)i & 0xf;
bitmap->lo[lo_nibble] |= (uint8_t)((uint32_t)1 << (hi_nibble - 0));
bitmap->hi[lo_nibble] |= (uint8_t)((uint32_t)1 << (hi_nibble - 8));
Comment on lines +26 to +27
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm interested in understanding the codegen of this: so the ...[lo_nibble] |= is generating some i8x16.replace_lane but somehow also OR-ing the high nibble bits? What is emitted by LLVM here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLVM cheats and uses the stack, wasm-opt (and wasm-ctor-eval) then removes any traces of $__stack_pointer in this build:
https://github.com/ncruces/go-sqlite3/blob/b72fd5db/sqlite3/libc/libc.wat#L1646-L1726

}

__attribute__((always_inline))
static v128_t __wasm_v128_chkbits(__wasm_v128_bitmap256_t bitmap, v128_t v) {
v128_t hi_nibbles = wasm_u8x16_shr(v, 4);
v128_t bitmask_lookup = wasm_u8x16_const(1, 2, 4, 8, 16, 32, 64, 128, //
1, 2, 4, 8, 16, 32, 64, 128);
Comment on lines +33 to +34
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are both -128 in the original algorithm, no?

Copy link
Contributor Author

@ncruces ncruces Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. In two's complement, int8(-128) is equal to uint8(128). In interpreting the algorithm, I found unsigned clearer, so used that.

I guess Intel lacks the unsigned version of _mm_setr_epi8 intrinsic (which would map to the same instruction anyways…)?

What do you think is best, a comment explaining this, or wasm_i8x16_const?

v128_t bitmask = wasm_i8x16_relaxed_swizzle(bitmask_lookup, hi_nibbles);

v128_t indices_0_7 = v & wasm_u8x16_const_splat(0x8f);
v128_t indices_8_15 = indices_0_7 ^ wasm_u8x16_const_splat(0x80);

v128_t row_0_7 = wasm_i8x16_swizzle(bitmap.lo, indices_0_7);
v128_t row_8_15 = wasm_i8x16_swizzle(bitmap.hi, indices_8_15);

v128_t bitsets = row_0_7 | row_8_15;
return wasm_i8x16_eq(bitsets & bitmask, bitmask);
}

size_t strspn(const char *s, const char *c)
{
// Note that reading before/after the allocation of a pointer is UB in
// C, so inline assembly is used to generate the exact machine
// instruction we want with opaque semantics to the compiler to avoid
// the UB.
uintptr_t align = (uintptr_t)s % sizeof(v128_t);
uintptr_t addr = (uintptr_t)s - align;

if (!c[0]) return 0;
if (!c[1]) {
v128_t vc = wasm_i8x16_splat(*c);
for (;;) {
v128_t v;
__asm__(
"local.get %1\n"
"v128.load 0\n"
"local.set %0\n"
: "=r"(v)
: "r"(addr)
: "memory");
v128_t cmp = wasm_i8x16_eq(v, vc);
// Bitmask is slow on AArch64, all_true is much faster.
if (!wasm_i8x16_all_true(cmp)) {
// Clear the bits corresponding to align (little-endian)
// so we can count trailing zeros.
int mask = (uint16_t)~wasm_i8x16_bitmask(cmp) >> align << align;
// At least one bit will be set, unless align cleared them.
// Knowing this helps the compiler if it unrolls the loop.
__builtin_assume(mask || align);
// If the mask became zero because of align,
// it's as if we didn't find anything.
if (mask) {
// Find the offset of the first one bit (little-endian).
return addr - (uintptr_t)s + __builtin_ctz(mask);
}
}
align = 0;
addr += sizeof(v128_t);
}
}

__wasm_v128_bitmap256_t bitmap = {};

for (; *c; c++) {
// Terminator IS NOT on the bitmap.
__wasm_v128_setbit(&bitmap, *c);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note for future reference: I was initially a bit concerned here that we will incur startup costs too heavy for the "check a small string" use case (?). But of course it's better to loop over c once up front rather than at each character in s like the scalar version does.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scalar version does the same: it iterates over c once, building a bitmap:
https://github.com/WebAssembly/wasi-libc/blob/main/libc-top-half/musl/src/string/strspn.c

They use some "inscrutable" (but well known) macros to build a more straightforward bitmap in stack memory.
I used this function to build our weird 256-bit bitmap "directly" into a pair of v128 vectors.

}

for (;;) {
v128_t v;
__asm__(
"local.get %1\n"
"v128.load 0\n"
"local.set %0\n"
: "=r"(v)
: "r"(addr)
: "memory");
v128_t cmp = __wasm_v128_chkbits(bitmap, v);
// Bitmask is slow on AArch64, all_true is much faster.
if (!wasm_i8x16_all_true(cmp)) {
// Clear the bits corresponding to align (little-endian)
// so we can count trailing zeros.
int mask = (uint16_t)~wasm_i8x16_bitmask(cmp) >> align << align;
// At least one bit will be set, unless align cleared them.
// Knowing this helps the compiler if it unrolls the loop.
__builtin_assume(mask || align);
// If the mask became zero because of align,
// it's as if we didn't find anything.
if (mask) {
// Find the offset of the first one bit (little-endian).
return addr - (uintptr_t)s + __builtin_ctz(mask);
}
}
Comment on lines +107 to +120
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We keep on using this pattern; is the logic repetitive enough that we should define a helper function somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but it's always subtly different.

Some times it's (1) cmp, any_true and bitmask; others (2) cmp, all_true and (uint16_t)~bitmask.

Then, strlen gets away with doing !all_true before the cmp and bitmask; memrchr does clz instead of ctz.

I also don't see how to get those two main patterns to inline without going into macros, and inlining is important. That __builtin_assume makes the compiler figure out it's good to unroll the first iteration to make the subsequent ones significantly simpler.

align = 0;
addr += sizeof(v128_t);
}
}

size_t strcspn(const char *s, const char *c)
{
if (!c[0] || !c[1]) return __strchrnul(s, *c) - s;

// Note that reading before/after the allocation of a pointer is UB in
// C, so inline assembly is used to generate the exact machine
// instruction we want with opaque semantics to the compiler to avoid
// the UB.
uintptr_t align = (uintptr_t)s % sizeof(v128_t);
uintptr_t addr = (uintptr_t)s - align;

__wasm_v128_bitmap256_t bitmap = {};

do {
// Terminator IS on the bitmap.
__wasm_v128_setbit(&bitmap, *c);
} while (*c++);

for (;;) {
v128_t v;
__asm__(
"local.get %1\n"
"v128.load 0\n"
"local.set %0\n"
: "=r"(v)
: "r"(addr)
: "memory");
v128_t cmp = __wasm_v128_chkbits(bitmap, v);
// Bitmask is slow on AArch64, any_true is much faster.
if (wasm_v128_any_true(cmp)) {
// Clear the bits corresponding to align (little-endian)
// so we can count trailing zeros.
int mask = wasm_i8x16_bitmask(cmp) >> align << align;
// At least one bit will be set, unless align cleared them.
// Knowing this helps the compiler if it unrolls the loop.
__builtin_assume(mask || align);
// If the mask became zero because of align,
// it's as if we didn't find anything.
if (mask) {
// Find the offset of the first one bit (little-endian).
return addr - (uintptr_t)s + __builtin_ctz(mask);
}
}
align = 0;
addr += sizeof(v128_t);
}
}

#endif
#endif
62 changes: 62 additions & 0 deletions test/src/misc/strcspn.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//! add-flags.py(LDFLAGS): -Wl,--stack-first -Wl,--initial-memory=327680

#include <__macro_PAGESIZE.h>
#include <stddef.h>
#include <stdio.h>
#include <string.h>

void test(char *ptr, char *set, size_t want) {
size_t got = strcspn(ptr, set);
if (got != want) {
printf("strcspn(%p, \"%s\") = %lu, want %lu\n", ptr, set, got, want);
}
}

int main(void) {
char *const LIMIT = (char *)(__builtin_wasm_memory_size(0) * PAGESIZE);

for (ptrdiff_t length = 0; length < 64; length++) {
for (ptrdiff_t alignment = 0; alignment < 24; alignment++) {
for (ptrdiff_t pos = -2; pos < length + 2; pos++) {
// Create a buffer with the given length, at a pointer with the given
// alignment. Using the offset LIMIT - PAGESIZE - 8 means many buffers
// will straddle a (Wasm, and likely OS) page boundary. Place the
// character to find at every position in the buffer, including just
// prior to it and after its end.
char *ptr = LIMIT - PAGESIZE - 8 + alignment;
memset(LIMIT - 2 * PAGESIZE, 0, 2 * PAGESIZE);
memset(ptr, 5, length);

// The first instance of the character is found.
if (pos >= 0) ptr[pos + 2] = 7;
ptr[pos] = 7;
ptr[length] = 0;

// The character is found if it's within range.
ptrdiff_t want = 0 <= pos && pos < length ? pos : length;
test(ptr, "\x07", want);
test(ptr, "\x07\x03", want);
test(ptr, "\x07\x85", want);
test(ptr, "\x87\x85", length);
}
}

// We need space for the terminator.
if (length == 0) continue;

// Ensure we never read past the end of memory.
char *ptr = LIMIT - length;
memset(LIMIT - 2 * PAGESIZE, 0, 2 * PAGESIZE);
memset(ptr, 5, length);

ptr[length - 1] = 7;
test(ptr, "\x07", length - 1);
test(ptr, "\x07\x03", length - 1);

ptr[length - 1] = 0;
test(ptr, "\x07", length - 1);
test(ptr, "\x07\x03", length - 1);
}

return 0;
}
62 changes: 62 additions & 0 deletions test/src/misc/strspn.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//! add-flags.py(LDFLAGS): -Wl,--stack-first -Wl,--initial-memory=327680

#include <__macro_PAGESIZE.h>
#include <stddef.h>
#include <stdio.h>
#include <string.h>

void test(char *ptr, char *set, size_t want) {
size_t got = strspn(ptr, set);
if (got != want) {
printf("strspn(%p, \"%s\") = %lu, want %lu\n", ptr, set, got, want);
}
}

int main(void) {
char *const LIMIT = (char *)(__builtin_wasm_memory_size(0) * PAGESIZE);

for (ptrdiff_t length = 0; length < 64; length++) {
for (ptrdiff_t alignment = 0; alignment < 24; alignment++) {
for (ptrdiff_t pos = -2; pos < length + 2; pos++) {
// Create a buffer with the given length, at a pointer with the given
// alignment. Using the offset LIMIT - PAGESIZE - 8 means many buffers
// will straddle a (Wasm, and likely OS) page boundary. Place the
// character to find at every position in the buffer, including just
// prior to it and after its end.
char *ptr = LIMIT - PAGESIZE - 8 + alignment;
memset(LIMIT - 2 * PAGESIZE, 0, 2 * PAGESIZE);
memset(ptr, 5, length);

// The first instance of the character is found.
if (pos >= 0) ptr[pos + 2] = 7;
ptr[pos] = 7;
ptr[length] = 0;

// The character is found if it's within range.
ptrdiff_t want = 0 <= pos && pos < length ? pos : length;
test(ptr, "\x05", want);
test(ptr, "\x05\x03", want);
test(ptr, "\x05\x87", want);
test(ptr, "\x05\x07", length);
}
}

// We need space for the terminator.
if (length == 0) continue;

// Ensure we never read past the end of memory.
char *ptr = LIMIT - length;
memset(LIMIT - 2 * PAGESIZE, 0, 2 * PAGESIZE);
memset(ptr, 5, length);

ptr[length - 1] = 7;
test(ptr, "\x05", length - 1);
test(ptr, "\x05\x03", length - 1);

ptr[length - 1] = 0;
test(ptr, "\x05", length - 1);
test(ptr, "\x05\x03", length - 1);
}

return 0;
}