From 4e81bb972b3936f1a7a0d3bad356e0a33613d4c5 Mon Sep 17 00:00:00 2001 From: silvanshade Date: Thu, 4 Jan 2024 08:26:25 -0700 Subject: [PATCH] Implement compress for NEON --- benches/bench.rs | 8 + c/blake3_c_rust_bindings/src/lib.rs | 15 ++ c/blake3_dispatch.c | 12 ++ c/blake3_impl.h | 10 + c/blake3_neon.c | 310 ++++++++++++++++++++++++++-- src/ffi_neon.rs | 67 +++--- src/platform.rs | 12 +- 7 files changed, 395 insertions(+), 39 deletions(-) diff --git a/benches/bench.rs b/benches/bench.rs index 5efb9e6..a397432 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -85,6 +85,14 @@ fn bench_single_compression_avx512(b: &mut Bencher) { } } +#[bench] +#[cfg(feature = "neon")] +fn bench_single_compression_neon(b: &mut Bencher) { + if let Some(platform) = Platform::neon() { + bench_single_compression_fn(b, platform); + } +} + fn bench_many_chunks_fn(b: &mut Bencher, platform: Platform) { let degree = platform.simd_degree(); let mut inputs = Vec::new(); diff --git a/c/blake3_c_rust_bindings/src/lib.rs b/c/blake3_c_rust_bindings/src/lib.rs index 41e4938..3c487a3 100644 --- a/c/blake3_c_rust_bindings/src/lib.rs +++ b/c/blake3_c_rust_bindings/src/lib.rs @@ -289,6 +289,21 @@ pub mod ffi { pub mod neon { extern "C" { // NEON low level functions + pub fn blake3_compress_xof_neon( + cv: *const u32, + block: *const u8, + block_len: u8, + counter: u64, + flags: u8, + out: *mut u8, + ); + pub fn blake3_compress_in_place_neon( + cv: *mut u32, + block: *const u8, + block_len: u8, + counter: u64, + flags: u8, + ); pub fn blake3_hash_many_neon( inputs: *const *const u8, num_inputs: usize, diff --git a/c/blake3_dispatch.c b/c/blake3_dispatch.c index af6c3da..644d98c 100644 --- a/c/blake3_dispatch.c +++ b/c/blake3_dispatch.c @@ -188,6 +188,12 @@ void blake3_compress_in_place(uint32_t cv[8], } #endif #endif + +#if BLAKE3_USE_NEON == 1 + blake3_compress_in_place_neon(cv, block, block_len, counter, flags); + return; +#endif + blake3_compress_in_place_portable(cv, block, block_len, counter, flags); } @@ -217,6 +223,12 @@ void blake3_compress_xof(const uint32_t cv[8], } #endif #endif + +#if BLAKE3_USE_NEON == 1 + blake3_compress_xof_neon(cv, block, block_len, counter, flags, out); + return; +#endif + blake3_compress_xof_portable(cv, block, block_len, counter, flags, out); } diff --git a/c/blake3_impl.h b/c/blake3_impl.h index beab5cf..b4e6e4a 100644 --- a/c/blake3_impl.h +++ b/c/blake3_impl.h @@ -274,6 +274,16 @@ void blake3_hash_many_avx512(const uint8_t *const *inputs, size_t num_inputs, #endif #if BLAKE3_USE_NEON == 1 +void blake3_compress_in_place_neon(uint32_t cv[8], + const uint8_t block[BLAKE3_BLOCK_LEN], + uint8_t block_len, uint64_t counter, + uint8_t flags); + +void blake3_compress_xof_neon(const uint32_t cv[8], + const uint8_t block[BLAKE3_BLOCK_LEN], + uint8_t block_len, uint64_t counter, + uint8_t flags, uint8_t out[64]); + void blake3_hash_many_neon(const uint8_t *const *inputs, size_t num_inputs, size_t blocks, const uint32_t key[8], uint64_t counter, bool increment_counter, diff --git a/c/blake3_neon.c b/c/blake3_neon.c index 8a818fc..c3e1035 100644 --- a/c/blake3_neon.c +++ b/c/blake3_neon.c @@ -68,10 +68,263 @@ INLINE uint32x4_t rot7_128(uint32x4_t x) { return vsriq_n_u32(vshlq_n_u32(x, 32-7), x, 7); } -// TODO: compress_neon - // TODO: hash2_neon +INLINE void g1(uint32x4_t *row0, uint32x4_t *row1, uint32x4_t *row2, + uint32x4_t *row3, uint32x4_t m) { + *row0 = vaddq_u32(vaddq_u32(*row0, m), *row1); + *row3 = veorq_u32(*row3, *row0); + *row3 = rot16_128(*row3); + *row2 = vaddq_u32(*row2, *row3); + *row1 = veorq_u32(*row1, *row2); + *row1 = rot12_128(*row1); +} + +INLINE void g2(uint32x4_t *row0, uint32x4_t *row1, uint32x4_t *row2, + uint32x4_t *row3, uint32x4_t m) { + *row0 = vaddq_u32(vaddq_u32(*row0, m), *row1); + *row3 = veorq_u32(*row3, *row0); + *row3 = rot8_128(*row3); + *row2 = vaddq_u32(*row2, *row3); + *row1 = veorq_u32(*row1, *row2); + *row1 = rot7_128(*row1); +} + +INLINE void diagonalize(uint32x4_t *row0, uint32x4_t *row2, uint32x4_t *row3) { + *row0 = vextq_u32(*row0, *row0, 3); + *row3 = vextq_u32(*row3, *row3, 2); + *row2 = vextq_u32(*row2, *row2, 1); +} + +INLINE void undiagonalize(uint32x4_t *row0, uint32x4_t *row2, uint32x4_t *row3) { + *row0 = vextq_u32(*row0, *row0, 1); + *row3 = vextq_u32(*row3, *row3, 2); + *row2 = vextq_u32(*row2, *row2, 3); +} + +#define unpacklo_32(a, b) \ + vzip1q_u32(a, b) + +#define unpackhi_32(a, b) \ + vzip2q_u32(a, b) + +#define unpacklo_64(a, b) \ + vreinterpretq_u64_u32(vzip1q_u64(vreinterpretq_u32_u64(a), vreinterpretq_u32_u64(b))) + +#define shuffle_128(a, m3, m2, m1, m0) \ + (__builtin_shufflevector(a, a, m0, m1, m2, m3)) + +#define shuffle_256(a, b, m3, m2, m1, m0) \ + (__builtin_shufflevector(a, b, m0, m1, m2 + 4, m3 + 4)) + +#define blend_16(a, b, mask) \ + (vreinterpretq_u32_u16( \ + __builtin_shufflevector( \ + vreinterpretq_u16_u32(a), \ + vreinterpretq_u16_u32(b), \ + 0 + ((mask >> 0) & 1) * 8, \ + 1 + ((mask >> 1) & 1) * 8, \ + 2 + ((mask >> 2) & 1) * 8, \ + 3 + ((mask >> 3) & 1) * 8, \ + 4 + ((mask >> 4) & 1) * 8, \ + 5 + ((mask >> 5) & 1) * 8, \ + 6 + ((mask >> 6) & 1) * 8, \ + 7 + ((mask >> 7) & 1) * 8 \ + ))) + +INLINE void compress_pre(uint32x4_t rows[4], const uint32_t cv[8], + const uint8_t block[BLAKE3_BLOCK_LEN], + uint8_t block_len, uint64_t counter, uint8_t flags) { + rows[0] = loadu_128((uint8_t *)&cv[0]); + rows[1] = loadu_128((uint8_t *)&cv[4]); + rows[2] = set4(IV[0], IV[1], IV[2], IV[3]); + rows[3] = set4(counter_low(counter), counter_high(counter), + (uint32_t)block_len, (uint32_t)flags); + + uint32x4_t m0 = loadu_128(&block[sizeof(uint32x4_t) * 0]); + uint32x4_t m1 = loadu_128(&block[sizeof(uint32x4_t) * 1]); + uint32x4_t m2 = loadu_128(&block[sizeof(uint32x4_t) * 2]); + uint32x4_t m3 = loadu_128(&block[sizeof(uint32x4_t) * 3]); + + uint32x4_t t0, t1, t2, t3, tt; + + // Round 1. The first round permutes the message words from the original + // input order, into the groups that get mixed in parallel. + t0 = shuffle_256(m0, m1, 2, 0, 2, 0); // 6 4 2 0 + g1(&rows[0], &rows[1], &rows[2], &rows[3], t0); + t1 = shuffle_256(m0, m1, 3, 1, 3, 1); // 7 5 3 1 + g2(&rows[0], &rows[1], &rows[2], &rows[3], t1); + diagonalize(&rows[0], &rows[2], &rows[3]); + t2 = shuffle_256(m2, m3, 2, 0, 2, 0); // 14 12 10 8 + t2 = shuffle_128(t2, 2, 1, 0, 3); // 12 10 8 14 + g1(&rows[0], &rows[1], &rows[2], &rows[3], t2); + t3 = shuffle_256(m2, m3, 3, 1, 3, 1); // 15 13 11 9 + t3 = vextq_u32(t3, t3, 3); + g2(&rows[0], &rows[1], &rows[2], &rows[3], t3); + undiagonalize(&rows[0], &rows[2], &rows[3]); + m0 = t0; + m1 = t1; + m2 = t2; + m3 = t3; + + // Round 2. This round and all following rounds apply a fixed permutation + // to the message words from the round before. + t0 = shuffle_256(m0, m1, 3, 1, 1, 2); + t0 = vextq_u32(t0, t0, 1); + g1(&rows[0], &rows[1], &rows[2], &rows[3], t0); + t1 = shuffle_256(m2, m3, 3, 3, 2, 2); + tt = shuffle_128(m0, 0, 0, 3, 3); + t1 = blend_16(tt, t1, 0xCC); + g2(&rows[0], &rows[1], &rows[2], &rows[3], t1); + diagonalize(&rows[0], &rows[2], &rows[3]); + t2 = unpacklo_64(m3, m1); + tt = blend_16(t2, m2, 0xC0); + t2 = shuffle_128(tt, 1, 3, 2, 0); + g1(&rows[0], &rows[1], &rows[2], &rows[3], t2); + t3 = unpackhi_32(m1, m3); + tt = unpacklo_32(m2, t3); + t3 = shuffle_128(tt, 0, 1, 3, 2); + g2(&rows[0], &rows[1], &rows[2], &rows[3], t3); + undiagonalize(&rows[0], &rows[2], &rows[3]); + m0 = t0; + m1 = t1; + m2 = t2; + m3 = t3; + + // Round 3 + t0 = shuffle_256(m0, m1, 3, 1, 1, 2); + t0 = vextq_u32(t0, t0, 1); + g1(&rows[0], &rows[1], &rows[2], &rows[3], t0); + t1 = shuffle_256(m2, m3, 3, 3, 2, 2); + tt = shuffle_128(m0, 0, 0, 3, 3); + t1 = blend_16(tt, t1, 0xCC); + g2(&rows[0], &rows[1], &rows[2], &rows[3], t1); + diagonalize(&rows[0], &rows[2], &rows[3]); + t2 = unpacklo_64(m3, m1); + tt = blend_16(t2, m2, 0xC0); + t2 = shuffle_128(tt, 1, 3, 2, 0); + g1(&rows[0], &rows[1], &rows[2], &rows[3], t2); + t3 = unpackhi_32(m1, m3); + tt = unpacklo_32(m2, t3); + t3 = shuffle_128(tt, 0, 1, 3, 2); + g2(&rows[0], &rows[1], &rows[2], &rows[3], t3); + undiagonalize(&rows[0], &rows[2], &rows[3]); + m0 = t0; + m1 = t1; + m2 = t2; + m3 = t3; + + // Round 4 + t0 = shuffle_256(m0, m1, 3, 1, 1, 2); + t0 = vextq_u32(t0, t0, 1); + g1(&rows[0], &rows[1], &rows[2], &rows[3], t0); + t1 = shuffle_256(m2, m3, 3, 3, 2, 2); + tt = shuffle_128(m0, 0, 0, 3, 3); + t1 = blend_16(tt, t1, 0xCC); + g2(&rows[0], &rows[1], &rows[2], &rows[3], t1); + diagonalize(&rows[0], &rows[2], &rows[3]); + t2 = unpacklo_64(m3, m1); + tt = blend_16(t2, m2, 0xC0); + t2 = shuffle_128(tt, 1, 3, 2, 0); + g1(&rows[0], &rows[1], &rows[2], &rows[3], t2); + t3 = unpackhi_32(m1, m3); + tt = unpacklo_32(m2, t3); + t3 = shuffle_128(tt, 0, 1, 3, 2); + g2(&rows[0], &rows[1], &rows[2], &rows[3], t3); + undiagonalize(&rows[0], &rows[2], &rows[3]); + m0 = t0; + m1 = t1; + m2 = t2; + m3 = t3; + + // Round 5 + t0 = shuffle_256(m0, m1, 3, 1, 1, 2); + t0 = vextq_u32(t0, t0, 1); + g1(&rows[0], &rows[1], &rows[2], &rows[3], t0); + t1 = shuffle_256(m2, m3, 3, 3, 2, 2); + tt = shuffle_128(m0, 0, 0, 3, 3); + t1 = blend_16(tt, t1, 0xCC); + g2(&rows[0], &rows[1], &rows[2], &rows[3], t1); + diagonalize(&rows[0], &rows[2], &rows[3]); + t2 = unpacklo_64(m3, m1); + tt = blend_16(t2, m2, 0xC0); + t2 = shuffle_128(tt, 1, 3, 2, 0); + g1(&rows[0], &rows[1], &rows[2], &rows[3], t2); + t3 = unpackhi_32(m1, m3); + tt = unpacklo_32(m2, t3); + t3 = shuffle_128(tt, 0, 1, 3, 2); + g2(&rows[0], &rows[1], &rows[2], &rows[3], t3); + undiagonalize(&rows[0], &rows[2], &rows[3]); + m0 = t0; + m1 = t1; + m2 = t2; + m3 = t3; + + // Round 6 + t0 = shuffle_256(m0, m1, 3, 1, 1, 2); + t0 = vextq_u32(t0, t0, 1); + g1(&rows[0], &rows[1], &rows[2], &rows[3], t0); + t1 = shuffle_256(m2, m3, 3, 3, 2, 2); + tt = shuffle_128(m0, 0, 0, 3, 3); + t1 = blend_16(tt, t1, 0xCC); + g2(&rows[0], &rows[1], &rows[2], &rows[3], t1); + diagonalize(&rows[0], &rows[2], &rows[3]); + t2 = unpacklo_64(m3, m1); + tt = blend_16(t2, m2, 0xC0); + t2 = shuffle_128(tt, 1, 3, 2, 0); + g1(&rows[0], &rows[1], &rows[2], &rows[3], t2); + t3 = unpackhi_32(m1, m3); + tt = unpacklo_32(m2, t3); + t3 = shuffle_128(tt, 0, 1, 3, 2); + g2(&rows[0], &rows[1], &rows[2], &rows[3], t3); + undiagonalize(&rows[0], &rows[2], &rows[3]); + m0 = t0; + m1 = t1; + m2 = t2; + m3 = t3; + + // Round 7 + t0 = shuffle_256(m0, m1, 3, 1, 1, 2); + t0 = vextq_u32(t0, t0, 1); + g1(&rows[0], &rows[1], &rows[2], &rows[3], t0); + t1 = shuffle_256(m2, m3, 3, 3, 2, 2); + tt = shuffle_128(m0, 0, 0, 3, 3); + t1 = blend_16(tt, t1, 0xCC); + g2(&rows[0], &rows[1], &rows[2], &rows[3], t1); + diagonalize(&rows[0], &rows[2], &rows[3]); + t2 = unpacklo_64(m3, m1); + tt = blend_16(t2, m2, 0xC0); + t2 = shuffle_128(tt, 1, 3, 2, 0); + g1(&rows[0], &rows[1], &rows[2], &rows[3], t2); + t3 = unpackhi_32(m1, m3); + tt = unpacklo_32(m2, t3); + t3 = shuffle_128(tt, 0, 1, 3, 2); + g2(&rows[0], &rows[1], &rows[2], &rows[3], t3); + undiagonalize(&rows[0], &rows[2], &rows[3]); +} + +void blake3_compress_in_place_neon(uint32_t cv[8], + const uint8_t block[BLAKE3_BLOCK_LEN], + uint8_t block_len, uint64_t counter, + uint8_t flags) { + uint32x4_t rows[4]; + compress_pre(rows, cv, block, block_len, counter, flags); + storeu_128(veorq_u32(rows[0], rows[2]), (uint8_t *)&cv[0]); + storeu_128(veorq_u32(rows[1], rows[3]), (uint8_t *)&cv[4]); +} + +void blake3_compress_xof_neon(const uint32_t cv[8], + const uint8_t block[BLAKE3_BLOCK_LEN], + uint8_t block_len, uint64_t counter, + uint8_t flags, uint8_t out[64]) { + uint32x4_t rows[4]; + compress_pre(rows, cv, block, block_len, counter, flags); + storeu_128(veorq_u32(rows[0], rows[2]), &out[0]); + storeu_128(veorq_u32(rows[1], rows[3]), &out[16]); + storeu_128(veorq_u32(rows[2], loadu_128((uint8_t *)&cv[0])), &out[32]); + storeu_128(veorq_u32(rows[3], loadu_128((uint8_t *)&cv[4])), &out[48]); +} + /* * ---------------------------------------------------------------------------- * hash4_neon @@ -234,6 +487,47 @@ INLINE void transpose_msg_vecs4(const uint8_t *const *inputs, transpose_vecs_128(&out[12]); } +// NOTE: The version below avoids the explicit transposes by relying on the interleaving from +// `vst4q_u32` but it seems to make no difference, or perhaps might be even a little slower. + +// INLINE void transpose_msg_vecs4(const uint8_t *const *inputs, +// size_t block_offset, uint32x4_t out[4]) { +// uint8x16x4_t l0 = vld1q_u8_x4(&inputs[0][block_offset]); +// uint8x16x4_t l1 = vld1q_u8_x4(&inputs[1][block_offset]); +// uint8x16x4_t l2 = vld1q_u8_x4(&inputs[2][block_offset]); +// uint8x16x4_t l3 = vld1q_u8_x4(&inputs[3][block_offset]); + +// uint32x4x4_t s0 = { +// vreinterpretq_u32_u8(l0.val[0]), +// vreinterpretq_u32_u8(l1.val[0]), +// vreinterpretq_u32_u8(l2.val[0]), +// vreinterpretq_u32_u8(l3.val[0]), +// }; +// uint32x4x4_t s1 = { +// vreinterpretq_u32_u8(l0.val[1]), +// vreinterpretq_u32_u8(l1.val[1]), +// vreinterpretq_u32_u8(l2.val[1]), +// vreinterpretq_u32_u8(l3.val[1]), +// }; +// uint32x4x4_t s2 = { +// vreinterpretq_u32_u8(l0.val[2]), +// vreinterpretq_u32_u8(l1.val[2]), +// vreinterpretq_u32_u8(l2.val[2]), +// vreinterpretq_u32_u8(l3.val[2]), +// }; +// uint32x4x4_t s3 = { +// vreinterpretq_u32_u8(l0.val[3]), +// vreinterpretq_u32_u8(l1.val[3]), +// vreinterpretq_u32_u8(l2.val[3]), +// vreinterpretq_u32_u8(l3.val[3]), +// }; + +// vst4q_u32((uint32_t *)&out[0], s0); +// vst4q_u32((uint32_t *)&out[4], s1); +// vst4q_u32((uint32_t *)&out[8], s2); +// vst4q_u32((uint32_t *)&out[12], s3); +// } + INLINE void load_counters4(uint64_t counter, bool increment_counter, uint32x4_t *out_low, uint32x4_t *out_high) { uint64_t mask = (increment_counter ? ~0 : 0); @@ -312,11 +606,6 @@ void blake3_hash4_neon(const uint8_t *const *inputs, size_t blocks, * ---------------------------------------------------------------------------- */ -void blake3_compress_in_place_portable(uint32_t cv[8], - const uint8_t block[BLAKE3_BLOCK_LEN], - uint8_t block_len, uint64_t counter, - uint8_t flags); - INLINE void hash_one_neon(const uint8_t *input, size_t blocks, const uint32_t key[8], uint64_t counter, uint8_t flags, uint8_t flags_start, uint8_t flags_end, @@ -328,11 +617,8 @@ INLINE void hash_one_neon(const uint8_t *input, size_t blocks, if (blocks == 1) { block_flags |= flags_end; } - // TODO: Implement compress_neon. However note that according to - // https://github.com/BLAKE2/BLAKE2/commit/7965d3e6e1b4193438b8d3a656787587d2579227, - // compress_neon might not be any faster than compress_portable. - blake3_compress_in_place_portable(cv, input, BLAKE3_BLOCK_LEN, counter, - block_flags); + blake3_compress_in_place_neon(cv, input, BLAKE3_BLOCK_LEN, counter, + block_flags); input = &input[BLAKE3_BLOCK_LEN]; blocks -= 1; block_flags = flags; diff --git a/src/ffi_neon.rs b/src/ffi_neon.rs index 54d07a4..8b400cc 100644 --- a/src/ffi_neon.rs +++ b/src/ffi_neon.rs @@ -1,5 +1,34 @@ use crate::{CVWords, IncrementCounter, BLOCK_LEN, OUT_LEN}; +pub unsafe fn compress_in_place( + cv: &mut CVWords, + block: &[u8; BLOCK_LEN], + block_len: u8, + counter: u64, + flags: u8, +) { + ffi::blake3_compress_in_place_neon(cv.as_mut_ptr(), block.as_ptr(), block_len, counter, flags) +} + +pub unsafe fn compress_xof( + cv: &CVWords, + block: &[u8; BLOCK_LEN], + block_len: u8, + counter: u64, + flags: u8, +) -> [u8; 64] { + let mut out = [0u8; 64]; + ffi::blake3_compress_xof_neon( + cv.as_ptr(), + block.as_ptr(), + block_len, + counter, + flags, + out.as_mut_ptr(), + ); + out +} + // Unsafe because this may only be called on platforms supporting NEON. pub unsafe fn hash_many( inputs: &[&[u8; N]], @@ -29,31 +58,23 @@ pub unsafe fn hash_many( ) } -// blake3_neon.c normally depends on blake3_portable.c, because the NEON -// implementation only provides 4x compression, and it relies on the portable -// implementation for 1x compression. However, we expose the portable Rust -// implementation here instead, to avoid linking in unnecessary code. -#[no_mangle] -pub extern "C" fn blake3_compress_in_place_portable( - cv: *mut u32, - block: *const u8, - block_len: u8, - counter: u64, - flags: u8, -) { - unsafe { - crate::portable::compress_in_place( - &mut *(cv as *mut [u32; 8]), - &*(block as *const [u8; 64]), - block_len, - counter, - flags, - ) - } -} - pub mod ffi { extern "C" { + pub fn blake3_compress_in_place_neon( + cv: *mut u32, + block: *const u8, + block_len: u8, + counter: u64, + flags: u8, + ); + pub fn blake3_compress_xof_neon( + cv: *const u32, + block: *const u8, + block_len: u8, + counter: u64, + flags: u8, + out: *mut u8, + ); pub fn blake3_hash_many_neon( inputs: *const *const u8, num_inputs: usize, diff --git a/src/platform.rs b/src/platform.rs index ef910aa..9e76428 100644 --- a/src/platform.rs +++ b/src/platform.rs @@ -128,9 +128,11 @@ impl Platform { Platform::AVX512 => unsafe { crate::avx512::compress_in_place(cv, block, block_len, counter, flags) }, - // No NEON compress_in_place() implementation yet. + // Safe because detect() checked for platform support. #[cfg(blake3_neon)] - Platform::NEON => portable::compress_in_place(cv, block, block_len, counter, flags), + Platform::NEON => unsafe { + crate::neon::compress_in_place(cv, block, block_len, counter, flags) + }, } } @@ -160,9 +162,11 @@ impl Platform { Platform::AVX512 => unsafe { crate::avx512::compress_xof(cv, block, block_len, counter, flags) }, - // No NEON compress_xof() implementation yet. + // Safe because detect() checked for platform support. #[cfg(blake3_neon)] - Platform::NEON => portable::compress_xof(cv, block, block_len, counter, flags), + Platform::NEON => unsafe { + crate::neon::compress_xof(cv, block, block_len, counter, flags) + }, } }