Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused adler #6

Merged
merged 2 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified silesia-small.tar.gz
Binary file not shown.
146 changes: 127 additions & 19 deletions src/adler32.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::mem::MaybeUninit;

pub fn adler32(start_checksum: u32, data: &[u8]) -> u32 {
#[cfg(target_arch = "x86_64")]
if std::is_x86_feature_detected!("avx2") {
Expand All @@ -7,6 +9,25 @@ pub fn adler32(start_checksum: u32, data: &[u8]) -> u32 {
adler32_rust(start_checksum, data)
}

pub fn adler32_fold_copy(start_checksum: u32, dst: &mut [MaybeUninit<u8>], src: &[u8]) -> u32 {
debug_assert!(dst.len() >= src.len(), "{} < {}", dst.len(), src.len());

#[cfg(target_arch = "x86_64")]
if std::is_x86_feature_detected!("avx2") {
return avx2::adler32_fold_copy_avx2(start_checksum, dst, src);
}

let adler = adler32_rust(start_checksum, src);
dst[..src.len()].copy_from_slice(slice_to_uninit(src));
adler
}

// when stable, use MaybeUninit::write_slice
fn slice_to_uninit(slice: &[u8]) -> &[MaybeUninit<u8>] {
// safety: &[T] and &[MaybeUninit<T>] have the same layout
unsafe { &*(slice as *const [u8] as *const [MaybeUninit<u8>]) }
}

// inefficient but correct, useful for testing
#[cfg(test)]
fn naive_adler32(start_checksum: u32, data: &[u8]) -> u32 {
Expand Down Expand Up @@ -124,6 +145,25 @@ fn adler32_len_16(mut adler: u32, buf: &[u8], mut sum2: u32) -> u32 {
adler | (sum2 << 16)
}

fn adler32_copy_len_16(
mut adler: u32,
dst: &mut [MaybeUninit<u8>],
src: &[u8],
mut sum2: u32,
) -> u32 {
for (source, destination) in src.iter().zip(dst.iter_mut()) {
let v = *source;
*destination = MaybeUninit::new(v);
adler += v as u32;
sum2 += adler;
}

adler %= BASE;
sum2 %= BASE; /* only added so many BASE's */
/* return recombined sums */
adler | (sum2 << 16)
}

fn adler32_len_64(mut adler: u32, buf: &[u8], mut sum2: u32) -> u32 {
const N: usize = if UNROLL_MORE { 16 } else { 8 };
let mut it = buf.chunks_exact(N);
Expand All @@ -145,8 +185,8 @@ mod avx2 {
use std::arch::x86_64::{
__m256i, _mm256_add_epi32, _mm256_castsi256_si128, _mm256_extracti128_si256,
_mm256_loadu_si256, _mm256_madd_epi16, _mm256_maddubs_epi16, _mm256_permutevar8x32_epi32,
_mm256_sad_epu8, _mm256_slli_epi32, _mm256_zextsi128_si256, _mm_add_epi32,
_mm_cvtsi128_si32, _mm_cvtsi32_si128, _mm_shuffle_epi32, _mm_unpackhi_epi64,
_mm256_sad_epu8, _mm256_slli_epi32, _mm256_storeu_si256, _mm256_zextsi128_si256,
_mm_add_epi32, _mm_cvtsi128_si32, _mm_cvtsi32_si128, _mm_shuffle_epi32, _mm_unpackhi_epi64,
};

const fn __m256i_literal(bytes: [u8; 32]) -> __m256i {
Expand Down Expand Up @@ -207,38 +247,74 @@ mod avx2 {
(array_slice, remainder)
}

pub fn adler32_avx2(adler: u32, buf: &[u8]) -> u32 {
if buf.is_empty() {
pub fn adler32_avx2(adler: u32, src: &[u8]) -> u32 {
adler32_avx2_help::<false>(adler, &mut [], src)
}

pub fn adler32_fold_copy_avx2(adler: u32, dst: &mut [MaybeUninit<u8>], src: &[u8]) -> u32 {
adler32_avx2_help::<true>(adler, dst, src)
}

fn adler32_avx2_help<const COPY: bool>(
adler: u32,
mut dst: &mut [MaybeUninit<u8>],
src: &[u8],
) -> u32 {
if src.is_empty() {
return adler;
}

let mut adler1 = (adler >> 16) & 0xffff;
let mut adler0 = adler & 0xffff;

if buf.len() < 16 {
return adler32_len_16(adler0, buf, adler1);
} else if buf.len() < 32 {
return adler32_len_64(adler0, buf, adler1);
if src.len() < 16 {
// use COPY const generic for this branch
if COPY {
return adler32_copy_len_16(adler0, dst, src, adler1);
} else {
return adler32_len_16(adler0, src, adler1);
}
} else if src.len() < 32 {
// use COPY const generic for this branch
if COPY {
return adler32_copy_len_16(adler0, dst, src, adler1);
} else {
return adler32_len_64(adler0, src, adler1);
}
}

// use largest step possible (without causing overflow)
const N: usize = (NMAX - (NMAX % 32)) as usize;
let (chunks, remainder) = slice_as_chunks::<_, N>(buf);
let (chunks, remainder) = slice_as_chunks::<_, N>(src);
for chunk in chunks {
(adler0, adler1) = unsafe { helper_32_bytes(adler0, adler1, chunk) };
(adler0, adler1) = unsafe { helper_32_bytes::<COPY>(adler0, adler1, dst, chunk) };
if COPY {
dst = &mut dst[N..];
}
}

// then take steps of 32 bytes
let (chunks, remainder) = slice_as_chunks::<_, 32>(remainder);
for chunk in chunks {
(adler0, adler1) = unsafe { helper_32_bytes(adler0, adler1, chunk) };
(adler0, adler1) = unsafe { helper_32_bytes::<COPY>(adler0, adler1, dst, chunk) };
if COPY {
dst = &mut dst[32..];
}
}

if !remainder.is_empty() {
if remainder.len() < 16 {
return adler32_len_16(adler0, remainder, adler1);
if COPY {
return adler32_copy_len_16(adler0, dst, remainder, adler1);
} else {
return adler32_len_16(adler0, remainder, adler1);
}
} else if remainder.len() < 32 {
return adler32_len_64(adler0, remainder, adler1);
if COPY {
return adler32_copy_len_16(adler0, dst, remainder, adler1);
} else {
return adler32_len_64(adler0, remainder, adler1);
}
} else {
unreachable!()
}
Expand All @@ -248,21 +324,33 @@ mod avx2 {
}

#[inline(always)]
unsafe fn helper_32_bytes(mut adler0: u32, mut adler1: u32, buf: &[u8]) -> (u32, u32) {
debug_assert_eq!(buf.len() % 32, 0);
unsafe fn helper_32_bytes<const COPY: bool>(
mut adler0: u32,
mut adler1: u32,
dst: &mut [MaybeUninit<u8>],
src: &[u8],
) -> (u32, u32) {
debug_assert_eq!(src.len() % 32, 0);

let mut vs1 = _mm256_zextsi128_si256(_mm_cvtsi32_si128(adler0 as i32));
let mut vs2 = _mm256_zextsi128_si256(_mm_cvtsi32_si128(adler1 as i32));

let mut vs1_0 = vs1;
let mut vs3 = ZERO;

for chunk in buf.chunks_exact(32) {
let vbuf = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
let mut out_chunks = dst.chunks_exact_mut(32);

let vs1_sad = _mm256_sad_epu8(vbuf, ZERO); // Sum of abs diff, resulting in 2 x int32's
for in_chunk in src.chunks_exact(32) {
let vbuf = _mm256_loadu_si256(in_chunk.as_ptr() as *const __m256i);

// TODO copy?
if COPY {
// println!("simd copy {:?}", in_chunk);
let out_chunk = out_chunks.next().unwrap();
_mm256_storeu_si256(out_chunk.as_mut_ptr() as *mut __m256i, vbuf);
// out_chunk.copy_from_slice(slice_to_uninit(in_chunk))
}

let vs1_sad = _mm256_sad_epu8(vbuf, ZERO); // Sum of abs diff, resulting in 2 x int32's

vs1 = _mm256_add_epi32(vs1, vs1_sad);
vs3 = _mm256_add_epi32(vs3, vs1_0);
Expand Down Expand Up @@ -302,6 +390,26 @@ mod avx2 {
assert_eq!(naive_adler32(1, &vec[..i]), adler32_avx2(1, &vec[..i]));
}
}

#[cfg(test)]
// TODO: This could use `MaybeUninit::slice_assume_init` when it is stable.
unsafe fn slice_assume_init(slice: &[MaybeUninit<u8>]) -> &[u8] {
&*(slice as *const [MaybeUninit<u8>] as *const [u8])
}

#[test]
fn fold_copy_copies() {
let src: Vec<_> = (0..128).map(|x| x as u8).collect();
let mut dst = [MaybeUninit::new(0); 128];

for (i, _) in src.iter().enumerate() {
dst.fill(MaybeUninit::new(0));

adler32_fold_copy_avx2(1, &mut dst[..i], &src[..i]);

assert_eq!(&src[..i], unsafe { slice_assume_init(&dst[..i]) })
}
}
}

#[cfg(test)]
Expand Down
22 changes: 8 additions & 14 deletions src/window.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::adler32::adler32;
use crate::adler32::{adler32, adler32_fold_copy};
use std::mem::MaybeUninit;

// translation guide:
Expand Down Expand Up @@ -89,10 +89,7 @@ impl<'a> Window<'a> {

if checksum != 0 {
checksum = adler32(checksum, non_window_slice);

checksum = adler32(checksum, window_slice);
self.buf
.copy_from_slice(unsafe { slice_to_uninit(window_slice) });
checksum = adler32_fold_copy(checksum, self.buf, window_slice);
} else {
self.buf
.copy_from_slice(unsafe { slice_to_uninit(window_slice) });
Expand All @@ -107,23 +104,20 @@ impl<'a> Window<'a> {
// written to the start of the window.
let (end_part, start_part) = slice.split_at(dist);

let end_part = unsafe { slice_to_uninit(end_part) };
let start_part = unsafe { slice_to_uninit(start_part) };

if checksum != 0 {
// TODO fuse memcpy and adler
checksum = adler32(checksum, slice);
self.buf[self.next..][..end_part.len()].copy_from_slice(end_part);
let dst = &mut self.buf[self.next..][..end_part.len()];
checksum = adler32_fold_copy(checksum, dst, end_part);
} else {
let end_part = unsafe { slice_to_uninit(end_part) };
self.buf[self.next..][..end_part.len()].copy_from_slice(end_part);
}

if !start_part.is_empty() {
if checksum != 0 {
// TODO fuse memcpy and adler
checksum = adler32(checksum, slice);
self.buf[..start_part.len()].copy_from_slice(start_part);
let dst = &mut self.buf[..start_part.len()];
checksum = adler32_fold_copy(checksum, dst, start_part);
} else {
let start_part = unsafe { slice_to_uninit(start_part) };
self.buf[..start_part.len()].copy_from_slice(start_part);
}

Expand Down