diff --git a/src/arm/64/satd.S b/src/arm/64/satd.S index 86b99b9b4c..8bcb11b9d0 100644 --- a/src/arm/64/satd.S +++ b/src/arm/64/satd.S @@ -589,7 +589,7 @@ function satd4x16_neon, export=1 #undef dst_stride endfunc -.macro load_rows n0, n1, n2, src, dst, src_stride, dst_stride, should_add=1 +.macro load_rows n0, n1, n2, src, dst, src_stride, dst_stride ldr d\n0, [\src] ldr d\n1, [\dst] @@ -600,10 +600,8 @@ endfunc usubl v\n1\().8h, v\n1\().8b, v\n2\().8b -.if \should_add != 0 add \src, \src, \src_stride, lsl 1 add \dst, \dst, \dst_stride, lsl 1 -.endif .endm .macro HADAMARD_8X8 @@ -693,9 +691,6 @@ endfunc // stage 3 sum add v0.4s, v0.4s, v4.4s addv s0, v0.4s - - fmov w0, s0 - normalize_8 .endm function satd8x8_neon, export=1 @@ -704,18 +699,85 @@ function satd8x8_neon, export=1 #define dst x2 #define dst_stride x3 + #define subtotal w9 + #define total w10 + #define w_ext x11 + #define w_bak w11 + #define width w12 + #define height w13 + + mov height, 8 + mov width, 8 + sxtw w_ext, width + mov total, wzr + +L(satd_8x8): load_rows 0, 1, 2, src, dst, src_stride, dst_stride load_rows 4, 5, 6, src, dst, src_stride, dst_stride load_rows 16, 17, 20, src, dst, src_stride, dst_stride - load_rows 18, 19, 22, src, dst, src_stride, dst_stride, 0 + load_rows 18, 19, 22, src, dst, src_stride, dst_stride HADAMARD_8X8 SUM_HADAMARD_8X8 + fmov subtotal, s0 + add total, subtotal, total + + sub src, src, src_stride, lsl 3 + sub dst, dst, dst_stride, lsl 3 + add src, src, #8 + add dst, dst, #8 + subs width, width, #8 + bne L(satd_8x8) + + sub src, src, w_ext + sub dst, dst, w_ext + add src, src, src_stride, lsl 3 + add dst, dst, dst_stride, lsl 3 + subs height, height, #8 + mov width, w_bak + bne L(satd_8x8) + + mov w0, total + normalize_8 ret #undef src #undef src_stride #undef dst #undef dst_stride + + #undef w_ext + #undef w_bak + #undef subtotal + #undef total + #undef height + #undef width +endfunc + +.macro satd_x8up width, height +function satd\width\()x\height\()_neon, export=1 + mov w13, \height + mov w12, \width + sxtw x11, w12 + mov w10, wzr + b L(satd_8x8) endfunc +.endm + +satd_x8up 8, 16 +satd_x8up 8, 32 +satd_x8up 16, 8 +satd_x8up 16, 16 +satd_x8up 16, 32 +satd_x8up 16, 64 +satd_x8up 32, 8 +satd_x8up 32, 16 +satd_x8up 32, 32 +satd_x8up 32, 64 +satd_x8up 64, 16 +satd_x8up 64, 32 +satd_x8up 64, 64 +satd_x8up 64, 128 +satd_x8up 128, 64 +satd_x8up 128, 128 diff --git a/src/asm/aarch64/dist.rs b/src/asm/aarch64/dist.rs index 2fcc40e686..dfd32c0da9 100644 --- a/src/asm/aarch64/dist.rs +++ b/src/asm/aarch64/dist.rs @@ -61,7 +61,23 @@ declare_asm_dist_fn![ (rav1e_satd4x16_neon, u8), (rav1e_satd8x4_neon, u8), (rav1e_satd8x8_neon, u8), - (rav1e_satd16x4_neon, u8) + (rav1e_satd8x16_neon, u8), + (rav1e_satd8x32_neon, u8), + (rav1e_satd16x4_neon, u8), + (rav1e_satd16x8_neon, u8), + (rav1e_satd16x16_neon, u8), + (rav1e_satd16x32_neon, u8), + (rav1e_satd16x64_neon, u8), + (rav1e_satd32x8_neon, u8), + (rav1e_satd32x16_neon, u8), + (rav1e_satd32x32_neon, u8), + (rav1e_satd32x64_neon, u8), + (rav1e_satd64x16_neon, u8), + (rav1e_satd64x32_neon, u8), + (rav1e_satd64x64_neon, u8), + (rav1e_satd64x128_neon, u8), + (rav1e_satd128x64_neon, u8), + (rav1e_satd128x128_neon, u8) ]; // BlockSize::BLOCK_SIZES.next_power_of_two(); @@ -109,151 +125,6 @@ pub fn get_sad( dist } -macro_rules! impl_satd_fn { - ($(($name: ident, $T: ident, $LOG_W:expr, $log_h:expr)),+) => ( - $( - #[no_mangle] - unsafe extern fn $name ( - src: *const $T, src_stride: isize, dst: *const $T, dst_stride: isize - ) -> u32 { - rav1e_satd8wx8h_neon::<$LOG_W>(src, src_stride, dst, dst_stride, $log_h) - } - )+ - ) -} - -impl_satd_fn![ - (rav1e_satd8x16_neon, u8, 0, 1), - (rav1e_satd8x32_neon, u8, 0, 2), - (rav1e_satd16x8_neon, u8, 1, 0), - (rav1e_satd16x16_neon, u8, 1, 1), - (rav1e_satd16x32_neon, u8, 1, 2), - (rav1e_satd16x64_neon, u8, 1, 3), - (rav1e_satd32x8_neon, u8, 2, 0), - (rav1e_satd32x16_neon, u8, 2, 1), - (rav1e_satd32x32_neon, u8, 2, 2), - (rav1e_satd32x64_neon, u8, 2, 3), - (rav1e_satd64x16_neon, u8, 3, 1), - (rav1e_satd64x32_neon, u8, 3, 2), - (rav1e_satd64x64_neon, u8, 3, 3), - (rav1e_satd64x128_neon, u8, 3, 4), - (rav1e_satd128x64_neon, u8, 4, 3), - (rav1e_satd128x128_neon, u8, 4, 4) -]; - -unsafe fn rav1e_satd8wx8h_neon( - mut src: *const u8, src_stride: isize, mut dst: *const u8, - dst_stride: isize, log_h: usize, -) -> u32 { - let mut sum = 0; - for _ in 0..(1 << log_h) { - let (mut src_off, mut dst_off) = (src, dst); - for _ in 0..(1 << LOG_W) { - sum += - rav1e_satd8x8_internal_neon(src_off, src_stride, dst_off, dst_stride); - src_off = src_off.add(8); - dst_off = dst_off.add(8); - } - src = src.offset(src_stride << 3); - dst = dst.offset(dst_stride << 3); - } - (sum + 4) >> 3 -} - -unsafe fn rav1e_satd8x8_internal_neon( - src: *const u8, src_stride: isize, dst: *const u8, dst_stride: isize, -) -> u32 { - use core::arch::aarch64::*; - - let load_row = |src: *const u8, dst: *const u8| -> int16x8_t { - vreinterpretq_s16_u16(vsubl_u8(vld1_u8(src), vld1_u8(dst))) - }; - let butterfly = |a: int16x8_t, b: int16x8_t| -> int16x8x2_t { - int16x8x2_t(vaddq_s16(a, b), vsubq_s16(a, b)) - }; - let zip1 = |v: int16x8x2_t| -> int16x8x2_t { - int16x8x2_t(vzip1q_s16(v.0, v.1), vzip2q_s16(v.0, v.1)) - }; - let zip2 = |v: int16x8x2_t| -> int16x8x2_t { - let v = - int32x4x2_t(vreinterpretq_s32_s16(v.0), vreinterpretq_s32_s16(v.1)); - let v = int32x4x2_t(vzip1q_s32(v.0, v.1), vzip2q_s32(v.0, v.1)); - int16x8x2_t(vreinterpretq_s16_s32(v.0), vreinterpretq_s16_s32(v.1)) - }; - let zip4 = |v: int16x8x2_t| -> int16x8x2_t { - let v = - int64x2x2_t(vreinterpretq_s64_s16(v.0), vreinterpretq_s64_s16(v.1)); - let v = int64x2x2_t(vzip1q_s64(v.0, v.1), vzip2q_s64(v.0, v.1)); - int16x8x2_t(vreinterpretq_s16_s64(v.0), vreinterpretq_s16_s64(v.1)) - }; - - let (src_stride2, dst_stride2) = (src_stride << 1, dst_stride << 1); - let int16x8x2_t(r0, r1) = zip1(butterfly( - load_row(src, dst), - load_row(src.offset(src_stride), dst.offset(dst_stride)), - )); - let (src, dst) = (src.offset(src_stride2), dst.offset(dst_stride2)); - let int16x8x2_t(r2, r3) = zip1(butterfly( - load_row(src, dst), - load_row(src.offset(src_stride), dst.offset(dst_stride)), - )); - let (src, dst) = (src.offset(src_stride2), dst.offset(dst_stride2)); - let int16x8x2_t(r4, r5) = zip1(butterfly( - load_row(src, dst), - load_row(src.offset(src_stride), dst.offset(dst_stride)), - )); - let (src, dst) = (src.offset(src_stride2), dst.offset(dst_stride2)); - let int16x8x2_t(r6, r7) = zip1(butterfly( - load_row(src, dst), - load_row(src.offset(src_stride), dst.offset(dst_stride)), - )); - - let int16x8x2_t(r0, r2) = zip2(butterfly(r0, r2)); - let int16x8x2_t(r1, r3) = zip2(butterfly(r1, r3)); - let int16x8x2_t(r4, r6) = zip2(butterfly(r4, r6)); - let int16x8x2_t(r5, r7) = zip2(butterfly(r5, r7)); - - let int16x8x2_t(r0, r4) = zip4(butterfly(r0, r4)); - let int16x8x2_t(r1, r5) = zip4(butterfly(r1, r5)); - let int16x8x2_t(r2, r6) = zip4(butterfly(r2, r6)); - let int16x8x2_t(r3, r7) = zip4(butterfly(r3, r7)); - - let int16x8x2_t(r0, r1) = butterfly(r0, r1); - let int16x8x2_t(r2, r3) = butterfly(r2, r3); - let int16x8x2_t(r4, r5) = butterfly(r4, r5); - let int16x8x2_t(r6, r7) = butterfly(r6, r7); - - let int16x8x2_t(r0, r2) = butterfly(r0, r2); - let int16x8x2_t(r1, r3) = butterfly(r1, r3); - let int16x8x2_t(r4, r6) = butterfly(r4, r6); - let int16x8x2_t(r5, r7) = butterfly(r5, r7); - - let int16x8x2_t(r0, r4) = butterfly(r0, r4); - let int16x8x2_t(r1, r5) = butterfly(r1, r5); - let int16x8x2_t(r2, r6) = butterfly(r2, r6); - let int16x8x2_t(r3, r7) = butterfly(r3, r7); - - let r0 = vabsq_s16(r0); - let r1 = vabsq_s16(r1); - let r2 = vabsq_s16(r2); - let r3 = vabsq_s16(r3); - let r4 = vabsq_s16(r4); - let r5 = vabsq_s16(r5); - let r6 = vabsq_s16(r6); - let r7 = vabsq_s16(r7); - - let (t0, t1) = (vmovl_s16(vget_low_s16(r0)), vmovl_s16(vget_high_s16(r0))); - let (t0, t1) = (vaddw_s16(t0, vget_low_s16(r1)), vaddw_high_s16(t1, r1)); - let (t0, t1) = (vaddw_s16(t0, vget_low_s16(r2)), vaddw_high_s16(t1, r2)); - let (t0, t1) = (vaddw_s16(t0, vget_low_s16(r3)), vaddw_high_s16(t1, r3)); - let (t0, t1) = (vaddw_s16(t0, vget_low_s16(r4)), vaddw_high_s16(t1, r4)); - let (t0, t1) = (vaddw_s16(t0, vget_low_s16(r5)), vaddw_high_s16(t1, r5)); - let (t0, t1) = (vaddw_s16(t0, vget_low_s16(r6)), vaddw_high_s16(t1, r6)); - let (t0, t1) = (vaddw_s16(t0, vget_low_s16(r7)), vaddw_high_s16(t1, r7)); - - vaddvq_s32(vaddq_s32(t0, t1)) as u32 -} - #[inline(always)] #[allow(clippy::let_and_return)] pub fn get_satd(