Skip to content

Commit ea73f9c

Browse files
authored
Optional SIMD str(c)spn (#597)
Continuing #580, implements `strspn` and `strcspn`. This one follows the same general structure as #586, #592 and #594, but uses a somewhat more complicated algorithm, described [here](http://0x80.pl/notesen/2018-10-18-simd-byte-lookup.html). I used the Geoff Langdale alternative implementation (the tweet as since disappeared) which is correctly described there but has a subtle bug in the implementation: WojciechMula/simd-byte-lookup#2 Since the complexity needed for `__wasm_v128_bitmap256_t` is shared for both `strspn` and `strcspn`, I moved the implementation to a common file, when SIMD is used. The tests follow a similar structure as the previous ones, and cover the bug, which I was found through fuzzing.
1 parent 75836f0 commit ea73f9c

File tree

5 files changed

+316
-0
lines changed

5 files changed

+316
-0
lines changed

libc-top-half/musl/src/string/strcspn.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
#if !defined(__wasm_simd128__) || !defined(__wasilibc_simd_string) || \
2+
__clang_major__ == 19 || __clang_major__ == 20
3+
// The SIMD implementation is in strspn_simd.c
4+
15
#include <string.h>
26

37
#define BITOP(a,b,op) \
@@ -15,3 +19,5 @@ size_t strcspn(const char *s, const char *c)
1519
for (; *s && !BITOP(byteset, *(unsigned char *)s, &); s++);
1620
return s-a;
1721
}
22+
23+
#endif

libc-top-half/musl/src/string/strspn.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
#if !defined(__wasm_simd128__) || !defined(__wasilibc_simd_string) || \
2+
__clang_major__ == 19 || __clang_major__ == 20
3+
// The SIMD implementation is in strspn_simd.c
4+
15
#include <string.h>
26

37
#define BITOP(a,b,op) \
@@ -18,3 +22,5 @@ size_t strspn(const char *s, const char *c)
1822
for (; *s && BITOP(byteset, *(unsigned char *)s, &); s++);
1923
return s-a;
2024
}
25+
26+
#endif
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
#if defined(__wasm_simd128__) && defined(__wasilibc_simd_string)
2+
// Skip Clang 19 and Clang 20 which have a bug (llvm/llvm-project#146574)
3+
// which results in an ICE when inline assembly is used with a vector result.
4+
#if __clang_major__ != 19 && __clang_major__ != 20
5+
6+
#include <stdint.h>
7+
#include <string.h>
8+
#include <wasm_simd128.h>
9+
10+
#if !defined(__wasm_relaxed_simd__) || !defined(__RELAXED_FN_ATTRS)
11+
#define wasm_i8x16_relaxed_swizzle wasm_i8x16_swizzle
12+
#endif
13+
14+
// SIMDized check which bytes are in a set (Geoff Langdale)
15+
// http://0x80.pl/notesen/2018-10-18-simd-byte-lookup.html
16+
17+
// This is the same algorithm as truffle from Hyperscan:
18+
// https://github.com/intel/hyperscan/blob/v5.4.2/src/nfa/truffle.c#L64-L81
19+
// https://github.com/intel/hyperscan/blob/v5.4.2/src/nfa/trufflecompile.cpp
20+
21+
typedef struct {
22+
__u8x16 lo;
23+
__u8x16 hi;
24+
} __wasm_v128_bitmap256_t;
25+
26+
__attribute__((always_inline))
27+
static void __wasm_v128_setbit(__wasm_v128_bitmap256_t *bitmap, uint8_t i) {
28+
uint8_t hi_nibble = i >> 4;
29+
uint8_t lo_nibble = i & 0xf;
30+
bitmap->lo[lo_nibble] |= (uint8_t)(1u << (hi_nibble - 0));
31+
bitmap->hi[lo_nibble] |= (uint8_t)(1u << (hi_nibble - 8));
32+
}
33+
34+
__attribute__((always_inline))
35+
static v128_t __wasm_v128_chkbits(__wasm_v128_bitmap256_t bitmap, v128_t v) {
36+
v128_t hi_nibbles = wasm_u8x16_shr(v, 4);
37+
v128_t bitmask_lookup = wasm_u64x2_const_splat(0x8040201008040201);
38+
v128_t bitmask = wasm_i8x16_relaxed_swizzle(bitmask_lookup, hi_nibbles);
39+
40+
v128_t indices_0_7 = v & wasm_u8x16_const_splat(0x8f);
41+
v128_t indices_8_15 = indices_0_7 ^ wasm_u8x16_const_splat(0x80);
42+
43+
v128_t row_0_7 = wasm_i8x16_swizzle((v128_t)bitmap.lo, indices_0_7);
44+
v128_t row_8_15 = wasm_i8x16_swizzle((v128_t)bitmap.hi, indices_8_15);
45+
46+
v128_t bitsets = row_0_7 | row_8_15;
47+
return bitsets & bitmask;
48+
}
49+
50+
size_t strspn(const char *s, const char *c)
51+
{
52+
// Note that reading before/after the allocation of a pointer is UB in
53+
// C, so inline assembly is used to generate the exact machine
54+
// instruction we want with opaque semantics to the compiler to avoid
55+
// the UB.
56+
uintptr_t align = (uintptr_t)s % sizeof(v128_t);
57+
uintptr_t addr = (uintptr_t)s - align;
58+
59+
if (!c[0]) return 0;
60+
if (!c[1]) {
61+
v128_t vc = wasm_i8x16_splat(*c);
62+
for (;;) {
63+
v128_t v;
64+
__asm__(
65+
"local.get %1\n"
66+
"v128.load 0\n"
67+
"local.set %0\n"
68+
: "=r"(v)
69+
: "r"(addr)
70+
: "memory");
71+
v128_t cmp = wasm_i8x16_eq(v, vc);
72+
// Bitmask is slow on AArch64, all_true is much faster.
73+
if (!wasm_i8x16_all_true(cmp)) {
74+
// Clear the bits corresponding to align (little-endian)
75+
// so we can count trailing zeros.
76+
int mask = (uint16_t)~wasm_i8x16_bitmask(cmp) >> align << align;
77+
// At least one bit will be set, unless align cleared them.
78+
// Knowing this helps the compiler if it unrolls the loop.
79+
__builtin_assume(mask || align);
80+
// If the mask became zero because of align,
81+
// it's as if we didn't find anything.
82+
if (mask) {
83+
// Find the offset of the first one bit (little-endian).
84+
return addr - (uintptr_t)s + __builtin_ctz(mask);
85+
}
86+
}
87+
align = 0;
88+
addr += sizeof(v128_t);
89+
}
90+
}
91+
92+
__wasm_v128_bitmap256_t bitmap = {};
93+
94+
for (; *c; c++) {
95+
// Terminator IS NOT on the bitmap.
96+
__wasm_v128_setbit(&bitmap, (uint8_t)*c);
97+
}
98+
99+
for (;;) {
100+
v128_t v;
101+
__asm__(
102+
"local.get %1\n"
103+
"v128.load 0\n"
104+
"local.set %0\n"
105+
: "=r"(v)
106+
: "r"(addr)
107+
: "memory");
108+
v128_t found = __wasm_v128_chkbits(bitmap, v);
109+
// Bitmask is slow on AArch64, all_true is much faster.
110+
if (!wasm_i8x16_all_true(found)) {
111+
v128_t cmp = wasm_i8x16_eq(found, (v128_t){});
112+
// Clear the bits corresponding to align (little-endian)
113+
// so we can count trailing zeros.
114+
int mask = wasm_i8x16_bitmask(cmp) >> align << align;
115+
// At least one bit will be set, unless align cleared them.
116+
// Knowing this helps the compiler if it unrolls the loop.
117+
__builtin_assume(mask || align);
118+
// If the mask became zero because of align,
119+
// it's as if we didn't find anything.
120+
if (mask) {
121+
// Find the offset of the first one bit (little-endian).
122+
return addr - (uintptr_t)s + __builtin_ctz(mask);
123+
}
124+
}
125+
align = 0;
126+
addr += sizeof(v128_t);
127+
}
128+
}
129+
130+
size_t strcspn(const char *s, const char *c)
131+
{
132+
if (!c[0] || !c[1]) return __strchrnul(s, *c) - s;
133+
134+
// Note that reading before/after the allocation of a pointer is UB in
135+
// C, so inline assembly is used to generate the exact machine
136+
// instruction we want with opaque semantics to the compiler to avoid
137+
// the UB.
138+
uintptr_t align = (uintptr_t)s % sizeof(v128_t);
139+
uintptr_t addr = (uintptr_t)s - align;
140+
141+
__wasm_v128_bitmap256_t bitmap = {};
142+
143+
do {
144+
// Terminator IS on the bitmap.
145+
__wasm_v128_setbit(&bitmap, (uint8_t)*c);
146+
} while (*c++);
147+
148+
for (;;) {
149+
v128_t v;
150+
__asm__(
151+
"local.get %1\n"
152+
"v128.load 0\n"
153+
"local.set %0\n"
154+
: "=r"(v)
155+
: "r"(addr)
156+
: "memory");
157+
v128_t found = __wasm_v128_chkbits(bitmap, v);
158+
// Bitmask is slow on AArch64, any_true is much faster.
159+
if (wasm_v128_any_true(found)) {
160+
v128_t cmp = wasm_i8x16_eq(found, (v128_t){});
161+
// Clear the bits corresponding to align (little-endian)
162+
// so we can count trailing zeros.
163+
int mask = (uint16_t)~wasm_i8x16_bitmask(cmp) >> align << align;
164+
// At least one bit will be set, unless align cleared them.
165+
// Knowing this helps the compiler if it unrolls the loop.
166+
__builtin_assume(mask || align);
167+
// If the mask became zero because of align,
168+
// it's as if we didn't find anything.
169+
if (mask) {
170+
// Find the offset of the first one bit (little-endian).
171+
return addr - (uintptr_t)s + __builtin_ctz(mask);
172+
}
173+
}
174+
align = 0;
175+
addr += sizeof(v128_t);
176+
}
177+
}
178+
179+
#endif
180+
#endif

test/src/misc/strcspn.c

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//! add-flags.py(LDFLAGS): -Wl,--stack-first -Wl,--initial-memory=327680
2+
3+
#include <__macro_PAGESIZE.h>
4+
#include <stddef.h>
5+
#include <stdio.h>
6+
#include <string.h>
7+
8+
void test(char *ptr, char *set, size_t want) {
9+
size_t got = strcspn(ptr, set);
10+
if (got != want) {
11+
printf("strcspn(%p, \"%s\") = %lu, want %lu\n", ptr, set, got, want);
12+
}
13+
}
14+
15+
int main(void) {
16+
char *const LIMIT = (char *)(__builtin_wasm_memory_size(0) * PAGESIZE);
17+
18+
for (ptrdiff_t length = 0; length < 64; length++) {
19+
for (ptrdiff_t alignment = 0; alignment < 24; alignment++) {
20+
for (ptrdiff_t pos = -2; pos < length + 2; pos++) {
21+
// Create a buffer with the given length, at a pointer with the given
22+
// alignment. Using the offset LIMIT - PAGESIZE - 8 means many buffers
23+
// will straddle a (Wasm, and likely OS) page boundary. Place the
24+
// character to find at every position in the buffer, including just
25+
// prior to it and after its end.
26+
char *ptr = LIMIT - PAGESIZE - 8 + alignment;
27+
memset(LIMIT - 2 * PAGESIZE, 0, 2 * PAGESIZE);
28+
memset(ptr, 5, length);
29+
30+
// The first instance of the character is found.
31+
if (pos >= 0) ptr[pos + 2] = 7;
32+
ptr[pos] = 7;
33+
ptr[length] = 0;
34+
35+
// The character is found if it's within range.
36+
ptrdiff_t want = 0 <= pos && pos < length ? pos : length;
37+
test(ptr, "\x07", want);
38+
test(ptr, "\x07\x03", want);
39+
test(ptr, "\x07\x85", want);
40+
test(ptr, "\x87\x85", length);
41+
}
42+
}
43+
44+
// We need space for the terminator.
45+
if (length == 0) continue;
46+
47+
// Ensure we never read past the end of memory.
48+
char *ptr = LIMIT - length;
49+
memset(LIMIT - 2 * PAGESIZE, 0, 2 * PAGESIZE);
50+
memset(ptr, 5, length);
51+
52+
ptr[length - 1] = 7;
53+
test(ptr, "\x07", length - 1);
54+
test(ptr, "\x07\x03", length - 1);
55+
56+
ptr[length - 1] = 0;
57+
test(ptr, "\x07", length - 1);
58+
test(ptr, "\x07\x03", length - 1);
59+
}
60+
61+
return 0;
62+
}

test/src/misc/strspn.c

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//! add-flags.py(LDFLAGS): -Wl,--stack-first -Wl,--initial-memory=327680
2+
3+
#include <__macro_PAGESIZE.h>
4+
#include <stddef.h>
5+
#include <stdio.h>
6+
#include <string.h>
7+
8+
void test(char *ptr, char *set, size_t want) {
9+
size_t got = strspn(ptr, set);
10+
if (got != want) {
11+
printf("strspn(%p, \"%s\") = %lu, want %lu\n", ptr, set, got, want);
12+
}
13+
}
14+
15+
int main(void) {
16+
char *const LIMIT = (char *)(__builtin_wasm_memory_size(0) * PAGESIZE);
17+
18+
for (ptrdiff_t length = 0; length < 64; length++) {
19+
for (ptrdiff_t alignment = 0; alignment < 24; alignment++) {
20+
for (ptrdiff_t pos = -2; pos < length + 2; pos++) {
21+
// Create a buffer with the given length, at a pointer with the given
22+
// alignment. Using the offset LIMIT - PAGESIZE - 8 means many buffers
23+
// will straddle a (Wasm, and likely OS) page boundary. Place the
24+
// character to find at every position in the buffer, including just
25+
// prior to it and after its end.
26+
char *ptr = LIMIT - PAGESIZE - 8 + alignment;
27+
memset(LIMIT - 2 * PAGESIZE, 0, 2 * PAGESIZE);
28+
memset(ptr, 5, length);
29+
30+
// The first instance of the character is found.
31+
if (pos >= 0) ptr[pos + 2] = 7;
32+
ptr[pos] = 7;
33+
ptr[length] = 0;
34+
35+
// The character is found if it's within range.
36+
ptrdiff_t want = 0 <= pos && pos < length ? pos : length;
37+
test(ptr, "\x05", want);
38+
test(ptr, "\x05\x03", want);
39+
test(ptr, "\x05\x87", want);
40+
test(ptr, "\x05\x07", length);
41+
}
42+
}
43+
44+
// We need space for the terminator.
45+
if (length == 0) continue;
46+
47+
// Ensure we never read past the end of memory.
48+
char *ptr = LIMIT - length;
49+
memset(LIMIT - 2 * PAGESIZE, 0, 2 * PAGESIZE);
50+
memset(ptr, 5, length);
51+
52+
ptr[length - 1] = 7;
53+
test(ptr, "\x05", length - 1);
54+
test(ptr, "\x05\x03", length - 1);
55+
56+
ptr[length - 1] = 0;
57+
test(ptr, "\x05", length - 1);
58+
test(ptr, "\x05\x03", length - 1);
59+
}
60+
61+
return 0;
62+
}

0 commit comments

Comments
 (0)