Skip to content

Commit

Permalink
arm64: satd: Complete 8 bpc NEON implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
barrbrain committed Nov 7, 2023
1 parent b498b44 commit a387c7d
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 153 deletions.
76 changes: 69 additions & 7 deletions src/arm/64/satd.S
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ function satd4x16_neon, export=1
#undef dst_stride
endfunc

.macro load_rows n0, n1, n2, src, dst, src_stride, dst_stride, should_add=1
.macro load_rows n0, n1, n2, src, dst, src_stride, dst_stride
ldr d\n0, [\src]
ldr d\n1, [\dst]

Expand All @@ -600,10 +600,8 @@ endfunc

usubl v\n1\().8h, v\n1\().8b, v\n2\().8b

.if \should_add != 0
add \src, \src, \src_stride, lsl 1
add \dst, \dst, \dst_stride, lsl 1
.endif
.endm

.macro HADAMARD_8X8
Expand Down Expand Up @@ -693,9 +691,6 @@ endfunc
// stage 3 sum
add v0.4s, v0.4s, v4.4s
addv s0, v0.4s

fmov w0, s0
normalize_8
.endm

function satd8x8_neon, export=1
Expand All @@ -704,18 +699,85 @@ function satd8x8_neon, export=1
#define dst x2
#define dst_stride x3

#define subtotal w9
#define total w10
#define w_ext x11
#define w_bak w11
#define width w12
#define height w13

mov height, 8
mov width, 8
sxtw w_ext, width
mov total, wzr

L(satd_8x8):
load_rows 0, 1, 2, src, dst, src_stride, dst_stride
load_rows 4, 5, 6, src, dst, src_stride, dst_stride
load_rows 16, 17, 20, src, dst, src_stride, dst_stride
load_rows 18, 19, 22, src, dst, src_stride, dst_stride, 0
load_rows 18, 19, 22, src, dst, src_stride, dst_stride

HADAMARD_8X8

SUM_HADAMARD_8X8
fmov subtotal, s0
add total, subtotal, total

sub src, src, src_stride, lsl 3
sub dst, dst, dst_stride, lsl 3
add src, src, #8
add dst, dst, #8
subs width, width, #8
bne L(satd_8x8)

sub src, src, w_ext
sub dst, dst, w_ext
add src, src, src_stride, lsl 3
add dst, dst, dst_stride, lsl 3
subs height, height, #8
mov width, w_bak
bne L(satd_8x8)

mov w0, total
normalize_8
ret

#undef src
#undef src_stride
#undef dst
#undef dst_stride

#undef w_ext
#undef w_bak
#undef subtotal
#undef total
#undef height
#undef width
endfunc

.macro satd_x8up width, height
function satd\width\()x\height\()_neon, export=1
mov w13, \height
mov w12, \width
sxtw x11, w12
mov w10, wzr
b L(satd_8x8)
endfunc
.endm

satd_x8up 8, 16
satd_x8up 8, 32
satd_x8up 16, 8
satd_x8up 16, 16
satd_x8up 16, 32
satd_x8up 16, 64
satd_x8up 32, 8
satd_x8up 32, 16
satd_x8up 32, 32
satd_x8up 32, 64
satd_x8up 64, 16
satd_x8up 64, 32
satd_x8up 64, 64
satd_x8up 64, 128
satd_x8up 128, 64
satd_x8up 128, 128
163 changes: 17 additions & 146 deletions src/asm/aarch64/dist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,23 @@ declare_asm_dist_fn![
(rav1e_satd4x16_neon, u8),
(rav1e_satd8x4_neon, u8),
(rav1e_satd8x8_neon, u8),
(rav1e_satd16x4_neon, u8)
(rav1e_satd8x16_neon, u8),
(rav1e_satd8x32_neon, u8),
(rav1e_satd16x4_neon, u8),
(rav1e_satd16x8_neon, u8),
(rav1e_satd16x16_neon, u8),
(rav1e_satd16x32_neon, u8),
(rav1e_satd16x64_neon, u8),
(rav1e_satd32x8_neon, u8),
(rav1e_satd32x16_neon, u8),
(rav1e_satd32x32_neon, u8),
(rav1e_satd32x64_neon, u8),
(rav1e_satd64x16_neon, u8),
(rav1e_satd64x32_neon, u8),
(rav1e_satd64x64_neon, u8),
(rav1e_satd64x128_neon, u8),
(rav1e_satd128x64_neon, u8),
(rav1e_satd128x128_neon, u8)
];

// BlockSize::BLOCK_SIZES.next_power_of_two();
Expand Down Expand Up @@ -109,151 +125,6 @@ pub fn get_sad<T: Pixel>(
dist
}

macro_rules! impl_satd_fn {
($(($name: ident, $T: ident, $LOG_W:expr, $log_h:expr)),+) => (
$(
#[no_mangle]
unsafe extern fn $name (
src: *const $T, src_stride: isize, dst: *const $T, dst_stride: isize
) -> u32 {
rav1e_satd8wx8h_neon::<$LOG_W>(src, src_stride, dst, dst_stride, $log_h)
}
)+
)
}

impl_satd_fn![
(rav1e_satd8x16_neon, u8, 0, 1),
(rav1e_satd8x32_neon, u8, 0, 2),
(rav1e_satd16x8_neon, u8, 1, 0),
(rav1e_satd16x16_neon, u8, 1, 1),
(rav1e_satd16x32_neon, u8, 1, 2),
(rav1e_satd16x64_neon, u8, 1, 3),
(rav1e_satd32x8_neon, u8, 2, 0),
(rav1e_satd32x16_neon, u8, 2, 1),
(rav1e_satd32x32_neon, u8, 2, 2),
(rav1e_satd32x64_neon, u8, 2, 3),
(rav1e_satd64x16_neon, u8, 3, 1),
(rav1e_satd64x32_neon, u8, 3, 2),
(rav1e_satd64x64_neon, u8, 3, 3),
(rav1e_satd64x128_neon, u8, 3, 4),
(rav1e_satd128x64_neon, u8, 4, 3),
(rav1e_satd128x128_neon, u8, 4, 4)
];

unsafe fn rav1e_satd8wx8h_neon<const LOG_W: usize>(
mut src: *const u8, src_stride: isize, mut dst: *const u8,
dst_stride: isize, log_h: usize,
) -> u32 {
let mut sum = 0;
for _ in 0..(1 << log_h) {
let (mut src_off, mut dst_off) = (src, dst);
for _ in 0..(1 << LOG_W) {
sum +=
rav1e_satd8x8_internal_neon(src_off, src_stride, dst_off, dst_stride);
src_off = src_off.add(8);
dst_off = dst_off.add(8);
}
src = src.offset(src_stride << 3);
dst = dst.offset(dst_stride << 3);
}
(sum + 4) >> 3
}

unsafe fn rav1e_satd8x8_internal_neon(
src: *const u8, src_stride: isize, dst: *const u8, dst_stride: isize,
) -> u32 {
use core::arch::aarch64::*;

let load_row = |src: *const u8, dst: *const u8| -> int16x8_t {
vreinterpretq_s16_u16(vsubl_u8(vld1_u8(src), vld1_u8(dst)))
};
let butterfly = |a: int16x8_t, b: int16x8_t| -> int16x8x2_t {
int16x8x2_t(vaddq_s16(a, b), vsubq_s16(a, b))
};
let zip1 = |v: int16x8x2_t| -> int16x8x2_t {
int16x8x2_t(vzip1q_s16(v.0, v.1), vzip2q_s16(v.0, v.1))
};
let zip2 = |v: int16x8x2_t| -> int16x8x2_t {
let v =
int32x4x2_t(vreinterpretq_s32_s16(v.0), vreinterpretq_s32_s16(v.1));
let v = int32x4x2_t(vzip1q_s32(v.0, v.1), vzip2q_s32(v.0, v.1));
int16x8x2_t(vreinterpretq_s16_s32(v.0), vreinterpretq_s16_s32(v.1))
};
let zip4 = |v: int16x8x2_t| -> int16x8x2_t {
let v =
int64x2x2_t(vreinterpretq_s64_s16(v.0), vreinterpretq_s64_s16(v.1));
let v = int64x2x2_t(vzip1q_s64(v.0, v.1), vzip2q_s64(v.0, v.1));
int16x8x2_t(vreinterpretq_s16_s64(v.0), vreinterpretq_s16_s64(v.1))
};

let (src_stride2, dst_stride2) = (src_stride << 1, dst_stride << 1);
let int16x8x2_t(r0, r1) = zip1(butterfly(
load_row(src, dst),
load_row(src.offset(src_stride), dst.offset(dst_stride)),
));
let (src, dst) = (src.offset(src_stride2), dst.offset(dst_stride2));
let int16x8x2_t(r2, r3) = zip1(butterfly(
load_row(src, dst),
load_row(src.offset(src_stride), dst.offset(dst_stride)),
));
let (src, dst) = (src.offset(src_stride2), dst.offset(dst_stride2));
let int16x8x2_t(r4, r5) = zip1(butterfly(
load_row(src, dst),
load_row(src.offset(src_stride), dst.offset(dst_stride)),
));
let (src, dst) = (src.offset(src_stride2), dst.offset(dst_stride2));
let int16x8x2_t(r6, r7) = zip1(butterfly(
load_row(src, dst),
load_row(src.offset(src_stride), dst.offset(dst_stride)),
));

let int16x8x2_t(r0, r2) = zip2(butterfly(r0, r2));
let int16x8x2_t(r1, r3) = zip2(butterfly(r1, r3));
let int16x8x2_t(r4, r6) = zip2(butterfly(r4, r6));
let int16x8x2_t(r5, r7) = zip2(butterfly(r5, r7));

let int16x8x2_t(r0, r4) = zip4(butterfly(r0, r4));
let int16x8x2_t(r1, r5) = zip4(butterfly(r1, r5));
let int16x8x2_t(r2, r6) = zip4(butterfly(r2, r6));
let int16x8x2_t(r3, r7) = zip4(butterfly(r3, r7));

let int16x8x2_t(r0, r1) = butterfly(r0, r1);
let int16x8x2_t(r2, r3) = butterfly(r2, r3);
let int16x8x2_t(r4, r5) = butterfly(r4, r5);
let int16x8x2_t(r6, r7) = butterfly(r6, r7);

let int16x8x2_t(r0, r2) = butterfly(r0, r2);
let int16x8x2_t(r1, r3) = butterfly(r1, r3);
let int16x8x2_t(r4, r6) = butterfly(r4, r6);
let int16x8x2_t(r5, r7) = butterfly(r5, r7);

let int16x8x2_t(r0, r4) = butterfly(r0, r4);
let int16x8x2_t(r1, r5) = butterfly(r1, r5);
let int16x8x2_t(r2, r6) = butterfly(r2, r6);
let int16x8x2_t(r3, r7) = butterfly(r3, r7);

let r0 = vabsq_s16(r0);
let r1 = vabsq_s16(r1);
let r2 = vabsq_s16(r2);
let r3 = vabsq_s16(r3);
let r4 = vabsq_s16(r4);
let r5 = vabsq_s16(r5);
let r6 = vabsq_s16(r6);
let r7 = vabsq_s16(r7);

let (t0, t1) = (vmovl_s16(vget_low_s16(r0)), vmovl_s16(vget_high_s16(r0)));
let (t0, t1) = (vaddw_s16(t0, vget_low_s16(r1)), vaddw_high_s16(t1, r1));
let (t0, t1) = (vaddw_s16(t0, vget_low_s16(r2)), vaddw_high_s16(t1, r2));
let (t0, t1) = (vaddw_s16(t0, vget_low_s16(r3)), vaddw_high_s16(t1, r3));
let (t0, t1) = (vaddw_s16(t0, vget_low_s16(r4)), vaddw_high_s16(t1, r4));
let (t0, t1) = (vaddw_s16(t0, vget_low_s16(r5)), vaddw_high_s16(t1, r5));
let (t0, t1) = (vaddw_s16(t0, vget_low_s16(r6)), vaddw_high_s16(t1, r6));
let (t0, t1) = (vaddw_s16(t0, vget_low_s16(r7)), vaddw_high_s16(t1, r7));

vaddvq_s32(vaddq_s32(t0, t1)) as u32
}

#[inline(always)]
#[allow(clippy::let_and_return)]
pub fn get_satd<T: Pixel>(
Expand Down

0 comments on commit a387c7d

Please sign in to comment.