Skip to content
This repository has been archived by the owner on May 25, 2024. It is now read-only.

Commit

Permalink
fix: using arm neon in arm64 WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Nambers committed May 8, 2024
1 parent 5b1bdd6 commit 5b3e396
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 6 deletions.
36 changes: 36 additions & 0 deletions src/str-utils.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "str-utils.h"
#ifdef __aarch64__
// ref https://stackoverflow.com/questions/11870910/sse-mm-movemask-epi8-equivalent-method-for-arm-neon
// Use shifts to collect all of the sign bits.
// I'm not sure if this works on big endian, but big endian NEON is very
// rare.
int16_t _movemmask(uint8x16_t input)
{
// Example input (half scale):
// 0x89 FF 1D C0 00 10 99 33

// Shift out everything but the sign bits
// 0x01 01 00 01 00 00 01 00
uint16x8_t high_bits = vreinterpretq_u16_u8(vshrq_n_u8(input, 7));

// Merge the even lanes together with vsra. The '??' bytes are garbage.
// vsri could also be used, but it is slightly slower on aarch64.
// 0x??03 ??02 ??00 ??01
uint32x4_t paired16 = vreinterpretq_u32_u16(
vsraq_n_u16(high_bits, high_bits, 7));
// Repeat with wider lanes.
// 0x??????0B ??????04
uint64x2_t paired32 = vreinterpretq_u64_u32(
vsraq_n_u32(paired16, paired16, 14));
// 0x??????????????4B
uint8x16_t paired64 = vreinterpretq_u8_u64(
vsraq_n_u64(paired32, paired32, 28));
// Extract the low 8 bits from each lane and join.
// 0x4B
return vgetq_lane_u8(paired64, 0) | ((int)vgetq_lane_u8(paired64, 8) << 8);
}
uint32_t movemask(uint8x16x2_t v) {
// TODO fix _movemmask
return _movemmask(v.val[0]) | (_movemmask(v.val[1]) >> 16);
}
#endif
24 changes: 24 additions & 0 deletions src/str-utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef STR_UTILS_H
#define STR_UTILS_H

#ifndef __aarch64__
#include <immintrin.h> // AVX
#define mm256_greater8u_mask(a, b) _mm256_movemask_epi8(_mm256_and_si256(_mm256_cmpgt_epi8(a, b), _mm256_cmpgt_epi8(m256_zero, a)))

#else
#include <arm_neon.h>

uint32_t movemask(uint8x16x2_t vector);

#define _mm256_loadu_si256(c) vld2q_u8((const uint8_t *) c)
#define _mm256_set1_epi8(v) vtrnq_u8(vdupq_n_u8(v), vdupq_n_u8(v))
#define _mm256_set1_epi16(v) vtrnq_u8(vreinterpretq_u8_u16(vdupq_n_u16(v)), vreinterpretq_u8_u16(vdupq_n_u16(v)))
#define _mm256_cmpeq_epi8(v, v2) vtrnq_u8(vceqq_u8(v.val[0], v2.val[0]), vceqq_u8(v.val[1], v2.val[1]))
#define _mm256_movemask_epi8 movemask
#define __m256i uint8x16x2_t
#define __mmask32 uint32_t
#define mm256_greater8u_mask(a, b) movemask(vtrnq_u8(vcgtq_u8(a.val[0], b.val[0]), vcgtq_u8(a.val[1], b.val[1])))

#endif

#endif //STR_UTILS_H
11 changes: 5 additions & 6 deletions src/str.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

#include "simdutf_wrapper.h"

// #include <emmintrin.h> // SSE2
#include <immintrin.h> // AVX
#include "str-utils.h"
#include <stdint.h>
#define CHECK_NOT_LATIN1_2BYTES(a, b) (((a & 0b00011111) << 6 | (b & 0b00111111)) > 0xFF)

// https://web.archive.org/web/20151229003112/http://blogs.msdn.com/b/jeuge/archive/2005/06/08/hakmem-bit-count.aspx
// input must be uint
int BitCount(unsigned int u) {
unsigned int uCount = u - ((u >> 1) & 033333333333) - ((u >> 2) & 011111111111);
Expand All @@ -26,6 +26,7 @@ bool count_skipped(const char *buf, size_t max_len, size_t *skipped, size_t *len

for (; i + 32 < max_len; i += 32) {
__m256i batch = _mm256_loadu_si256((__m256i *) (buf + i));
// for avx512
// __mmask32 escape_result = _mm256_cmpeq_epi8_mask(batch, escape_mask);
// __mmask32 end_result = _mm256_cmpeq_epi8_mask(batch, end_mask);
// __mmask32 u_result = _mm256_cmpeq_epi8_mask(batch, u_mask);
Expand Down Expand Up @@ -235,14 +236,12 @@ int get_utf8_kind(const unsigned char *buf, size_t len) {
int kind = 1;
for (i = 0; i + 32 <= len; i += 32) {
__m256i in = _mm256_loadu_si256((const void *) (buf + i));
__m256i cond = _mm256_and_si256(_mm256_cmpgt_epi8(in, min_4bytes), _mm256_cmpgt_epi8(m256_zero, in));
if (_mm256_movemask_epi8(cond) != 0){
if (mm256_greater8u_mask(in, min_4bytes) != 0) {
// it is 4 bytes
return 4;
}
// if not all bytes are utf8 1bytes sequence in this batch
cond = _mm256_and_si256(_mm256_cmpgt_epi8(in, max_onebyte), _mm256_cmpgt_epi8(m256_zero, in));
if (_mm256_movemask_epi8(cond)) {
if (mm256_greater8u_mask(in, max_onebyte)) {
for (int j = 0; j < 32; j++) {
if (buf[i + j] & 0b10000000) {
if (buf[i + j] & 0b01000000) {
Expand Down

0 comments on commit 5b3e396

Please sign in to comment.