diff --git a/rust/guts/Cargo.toml b/rust/guts/Cargo.toml new file mode 100644 index 0000000..ebcf77f --- /dev/null +++ b/rust/guts/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "blake3_guts" +version = "0.0.0" +authors = ["Jack O'Connor ", "Samuel Neves"] +description = "low-level building blocks for the BLAKE3 hash function" +repository = "https://github.com/BLAKE3-team/BLAKE3" +license = "CC0-1.0 OR Apache-2.0" +documentation = "https://docs.rs/blake3_guts" +readme = "readme.md" +edition = "2021" + +[dev-dependencies] +hex = "0.4.3" +reference_impl = { path = "../../reference_impl" } + +[features] +default = ["std"] +std = [] diff --git a/rust/guts/readme.md b/rust/guts/readme.md new file mode 100644 index 0000000..a1adbf1 --- /dev/null +++ b/rust/guts/readme.md @@ -0,0 +1,62 @@ +# The BLAKE3 Guts API + +## Introduction + +This crate contains low-level, high-performance, platform-specific +implementations of the BLAKE3 compression function. This API is complicated and +unsafe, and this crate will never have a stable release. For the standard +BLAKE3 hash function, see the [`blake3`](https://crates.io/crates/blake3) +crate, which depends on this one. + +The most important ingredient in a high-performance implementation of BLAKE3 is +parallelism. The BLAKE3 tree structure lets us hash different parts of the tree +in parallel, and modern computers have a _lot_ of parallelism to offer. +Sometimes that means using multiple threads running on multiple cores, but +multithreading isn't appropriate for all applications, and it's not the usual +default for library APIs. More commonly, BLAKE3 implementations use SIMD +instructions ("Single Instruction Multiple Data") to improve the performance of +a single thread. When we do use multithreading, the performance benefits +multiply. + +The tricky thing about SIMD is that each instruction set works differently. +Instead of writing portable code once and letting the compiler do most of the +optimization work, we need to write platform-specific implementations, and +sometimes more than one per platform. We maintain *four* different +implementations on x86 alone (targeting SSE2, SSE4.1, AVX2, and AVX-512), in +addition to ARM NEON and the RISC-V vector extensions. In the future we might +add ARM SVE2. + +All of that means a lot of duplicated logic and maintenance. So while the main +goal of this API is high performance, it's also important to keep the API as +small and simple as possible. Higher level details like the "CV stack", input +buffering, and multithreading are handled by portable code in the main `blake3` +crate. These are just building blocks. + +## The private API + +This is the API that each platform reimplements. It's completely `unsafe`, +inputs and outputs are allowed to alias, and bounds checking is the caller's +responsibility. + +- `degree` +- `compress` +- `hash_chunks` +- `hash_parents` +- `xof` +- `xof_xor` +- `universal_hash` + +## The public API + +This is the API that this crate exposes to callers, i.e. to the main `blake3` +crate. It's a thin, portable layer on top of the private API above. The Rust +version of this API is memory-safe. + +- `degree` +- `compress` +- `hash_chunks` +- `hash_parents` +- `reduce_parents` +- `xof` +- `xof_xor` +- `universal_hash` diff --git a/rust/guts/src/lib.rs b/rust/guts/src/lib.rs new file mode 100644 index 0000000..67f7a05 --- /dev/null +++ b/rust/guts/src/lib.rs @@ -0,0 +1,956 @@ +use core::cmp; +use core::marker::PhantomData; +use core::mem; +use core::ptr; +use core::sync::atomic::{AtomicPtr, Ordering::Relaxed}; + +pub mod portable; + +#[cfg(test)] +mod test; + +pub const OUT_LEN: usize = 32; +pub const BLOCK_LEN: usize = 64; +pub const CHUNK_LEN: usize = 1024; +pub const WORD_LEN: usize = 4; +pub const UNIVERSAL_HASH_LEN: usize = 16; + +pub const CHUNK_START: u32 = 1 << 0; +pub const CHUNK_END: u32 = 1 << 1; +pub const PARENT: u32 = 1 << 2; +pub const ROOT: u32 = 1 << 3; +pub const KEYED_HASH: u32 = 1 << 4; +pub const DERIVE_KEY_CONTEXT: u32 = 1 << 5; +pub const DERIVE_KEY_MATERIAL: u32 = 1 << 6; + +pub const IV: CVWords = [ + 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, +]; +pub const IV_BYTES: CVBytes = le_bytes_from_words_32(&IV); + +pub const MSG_SCHEDULE: [[usize; 16]; 7] = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8], + [3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1], + [10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6], + [12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4], + [9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7], + [11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13], +]; + +// never less than 2 +pub const MAX_SIMD_DEGREE: usize = 2; + +pub type CVBytes = [u8; 32]; +pub type CVWords = [u32; 8]; +pub type BlockBytes = [u8; 64]; +pub type BlockWords = [u32; 16]; + +pub static DETECTED_IMPL: Implementation = Implementation::new( + degree_init, + compress_init, + hash_chunks_init, + hash_parents_init, + xof_init, + xof_xor_init, + universal_hash_init, +); + +fn detect() -> Implementation { + portable::implementation() +} + +fn init_detected_impl() { + let detected = detect(); + + DETECTED_IMPL + .degree_ptr + .store(detected.degree_ptr.load(Relaxed), Relaxed); + DETECTED_IMPL + .compress_ptr + .store(detected.compress_ptr.load(Relaxed), Relaxed); + DETECTED_IMPL + .hash_chunks_ptr + .store(detected.hash_chunks_ptr.load(Relaxed), Relaxed); + DETECTED_IMPL + .hash_parents_ptr + .store(detected.hash_parents_ptr.load(Relaxed), Relaxed); + DETECTED_IMPL + .xof_ptr + .store(detected.xof_ptr.load(Relaxed), Relaxed); + DETECTED_IMPL + .xof_xor_ptr + .store(detected.xof_xor_ptr.load(Relaxed), Relaxed); + DETECTED_IMPL + .universal_hash_ptr + .store(detected.universal_hash_ptr.load(Relaxed), Relaxed); +} + +pub struct Implementation { + degree_ptr: AtomicPtr<()>, + compress_ptr: AtomicPtr<()>, + hash_chunks_ptr: AtomicPtr<()>, + hash_parents_ptr: AtomicPtr<()>, + xof_ptr: AtomicPtr<()>, + xof_xor_ptr: AtomicPtr<()>, + universal_hash_ptr: AtomicPtr<()>, +} + +impl Implementation { + const fn new( + degree_fn: DegreeFn, + compress_fn: CompressFn, + hash_chunks_fn: HashChunksFn, + hash_parents_fn: HashParentsFn, + xof_fn: XofFn, + xof_xor_fn: XofFn, + universal_hash_fn: UniversalHashFn, + ) -> Self { + Self { + degree_ptr: AtomicPtr::new(degree_fn as *mut ()), + compress_ptr: AtomicPtr::new(compress_fn as *mut ()), + hash_chunks_ptr: AtomicPtr::new(hash_chunks_fn as *mut ()), + hash_parents_ptr: AtomicPtr::new(hash_parents_fn as *mut ()), + xof_ptr: AtomicPtr::new(xof_fn as *mut ()), + xof_xor_ptr: AtomicPtr::new(xof_xor_fn as *mut ()), + universal_hash_ptr: AtomicPtr::new(universal_hash_fn as *mut ()), + } + } + + #[inline] + fn degree_fn(&self) -> DegreeFn { + unsafe { mem::transmute(self.degree_ptr.load(Relaxed)) } + } + + #[inline] + pub fn degree(&self) -> usize { + let degree = unsafe { self.degree_fn()() }; + debug_assert!(degree >= 2); + debug_assert!(degree <= MAX_SIMD_DEGREE); + debug_assert_eq!(1, degree.count_ones(), "power of 2"); + degree + } + + #[inline] + pub fn split_transposed_vectors<'v>( + &self, + vectors: &'v mut TransposedVectors, + ) -> (TransposedSplit<'v>, TransposedSplit<'v>) { + unsafe { vectors.split(self.degree()) } + } + + #[inline] + fn compress_fn(&self) -> CompressFn { + unsafe { mem::transmute(self.compress_ptr.load(Relaxed)) } + } + + #[inline] + pub fn compress( + &self, + block: &BlockBytes, + block_len: u32, + cv: &CVBytes, + counter: u64, + flags: u32, + ) -> CVBytes { + let mut out = [0u8; 32]; + unsafe { + self.compress_fn()(block, block_len, cv, counter, flags, &mut out); + } + out + } + + // The contract for HashChunksFn doesn't require the implementation to support single-chunk + // inputs. Instead we handle that case here by calling compress in a loop. + #[inline] + fn hash_one_chunk( + &self, + mut input: &[u8], + key: &CVBytes, + counter: u64, + mut flags: u32, + output: TransposedSplit, + ) { + debug_assert!(input.len() <= CHUNK_LEN); + let mut cv = *key; + flags |= CHUNK_START; + while input.len() > BLOCK_LEN { + cv = self.compress( + input[..BLOCK_LEN].try_into().unwrap(), + BLOCK_LEN as u32, + &cv, + counter, + flags, + ); + input = &input[BLOCK_LEN..]; + flags &= !CHUNK_START; + } + let mut final_block = [0u8; BLOCK_LEN]; + final_block[..input.len()].copy_from_slice(input); + cv = self.compress( + &final_block, + input.len() as u32, + &cv, + counter, + flags | CHUNK_END, + ); + unsafe { + write_transposed_cv(&words_from_le_bytes_32(&cv), output.ptr); + } + } + + #[inline] + fn hash_chunks_fn(&self) -> HashChunksFn { + unsafe { mem::transmute(self.hash_chunks_ptr.load(Relaxed)) } + } + + #[inline] + pub fn hash_chunks( + &self, + input: &[u8], + key: &CVBytes, + counter: u64, + flags: u32, + transposed_output: TransposedSplit, + ) -> usize { + debug_assert!(input.len() <= self.degree() * CHUNK_LEN); + if input.len() <= CHUNK_LEN { + // The underlying hash_chunks_fn isn't required to support this case. Instead we handle + // it by calling compress_fn in a loop. But note that we still don't support root + // finalization or the empty input here. + self.hash_one_chunk(input, key, counter, flags, transposed_output); + return 1; + } + // SAFETY: If the caller passes in more than MAX_SIMD_DEGREE * CHUNK_LEN bytes, silently + // ignore the remainder. This makes it impossible to write out of bounds in a properly + // constructed TransposedSplit. + let len = cmp::min(input.len(), MAX_SIMD_DEGREE * CHUNK_LEN); + unsafe { + self.hash_chunks_fn()( + input.as_ptr(), + len, + key, + counter, + flags, + transposed_output.ptr, + ); + } + if input.len() % CHUNK_LEN == 0 { + input.len() / CHUNK_LEN + } else { + (input.len() / CHUNK_LEN) + 1 + } + } + + #[inline] + fn hash_parents_fn(&self) -> HashParentsFn { + unsafe { mem::transmute(self.hash_parents_ptr.load(Relaxed)) } + } + + #[inline] + pub fn hash_parents( + &self, + transposed_input: &TransposedVectors, + mut num_cvs: usize, + key: &CVBytes, + flags: u32, + transposed_output: TransposedSplit, + ) -> usize { + debug_assert!(num_cvs <= 2 * MAX_SIMD_DEGREE); + // SAFETY: Cap num_cvs at 2 * MAX_SIMD_DEGREE, to guarantee no out-of-bounds accesses. + num_cvs = cmp::min(num_cvs, 2 * MAX_SIMD_DEGREE); + let mut odd_cv = [0u32; 8]; + if num_cvs % 2 == 1 { + unsafe { + odd_cv = read_transposed_cv(transposed_input.as_ptr().add(num_cvs - 1)); + } + } + let num_parents = num_cvs / 2; + unsafe { + self.hash_parents_fn()( + transposed_input.as_ptr(), + num_parents, + key, + flags | PARENT, + transposed_output.ptr, + ); + } + if num_cvs % 2 == 1 { + unsafe { + write_transposed_cv(&odd_cv, transposed_output.ptr.add(num_parents)); + } + num_parents + 1 + } else { + num_parents + } + } + + #[inline] + pub fn reduce_parents( + &self, + transposed_in_out: &mut TransposedVectors, + mut num_cvs: usize, + key: &CVBytes, + flags: u32, + ) -> usize { + debug_assert!(num_cvs <= 2 * MAX_SIMD_DEGREE); + // SAFETY: Cap num_cvs at 2 * MAX_SIMD_DEGREE, to guarantee no out-of-bounds accesses. + num_cvs = cmp::min(num_cvs, 2 * MAX_SIMD_DEGREE); + let in_out_ptr = transposed_in_out.as_mut_ptr(); + let mut odd_cv = [0u32; 8]; + if num_cvs % 2 == 1 { + unsafe { + odd_cv = read_transposed_cv(in_out_ptr.add(num_cvs - 1)); + } + } + let num_parents = num_cvs / 2; + unsafe { + self.hash_parents_fn()(in_out_ptr, num_parents, key, flags | PARENT, in_out_ptr); + } + if num_cvs % 2 == 1 { + unsafe { + write_transposed_cv(&odd_cv, in_out_ptr.add(num_parents)); + } + num_parents + 1 + } else { + num_parents + } + } + + #[inline] + fn xof_fn(&self) -> XofFn { + unsafe { mem::transmute(self.xof_ptr.load(Relaxed)) } + } + + #[inline] + pub fn xof( + &self, + block: &BlockBytes, + block_len: u32, + cv: &CVBytes, + mut counter: u64, + flags: u32, + mut out: &mut [u8], + ) { + let degree = self.degree(); + let simd_len = degree * BLOCK_LEN; + while !out.is_empty() { + let take = cmp::min(simd_len, out.len()); + unsafe { + self.xof_fn()( + block, + block_len, + cv, + counter, + flags | ROOT, + out.as_mut_ptr(), + take, + ); + } + out = &mut out[take..]; + counter += degree as u64; + } + } + + #[inline] + fn xof_xor_fn(&self) -> XofFn { + unsafe { mem::transmute(self.xof_xor_ptr.load(Relaxed)) } + } + + #[inline] + pub fn xof_xor( + &self, + block: &BlockBytes, + block_len: u32, + cv: &CVBytes, + mut counter: u64, + flags: u32, + mut out: &mut [u8], + ) { + let degree = self.degree(); + let simd_len = degree * BLOCK_LEN; + while !out.is_empty() { + let take = cmp::min(simd_len, out.len()); + unsafe { + self.xof_xor_fn()( + block, + block_len, + cv, + counter, + flags | ROOT, + out.as_mut_ptr(), + take, + ); + } + out = &mut out[take..]; + counter += degree as u64; + } + } + + #[inline] + fn universal_hash_fn(&self) -> UniversalHashFn { + unsafe { mem::transmute(self.universal_hash_ptr.load(Relaxed)) } + } + + #[inline] + pub fn universal_hash(&self, mut input: &[u8], key: &CVBytes, mut counter: u64) -> [u8; 16] { + let degree = self.degree(); + let simd_len = degree * BLOCK_LEN; + let mut ret = [0u8; 16]; + while !input.is_empty() { + let take = cmp::min(simd_len, input.len()); + let mut output = [0u8; 16]; + unsafe { + self.universal_hash_fn()(input.as_ptr(), take, key, counter, &mut output); + } + input = &input[take..]; + counter += degree as u64; + for byte_index in 0..16 { + ret[byte_index] ^= output[byte_index]; + } + } + ret + } +} + +impl Clone for Implementation { + fn clone(&self) -> Self { + Self { + degree_ptr: AtomicPtr::new(self.degree_ptr.load(Relaxed)), + compress_ptr: AtomicPtr::new(self.compress_ptr.load(Relaxed)), + hash_chunks_ptr: AtomicPtr::new(self.hash_chunks_ptr.load(Relaxed)), + hash_parents_ptr: AtomicPtr::new(self.hash_parents_ptr.load(Relaxed)), + xof_ptr: AtomicPtr::new(self.xof_ptr.load(Relaxed)), + xof_xor_ptr: AtomicPtr::new(self.xof_xor_ptr.load(Relaxed)), + universal_hash_ptr: AtomicPtr::new(self.universal_hash_ptr.load(Relaxed)), + } + } +} + +// never less than 2 +type DegreeFn = unsafe extern "C" fn() -> usize; + +unsafe extern "C" fn degree_init() -> usize { + init_detected_impl(); + DETECTED_IMPL.degree_fn()() +} + +type CompressFn = unsafe extern "C" fn( + block: *const BlockBytes, // zero padded to 64 bytes + block_len: u32, + cv: *const CVBytes, + counter: u64, + flags: u32, + out: *mut CVBytes, // may overlap the input +); + +unsafe extern "C" fn compress_init( + block: *const BlockBytes, + block_len: u32, + cv: *const CVBytes, + counter: u64, + flags: u32, + out: *mut CVBytes, +) { + init_detected_impl(); + DETECTED_IMPL.compress_fn()(block, block_len, cv, counter, flags, out); +} + +type CompressXofFn = unsafe extern "C" fn( + block: *const BlockBytes, // zero padded to 64 bytes + block_len: u32, + cv: *const CVBytes, + counter: u64, + flags: u32, + out: *mut BlockBytes, // may overlap the input +); + +type HashChunksFn = unsafe extern "C" fn( + input: *const u8, + input_len: usize, + key: *const CVBytes, + counter: u64, + flags: u32, + transposed_output: *mut u32, +); + +unsafe extern "C" fn hash_chunks_init( + input: *const u8, + input_len: usize, + key: *const CVBytes, + counter: u64, + flags: u32, + transposed_output: *mut u32, +) { + init_detected_impl(); + DETECTED_IMPL.hash_chunks_fn()(input, input_len, key, counter, flags, transposed_output); +} + +type HashParentsFn = unsafe extern "C" fn( + transposed_input: *const u32, + num_parents: usize, + key: *const CVBytes, + flags: u32, + transposed_output: *mut u32, // may overlap the input +); + +unsafe extern "C" fn hash_parents_init( + transposed_input: *const u32, + num_parents: usize, + key: *const CVBytes, + flags: u32, + transposed_output: *mut u32, +) { + init_detected_impl(); + DETECTED_IMPL.hash_parents_fn()(transposed_input, num_parents, key, flags, transposed_output); +} + +// This signature covers both xof() and xof_xor(). +type XofFn = unsafe extern "C" fn( + block: *const BlockBytes, // zero padded to 64 bytes + block_len: u32, + cv: *const CVBytes, + counter: u64, + flags: u32, + out: *mut u8, + out_len: usize, +); + +unsafe extern "C" fn xof_init( + block: *const BlockBytes, + block_len: u32, + cv: *const CVBytes, + counter: u64, + flags: u32, + out: *mut u8, + out_len: usize, +) { + init_detected_impl(); + DETECTED_IMPL.xof_fn()(block, block_len, cv, counter, flags, out, out_len); +} + +unsafe extern "C" fn xof_xor_init( + block: *const BlockBytes, + block_len: u32, + cv: *const CVBytes, + counter: u64, + flags: u32, + out: *mut u8, + out_len: usize, +) { + init_detected_impl(); + DETECTED_IMPL.xof_xor_fn()(block, block_len, cv, counter, flags, out, out_len); +} + +type UniversalHashFn = unsafe extern "C" fn( + input: *const u8, + input_len: usize, + key: *const CVBytes, + counter: u64, + out: *mut [u8; 16], +); + +unsafe extern "C" fn universal_hash_init( + input: *const u8, + input_len: usize, + key: *const CVBytes, + counter: u64, + out: *mut [u8; 16], +) { + init_detected_impl(); + DETECTED_IMPL.universal_hash_fn()(input, input_len, key, counter, out); +} + +// The implicit degree of this implementation is MAX_SIMD_DEGREE. +#[inline(always)] +unsafe fn hash_chunks_using_compress( + compress: CompressFn, + mut input: *const u8, + mut input_len: usize, + key: *const CVBytes, + mut counter: u64, + flags: u32, + mut transposed_output: *mut u32, +) { + debug_assert!(input_len > 0); + debug_assert!(input_len <= MAX_SIMD_DEGREE * CHUNK_LEN); + input_len = cmp::min(input_len, MAX_SIMD_DEGREE * CHUNK_LEN); + while input_len > 0 { + let mut chunk_len = cmp::min(input_len, CHUNK_LEN); + input_len -= chunk_len; + // We only use 8 words of the CV, but compress returns 16. + let mut cv = *key; + let cv_ptr: *mut CVBytes = &mut cv; + let mut chunk_flags = flags | CHUNK_START; + while chunk_len > BLOCK_LEN { + compress( + input as *const BlockBytes, + BLOCK_LEN as u32, + cv_ptr, + counter, + chunk_flags, + cv_ptr, + ); + input = input.add(BLOCK_LEN); + chunk_len -= BLOCK_LEN; + chunk_flags &= !CHUNK_START; + } + let mut last_block = [0u8; BLOCK_LEN]; + ptr::copy_nonoverlapping(input, last_block.as_mut_ptr(), chunk_len); + input = input.add(chunk_len); + compress( + &last_block, + chunk_len as u32, + cv_ptr, + counter, + chunk_flags | CHUNK_END, + cv_ptr, + ); + let cv_words = words_from_le_bytes_32(&cv); + for word_index in 0..8 { + transposed_output + .add(word_index * TRANSPOSED_STRIDE) + .write(cv_words[word_index]); + } + transposed_output = transposed_output.add(1); + counter += 1; + } +} + +// The implicit degree of this implementation is MAX_SIMD_DEGREE. +#[inline(always)] +unsafe fn hash_parents_using_compress( + compress: CompressFn, + mut transposed_input: *const u32, + mut num_parents: usize, + key: *const CVBytes, + flags: u32, + mut transposed_output: *mut u32, // may overlap the input +) { + debug_assert!(num_parents > 0); + debug_assert!(num_parents <= MAX_SIMD_DEGREE); + while num_parents > 0 { + let mut block_bytes = [0u8; 64]; + for word_index in 0..8 { + let left_child_word = transposed_input.add(word_index * TRANSPOSED_STRIDE).read(); + block_bytes[WORD_LEN * word_index..][..WORD_LEN] + .copy_from_slice(&left_child_word.to_le_bytes()); + let right_child_word = transposed_input + .add(word_index * TRANSPOSED_STRIDE + 1) + .read(); + block_bytes[WORD_LEN * (word_index + 8)..][..WORD_LEN] + .copy_from_slice(&right_child_word.to_le_bytes()); + } + let mut cv = [0u8; 32]; + compress(&block_bytes, BLOCK_LEN as u32, key, 0, flags, &mut cv); + let cv_words = words_from_le_bytes_32(&cv); + for word_index in 0..8 { + transposed_output + .add(word_index * TRANSPOSED_STRIDE) + .write(cv_words[word_index]); + } + transposed_input = transposed_input.add(2); + transposed_output = transposed_output.add(1); + num_parents -= 1; + } +} + +#[inline(always)] +unsafe fn xof_using_compress_xof( + compress_xof: CompressXofFn, + block: *const BlockBytes, + block_len: u32, + cv: *const CVBytes, + mut counter: u64, + flags: u32, + mut out: *mut u8, + mut out_len: usize, +) { + debug_assert!(out_len <= MAX_SIMD_DEGREE * BLOCK_LEN); + while out_len > 0 { + let mut block_output = [0u8; 64]; + compress_xof(block, block_len, cv, counter, flags, &mut block_output); + let take = cmp::min(out_len, BLOCK_LEN); + ptr::copy_nonoverlapping(block_output.as_ptr(), out, take); + out = out.add(take); + out_len -= take; + counter += 1; + } +} + +#[inline(always)] +unsafe fn xof_xor_using_compress_xof( + compress_xof: CompressXofFn, + block: *const BlockBytes, + block_len: u32, + cv: *const CVBytes, + mut counter: u64, + flags: u32, + mut out: *mut u8, + mut out_len: usize, +) { + debug_assert!(out_len <= MAX_SIMD_DEGREE * BLOCK_LEN); + while out_len > 0 { + let mut block_output = [0u8; 64]; + compress_xof(block, block_len, cv, counter, flags, &mut block_output); + let take = cmp::min(out_len, BLOCK_LEN); + for i in 0..take { + *out.add(i) ^= block_output[i]; + } + out = out.add(take); + out_len -= take; + counter += 1; + } +} + +#[inline(always)] +unsafe fn universal_hash_using_compress( + compress: CompressFn, + mut input: *const u8, + mut input_len: usize, + key: *const CVBytes, + mut counter: u64, + out: *mut [u8; 16], +) { + let flags = KEYED_HASH | CHUNK_START | CHUNK_END | ROOT; + let mut result = [0u8; 16]; + while input_len > 0 { + let block_len = cmp::min(input_len, BLOCK_LEN); + let mut block = [0u8; BLOCK_LEN]; + ptr::copy_nonoverlapping(input, block.as_mut_ptr(), block_len); + let mut block_output = [0u8; 32]; + compress( + &block, + block_len as u32, + key, + counter, + flags, + &mut block_output, + ); + for i in 0..16 { + result[i] ^= block_output[i]; + } + input = input.add(block_len); + input_len -= block_len; + counter += 1; + } + *out = result; +} + +// this is in units of *words*, for pointer operations on *const/*mut u32 +const TRANSPOSED_STRIDE: usize = 2 * MAX_SIMD_DEGREE; + +#[cfg_attr(any(target_arch = "x86", target_arch = "x86_64"), repr(C, align(64)))] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct TransposedVectors([[u32; 2 * MAX_SIMD_DEGREE]; 8]); + +impl TransposedVectors { + pub fn new() -> Self { + Self([[0; 2 * MAX_SIMD_DEGREE]; 8]) + } + + pub fn extract_cv(&self, cv_index: usize) -> CVBytes { + let mut words = [0u32; 8]; + for word_index in 0..8 { + words[word_index] = self.0[word_index][cv_index]; + } + le_bytes_from_words_32(&words) + } + + pub fn extract_parent_node(&self, parent_index: usize) -> BlockBytes { + let mut bytes = [0u8; 64]; + bytes[..32].copy_from_slice(&self.extract_cv(parent_index / 2)); + bytes[32..].copy_from_slice(&self.extract_cv(parent_index / 2 + 1)); + bytes + } + + fn as_ptr(&self) -> *const u32 { + self.0[0].as_ptr() + } + + fn as_mut_ptr(&mut self) -> *mut u32 { + self.0[0].as_mut_ptr() + } + + // SAFETY: This function is just pointer arithmetic, but callers assume that it's safe (not + // necessarily correct) to write up to `degree` words to either side of the split, possibly + // from different threads. + unsafe fn split(&mut self, degree: usize) -> (TransposedSplit, TransposedSplit) { + debug_assert!(degree > 0); + debug_assert!(degree <= MAX_SIMD_DEGREE); + debug_assert_eq!(degree.count_ones(), 1, "power of 2"); + let ptr = self.as_mut_ptr(); + let left = TransposedSplit { + ptr, + phantom_data: PhantomData, + }; + let right = TransposedSplit { + ptr: ptr.wrapping_add(degree), + phantom_data: PhantomData, + }; + (left, right) + } +} + +pub struct TransposedSplit<'vectors> { + ptr: *mut u32, + phantom_data: PhantomData<&'vectors mut u32>, +} + +unsafe impl<'vectors> Send for TransposedSplit<'vectors> {} +unsafe impl<'vectors> Sync for TransposedSplit<'vectors> {} + +unsafe fn read_transposed_cv(src: *const u32) -> CVWords { + let mut cv = [0u32; 8]; + for word_index in 0..8 { + let offset_words = word_index * TRANSPOSED_STRIDE; + cv[word_index] = src.add(offset_words).read(); + } + cv +} + +unsafe fn write_transposed_cv(cv: &CVWords, dest: *mut u32) { + for word_index in 0..8 { + let offset_words = word_index * TRANSPOSED_STRIDE; + dest.add(offset_words).write(cv[word_index]); + } +} + +#[inline(always)] +pub const fn le_bytes_from_words_32(words: &CVWords) -> CVBytes { + let mut bytes = [0u8; 32]; + // This loop is super verbose because currently that's what it takes to be const. + let mut word_index = 0; + while word_index < bytes.len() / WORD_LEN { + let word_bytes = words[word_index].to_le_bytes(); + let mut byte_index = 0; + while byte_index < WORD_LEN { + bytes[word_index * WORD_LEN + byte_index] = word_bytes[byte_index]; + byte_index += 1; + } + word_index += 1; + } + bytes +} + +#[inline(always)] +pub const fn le_bytes_from_words_64(words: &BlockWords) -> BlockBytes { + let mut bytes = [0u8; 64]; + // This loop is super verbose because currently that's what it takes to be const. + let mut word_index = 0; + while word_index < bytes.len() / WORD_LEN { + let word_bytes = words[word_index].to_le_bytes(); + let mut byte_index = 0; + while byte_index < WORD_LEN { + bytes[word_index * WORD_LEN + byte_index] = word_bytes[byte_index]; + byte_index += 1; + } + word_index += 1; + } + bytes +} + +#[inline(always)] +pub const fn words_from_le_bytes_32(bytes: &CVBytes) -> CVWords { + let mut words = [0u32; 8]; + // This loop is super verbose because currently that's what it takes to be const. + let mut word_index = 0; + while word_index < words.len() { + let mut word_bytes = [0u8; WORD_LEN]; + let mut byte_index = 0; + while byte_index < WORD_LEN { + word_bytes[byte_index] = bytes[word_index * WORD_LEN + byte_index]; + byte_index += 1; + } + words[word_index] = u32::from_le_bytes(word_bytes); + word_index += 1; + } + words +} + +#[inline(always)] +pub const fn words_from_le_bytes_64(bytes: &BlockBytes) -> BlockWords { + let mut words = [0u32; 16]; + // This loop is super verbose because currently that's what it takes to be const. + let mut word_index = 0; + while word_index < words.len() { + let mut word_bytes = [0u8; WORD_LEN]; + let mut byte_index = 0; + while byte_index < WORD_LEN { + word_bytes[byte_index] = bytes[word_index * WORD_LEN + byte_index]; + byte_index += 1; + } + words[word_index] = u32::from_le_bytes(word_bytes); + word_index += 1; + } + words +} + +#[test] +fn test_byte_word_round_trips() { + let cv = *b"This is 32 LE bytes/eight words."; + assert_eq!(cv, le_bytes_from_words_32(&words_from_le_bytes_32(&cv))); + let block = *b"This is sixty-four little-endian bytes, or sixteen 32-bit words."; + assert_eq!( + block, + le_bytes_from_words_64(&words_from_le_bytes_64(&block)), + ); +} + +// The largest power of two less than or equal to `n`, used for left_len() +// immediately below, and also directly in Hasher::update(). +pub fn largest_power_of_two_leq(n: usize) -> usize { + ((n / 2) + 1).next_power_of_two() +} + +#[test] +fn test_largest_power_of_two_leq() { + let input_output = &[ + // The zero case is nonsensical, but it does work. + (0, 1), + (1, 1), + (2, 2), + (3, 2), + (4, 4), + (5, 4), + (6, 4), + (7, 4), + (8, 8), + // the largest possible usize + (usize::MAX, (usize::MAX >> 1) + 1), + ]; + for &(input, output) in input_output { + assert_eq!( + output, + crate::largest_power_of_two_leq(input), + "wrong output for n={}", + input + ); + } +} + +// Given some input larger than one chunk, return the number of bytes that +// should go in the left subtree. This is the largest power-of-2 number of +// chunks that leaves at least 1 byte for the right subtree. +pub fn left_len(content_len: usize) -> usize { + debug_assert!(content_len > CHUNK_LEN); + // Subtract 1 to reserve at least one byte for the right side. + let full_chunks = (content_len - 1) / CHUNK_LEN; + largest_power_of_two_leq(full_chunks) * CHUNK_LEN +} + +#[test] +fn test_left_len() { + let input_output = &[ + (CHUNK_LEN + 1, CHUNK_LEN), + (2 * CHUNK_LEN - 1, CHUNK_LEN), + (2 * CHUNK_LEN, CHUNK_LEN), + (2 * CHUNK_LEN + 1, 2 * CHUNK_LEN), + (4 * CHUNK_LEN - 1, 2 * CHUNK_LEN), + (4 * CHUNK_LEN, 2 * CHUNK_LEN), + (4 * CHUNK_LEN + 1, 4 * CHUNK_LEN), + ]; + for &(input, output) in input_output { + assert_eq!(left_len(input), output); + } +} diff --git a/rust/guts/src/portable.rs b/rust/guts/src/portable.rs new file mode 100644 index 0000000..d597644 --- /dev/null +++ b/rust/guts/src/portable.rs @@ -0,0 +1,262 @@ +use crate::{ + le_bytes_from_words_32, le_bytes_from_words_64, words_from_le_bytes_32, words_from_le_bytes_64, + BlockBytes, BlockWords, CVBytes, CVWords, Implementation, IV, MAX_SIMD_DEGREE, MSG_SCHEDULE, +}; + +const DEGREE: usize = MAX_SIMD_DEGREE; + +unsafe extern "C" fn degree() -> usize { + DEGREE +} + +#[inline(always)] +fn g(state: &mut BlockWords, a: usize, b: usize, c: usize, d: usize, x: u32, y: u32) { + state[a] = state[a].wrapping_add(state[b]).wrapping_add(x); + state[d] = (state[d] ^ state[a]).rotate_right(16); + state[c] = state[c].wrapping_add(state[d]); + state[b] = (state[b] ^ state[c]).rotate_right(12); + state[a] = state[a].wrapping_add(state[b]).wrapping_add(y); + state[d] = (state[d] ^ state[a]).rotate_right(8); + state[c] = state[c].wrapping_add(state[d]); + state[b] = (state[b] ^ state[c]).rotate_right(7); +} + +#[inline(always)] +fn round(state: &mut [u32; 16], msg: &BlockWords, round: usize) { + // Select the message schedule based on the round. + let schedule = MSG_SCHEDULE[round]; + + // Mix the columns. + g(state, 0, 4, 8, 12, msg[schedule[0]], msg[schedule[1]]); + g(state, 1, 5, 9, 13, msg[schedule[2]], msg[schedule[3]]); + g(state, 2, 6, 10, 14, msg[schedule[4]], msg[schedule[5]]); + g(state, 3, 7, 11, 15, msg[schedule[6]], msg[schedule[7]]); + + // Mix the diagonals. + g(state, 0, 5, 10, 15, msg[schedule[8]], msg[schedule[9]]); + g(state, 1, 6, 11, 12, msg[schedule[10]], msg[schedule[11]]); + g(state, 2, 7, 8, 13, msg[schedule[12]], msg[schedule[13]]); + g(state, 3, 4, 9, 14, msg[schedule[14]], msg[schedule[15]]); +} + +#[inline(always)] +fn compress_inner( + block_words: &BlockWords, + block_len: u32, + cv_words: &CVWords, + counter: u64, + flags: u32, +) -> [u32; 16] { + let mut state = [ + cv_words[0], + cv_words[1], + cv_words[2], + cv_words[3], + cv_words[4], + cv_words[5], + cv_words[6], + cv_words[7], + IV[0], + IV[1], + IV[2], + IV[3], + counter as u32, + (counter >> 32) as u32, + block_len as u32, + flags as u32, + ]; + for round_number in 0..7 { + round(&mut state, &block_words, round_number); + } + state +} + +pub(crate) unsafe extern "C" fn compress( + block: *const BlockBytes, + block_len: u32, + cv: *const CVBytes, + counter: u64, + flags: u32, + out: *mut CVBytes, +) { + let block_words = words_from_le_bytes_64(&*block); + let cv_words = words_from_le_bytes_32(&*cv); + let mut state = compress_inner(&block_words, block_len, &cv_words, counter, flags); + for word_index in 0..8 { + state[word_index] ^= state[word_index + 8]; + } + *out = le_bytes_from_words_32(state[..8].try_into().unwrap()); +} + +pub(crate) unsafe extern "C" fn compress_xof( + block: *const BlockBytes, + block_len: u32, + cv: *const CVBytes, + counter: u64, + flags: u32, + out: *mut BlockBytes, +) { + let block_words = words_from_le_bytes_64(&*block); + let cv_words = words_from_le_bytes_32(&*cv); + let mut state = compress_inner(&block_words, block_len, &cv_words, counter, flags); + for word_index in 0..8 { + state[word_index] ^= state[word_index + 8]; + state[word_index + 8] ^= cv_words[word_index]; + } + *out = le_bytes_from_words_64(&state); +} + +pub(crate) unsafe extern "C" fn hash_chunks( + input: *const u8, + input_len: usize, + key: *const CVBytes, + counter: u64, + flags: u32, + transposed_output: *mut u32, +) { + crate::hash_chunks_using_compress( + compress, + input, + input_len, + key, + counter, + flags, + transposed_output, + ) +} + +pub(crate) unsafe extern "C" fn hash_parents( + transposed_input: *const u32, + num_parents: usize, + key: *const CVBytes, + flags: u32, + transposed_output: *mut u32, // may overlap the input +) { + crate::hash_parents_using_compress( + compress, + transposed_input, + num_parents, + key, + flags, + transposed_output, + ) +} + +pub(crate) unsafe extern "C" fn xof( + block: *const BlockBytes, + block_len: u32, + cv: *const CVBytes, + counter: u64, + flags: u32, + out: *mut u8, + out_len: usize, +) { + crate::xof_using_compress_xof( + compress_xof, + block, + block_len, + cv, + counter, + flags, + out, + out_len, + ) +} + +pub(crate) unsafe extern "C" fn xof_xor( + block: *const BlockBytes, + block_len: u32, + cv: *const CVBytes, + counter: u64, + flags: u32, + out: *mut u8, + out_len: usize, +) { + crate::xof_xor_using_compress_xof( + compress_xof, + block, + block_len, + cv, + counter, + flags, + out, + out_len, + ) +} + +pub(crate) unsafe extern "C" fn universal_hash( + input: *const u8, + input_len: usize, + key: *const CVBytes, + counter: u64, + out: *mut [u8; 16], +) { + crate::universal_hash_using_compress(compress, input, input_len, key, counter, out) +} + +pub fn implementation() -> Implementation { + Implementation::new( + degree, + compress, + hash_chunks, + hash_parents, + xof, + xof_xor, + universal_hash, + ) +} + +#[cfg(test)] +mod test { + use super::*; + + // This is circular but do it anyway. + #[test] + fn test_compress_vs_portable() { + crate::test::test_compress_vs_portable(&implementation()); + } + + #[test] + fn test_compress_vs_reference() { + crate::test::test_compress_vs_reference(&implementation()); + } + + // This is circular but do it anyway. + #[test] + fn test_hash_chunks_vs_portable() { + crate::test::test_hash_chunks_vs_portable(&implementation()); + } + + // This is circular but do it anyway. + #[test] + fn test_hash_parents_vs_portable() { + crate::test::test_hash_parents_vs_portable(&implementation()); + } + + #[test] + fn test_chunks_and_parents_vs_reference() { + crate::test::test_chunks_and_parents_vs_reference(&implementation()); + } + + // This is circular but do it anyway. + #[test] + fn test_xof_vs_portable() { + crate::test::test_xof_vs_portable(&implementation()); + } + + #[test] + fn test_xof_vs_reference() { + crate::test::test_xof_vs_reference(&implementation()); + } + + // This is circular but do it anyway. + #[test] + fn test_universal_hash_vs_portable() { + crate::test::test_universal_hash_vs_portable(&implementation()); + } + + #[test] + fn test_universal_hash_vs_reference() { + crate::test::test_universal_hash_vs_reference(&implementation()); + } +} diff --git a/rust/guts/src/test.rs b/rust/guts/src/test.rs new file mode 100644 index 0000000..83bd790 --- /dev/null +++ b/rust/guts/src/test.rs @@ -0,0 +1,523 @@ +use crate::*; + +pub const TEST_KEY: CVBytes = *b"whats the Elvish word for friend"; + +// Test a few different initial counter values. +// - 0: The base case. +// - i32::MAX: *No* overflow. But carry bugs in tricky SIMD code can screw this up, if you XOR when +// you're supposed to ANDNOT. +// - u32::MAX: The low word of the counter overflows for all inputs except the first. +// - (42 << 32) + u32::MAX: Same but with a non-zero value in the high word. +const INITIAL_COUNTERS: [u64; 4] = [ + 0, + i32::MAX as u64, + u32::MAX as u64, + (42u64 << 32) + u32::MAX as u64, +]; + +const BLOCK_LENGTHS: [usize; 4] = [0, 1, 63, 64]; + +pub fn paint_test_input(buf: &mut [u8]) { + for (i, b) in buf.iter_mut().enumerate() { + *b = (i % 251) as u8; + } +} + +pub fn test_compress_vs_portable(test_impl: &Implementation) { + for block_len in BLOCK_LENGTHS { + dbg!(block_len); + let mut block = [0; BLOCK_LEN]; + paint_test_input(&mut block[..block_len]); + for counter in INITIAL_COUNTERS { + dbg!(counter); + let portable_cv = portable::implementation().compress( + &block, + block_len as u32, + &TEST_KEY, + counter, + KEYED_HASH, + ); + + let test_cv = + test_impl.compress(&block, block_len as u32, &TEST_KEY, counter, KEYED_HASH); + + assert_eq!(portable_cv, test_cv); + } + } +} + +pub fn test_compress_vs_reference(test_impl: &Implementation) { + for block_len in BLOCK_LENGTHS { + dbg!(block_len); + let mut block = [0; BLOCK_LEN]; + paint_test_input(&mut block[..block_len]); + + let mut ref_hasher = reference_impl::Hasher::new_keyed(&TEST_KEY); + ref_hasher.update(&block[..block_len]); + let mut ref_hash = [0u8; 32]; + ref_hasher.finalize(&mut ref_hash); + + let test_cv = test_impl.compress( + &block, + block_len as u32, + &TEST_KEY, + 0, + CHUNK_START | CHUNK_END | ROOT | KEYED_HASH, + ); + + assert_eq!(ref_hash, test_cv); + } +} + +fn check_transposed_eq(output_a: &TransposedVectors, output_b: &TransposedVectors) { + if output_a == output_b { + return; + } + for cv_index in 0..2 * MAX_SIMD_DEGREE { + let cv_a = output_a.extract_cv(cv_index); + let cv_b = output_b.extract_cv(cv_index); + if cv_a == [0; 32] && cv_b == [0; 32] { + println!("CV {cv_index:2} empty"); + } else if cv_a == cv_b { + println!("CV {cv_index:2} matches"); + } else { + println!("CV {cv_index:2} mismatch:"); + println!(" {}", hex::encode(cv_a)); + println!(" {}", hex::encode(cv_b)); + } + } + panic!("transposed outputs are not equal"); +} + +pub fn test_hash_chunks_vs_portable(test_impl: &Implementation) { + assert!(test_impl.degree() <= MAX_SIMD_DEGREE); + dbg!(test_impl.degree() * CHUNK_LEN); + // Allocate 4 extra bytes of padding so we can make aligned slices. + let mut input_buf = [0u8; 2 * 2 * MAX_SIMD_DEGREE * CHUNK_LEN + 4]; + let mut input_slice = &mut input_buf[..]; + // Make sure the start of the input is word-aligned. + while input_slice.as_ptr() as usize % 4 != 0 { + input_slice = &mut input_slice[1..]; + } + let (aligned_input, mut unaligned_input) = + input_slice.split_at_mut(2 * MAX_SIMD_DEGREE * CHUNK_LEN); + unaligned_input = &mut unaligned_input[1..][..2 * MAX_SIMD_DEGREE * CHUNK_LEN]; + assert_eq!(aligned_input.as_ptr() as usize % 4, 0); + assert_eq!(unaligned_input.as_ptr() as usize % 4, 1); + paint_test_input(aligned_input); + paint_test_input(unaligned_input); + // Try just below, equal to, and just above every whole number of chunks. + let mut input_2_lengths = Vec::new(); + let mut next_len = 2 * CHUNK_LEN; + loop { + // 95 is one whole block plus one interesting part of another + input_2_lengths.push(next_len - 95); + input_2_lengths.push(next_len); + if next_len == test_impl.degree() * CHUNK_LEN { + break; + } + input_2_lengths.push(next_len + 95); + next_len += CHUNK_LEN; + } + for input_2_len in input_2_lengths { + dbg!(input_2_len); + let aligned_input1 = &aligned_input[..test_impl.degree() * CHUNK_LEN]; + let aligned_input2 = &aligned_input[test_impl.degree() * CHUNK_LEN..][..input_2_len]; + let unaligned_input1 = &unaligned_input[..test_impl.degree() * CHUNK_LEN]; + let unaligned_input2 = &unaligned_input[test_impl.degree() * CHUNK_LEN..][..input_2_len]; + for initial_counter in INITIAL_COUNTERS { + dbg!(initial_counter); + // Make two calls, to test the output_column parameter. + let mut portable_output = TransposedVectors::new(); + let (portable_left, portable_right) = + test_impl.split_transposed_vectors(&mut portable_output); + portable::implementation().hash_chunks( + aligned_input1, + &IV_BYTES, + initial_counter, + 0, + portable_left, + ); + portable::implementation().hash_chunks( + aligned_input2, + &TEST_KEY, + initial_counter + test_impl.degree() as u64, + KEYED_HASH, + portable_right, + ); + + let mut test_output = TransposedVectors::new(); + let (test_left, test_right) = test_impl.split_transposed_vectors(&mut test_output); + test_impl.hash_chunks(aligned_input1, &IV_BYTES, initial_counter, 0, test_left); + test_impl.hash_chunks( + aligned_input2, + &TEST_KEY, + initial_counter + test_impl.degree() as u64, + KEYED_HASH, + test_right, + ); + check_transposed_eq(&portable_output, &test_output); + + // Do the same thing with unaligned input. + let mut unaligned_test_output = TransposedVectors::new(); + let (unaligned_left, unaligned_right) = + test_impl.split_transposed_vectors(&mut unaligned_test_output); + test_impl.hash_chunks( + unaligned_input1, + &IV_BYTES, + initial_counter, + 0, + unaligned_left, + ); + test_impl.hash_chunks( + unaligned_input2, + &TEST_KEY, + initial_counter + test_impl.degree() as u64, + KEYED_HASH, + unaligned_right, + ); + check_transposed_eq(&portable_output, &unaligned_test_output); + } + } +} + +fn painted_transposed_input() -> TransposedVectors { + let mut vectors = TransposedVectors::new(); + let mut val = 0; + for col in 0..2 * MAX_SIMD_DEGREE { + for row in 0..8 { + vectors.0[row][col] = val; + val += 1; + } + } + vectors +} + +pub fn test_hash_parents_vs_portable(test_impl: &Implementation) { + assert!(test_impl.degree() <= MAX_SIMD_DEGREE); + let input = painted_transposed_input(); + for num_parents in 2..=(test_impl.degree() / 2) { + dbg!(num_parents); + let mut portable_output = TransposedVectors::new(); + let (portable_left, portable_right) = + test_impl.split_transposed_vectors(&mut portable_output); + portable::implementation().hash_parents( + &input, + 2 * num_parents, // num_cvs + &IV_BYTES, + 0, + portable_left, + ); + portable::implementation().hash_parents( + &input, + 2 * num_parents, // num_cvs + &TEST_KEY, + KEYED_HASH, + portable_right, + ); + + let mut test_output = TransposedVectors::new(); + let (test_left, test_right) = test_impl.split_transposed_vectors(&mut test_output); + test_impl.hash_parents( + &input, + 2 * num_parents, // num_cvs + &IV_BYTES, + 0, + test_left, + ); + test_impl.hash_parents( + &input, + 2 * num_parents, // num_cvs + &TEST_KEY, + KEYED_HASH, + test_right, + ); + + check_transposed_eq(&portable_output, &test_output); + } +} + +fn hash_with_chunks_and_parents_recurse( + test_impl: &Implementation, + input: &[u8], + counter: u64, + output: TransposedSplit, +) -> usize { + assert!(input.len() > 0); + if input.len() <= test_impl.degree() * CHUNK_LEN { + return test_impl.hash_chunks(input, &IV_BYTES, counter, 0, output); + } + let (left_input, right_input) = input.split_at(left_len(input.len())); + let mut child_output = TransposedVectors::new(); + let (left_output, right_output) = test_impl.split_transposed_vectors(&mut child_output); + let mut children = + hash_with_chunks_and_parents_recurse(test_impl, left_input, counter, left_output); + assert_eq!(children, test_impl.degree()); + children += hash_with_chunks_and_parents_recurse( + test_impl, + right_input, + counter + (left_input.len() / CHUNK_LEN) as u64, + right_output, + ); + test_impl.hash_parents(&child_output, children, &IV_BYTES, PARENT, output) +} + +// Note: This test implementation doesn't support the 1-chunk-or-less case. +fn root_hash_with_chunks_and_parents(test_impl: &Implementation, input: &[u8]) -> CVBytes { + // TODO: handle the 1-chunk case? + assert!(input.len() > CHUNK_LEN); + let mut cvs = TransposedVectors::new(); + // The right half of these vectors are never used. + let (cvs_left, _) = test_impl.split_transposed_vectors(&mut cvs); + let mut num_cvs = hash_with_chunks_and_parents_recurse(test_impl, input, 0, cvs_left); + while num_cvs > 2 { + num_cvs = test_impl.reduce_parents(&mut cvs, num_cvs, &IV_BYTES, 0); + } + test_impl.compress( + &cvs.extract_parent_node(0), + BLOCK_LEN as u32, + &IV_BYTES, + 0, + PARENT | ROOT, + ) +} + +pub fn test_chunks_and_parents_vs_reference(test_impl: &Implementation) { + assert_eq!(test_impl.degree().count_ones(), 1, "power of 2"); + const MAX_INPUT_LEN: usize = 2 * MAX_SIMD_DEGREE * CHUNK_LEN; + let mut input_buf = [0u8; MAX_INPUT_LEN]; + paint_test_input(&mut input_buf); + // Try just below, equal to, and just above every whole number of chunks, except that + // root_hash_with_chunks_and_parents doesn't support the 1-chunk-or-less case. + let mut test_lengths = vec![CHUNK_LEN + 1]; + let mut next_len = 2 * CHUNK_LEN; + loop { + // 95 is one whole block plus one interesting part of another + test_lengths.push(next_len - 95); + test_lengths.push(next_len); + if next_len == MAX_INPUT_LEN { + break; + } + test_lengths.push(next_len + 95); + next_len += CHUNK_LEN; + } + for test_len in test_lengths { + dbg!(test_len); + let input = &input_buf[..test_len]; + + let mut ref_hasher = reference_impl::Hasher::new(); + ref_hasher.update(&input); + let mut ref_hash = [0u8; 32]; + ref_hasher.finalize(&mut ref_hash); + + let test_hash = root_hash_with_chunks_and_parents(test_impl, input); + + assert_eq!(ref_hash, test_hash); + } +} + +pub fn test_xof_vs_portable(test_impl: &Implementation) { + let flags = CHUNK_START | CHUNK_END | KEYED_HASH; + for counter in INITIAL_COUNTERS { + dbg!(counter); + for input_len in [0, 1, BLOCK_LEN] { + dbg!(input_len); + let mut input_block = [0u8; BLOCK_LEN]; + for byte_index in 0..input_len { + input_block[byte_index] = byte_index as u8 + 42; + } + // Try equal to and partway through every whole number of output blocks. + const MAX_OUTPUT_LEN: usize = 2 * MAX_SIMD_DEGREE * BLOCK_LEN; + let mut output_lengths = Vec::new(); + let mut next_len = 0; + loop { + output_lengths.push(next_len); + if next_len == MAX_OUTPUT_LEN { + break; + } + output_lengths.push(next_len + 31); + next_len += BLOCK_LEN; + } + for output_len in output_lengths { + dbg!(output_len); + let mut portable_output = [0xff; MAX_OUTPUT_LEN]; + portable::implementation().xof( + &input_block, + input_len as u32, + &TEST_KEY, + counter, + flags, + &mut portable_output[..output_len], + ); + let mut test_output = [0xff; MAX_OUTPUT_LEN]; + test_impl.xof( + &input_block, + input_len as u32, + &TEST_KEY, + counter, + flags, + &mut test_output[..output_len], + ); + assert_eq!(portable_output, test_output); + + // Double check that the implementation didn't overwrite. + assert!(test_output[output_len..].iter().all(|&b| b == 0xff)); + + // The first XOR cancels out the output. + test_impl.xof_xor( + &input_block, + input_len as u32, + &TEST_KEY, + counter, + flags, + &mut test_output[..output_len], + ); + assert!(test_output[..output_len].iter().all(|&b| b == 0)); + assert!(test_output[output_len..].iter().all(|&b| b == 0xff)); + + // The second XOR restores out the output. + test_impl.xof_xor( + &input_block, + input_len as u32, + &TEST_KEY, + counter, + flags, + &mut test_output[..output_len], + ); + assert_eq!(portable_output, test_output); + assert!(test_output[output_len..].iter().all(|&b| b == 0xff)); + } + } + } +} + +pub fn test_xof_vs_reference(test_impl: &Implementation) { + let input = b"hello world"; + let mut input_block = [0; BLOCK_LEN]; + input_block[..input.len()].copy_from_slice(input); + + const MAX_OUTPUT_LEN: usize = 2 * MAX_SIMD_DEGREE * BLOCK_LEN; + let mut ref_output = [0; MAX_OUTPUT_LEN]; + let mut ref_hasher = reference_impl::Hasher::new_keyed(&TEST_KEY); + ref_hasher.update(input); + ref_hasher.finalize(&mut ref_output); + + // Try equal to and partway through every whole number of output blocks. + let mut output_lengths = vec![0, 1, 31]; + let mut next_len = BLOCK_LEN; + loop { + output_lengths.push(next_len); + if next_len == MAX_OUTPUT_LEN { + break; + } + output_lengths.push(next_len + 31); + next_len += BLOCK_LEN; + } + + for output_len in output_lengths { + dbg!(output_len); + let mut test_output = [0; MAX_OUTPUT_LEN]; + test_impl.xof( + &input_block, + input.len() as u32, + &TEST_KEY, + 0, + KEYED_HASH | CHUNK_START | CHUNK_END, + &mut test_output[..output_len], + ); + assert_eq!(ref_output[..output_len], test_output[..output_len]); + + // Double check that the implementation didn't overwrite. + assert!(test_output[output_len..].iter().all(|&b| b == 0)); + + // Do it again starting from block 1. + if output_len >= BLOCK_LEN { + test_impl.xof( + &input_block, + input.len() as u32, + &TEST_KEY, + 1, + KEYED_HASH | CHUNK_START | CHUNK_END, + &mut test_output[..output_len - BLOCK_LEN], + ); + assert_eq!( + ref_output[BLOCK_LEN..output_len], + test_output[..output_len - BLOCK_LEN], + ); + } + } +} + +pub fn test_universal_hash_vs_portable(test_impl: &Implementation) { + const MAX_INPUT_LEN: usize = 2 * MAX_SIMD_DEGREE * BLOCK_LEN; + let mut input_buf = [0; MAX_INPUT_LEN]; + paint_test_input(&mut input_buf); + // Try equal to and partway through every whole number of input blocks. + let mut input_lengths = vec![0, 1, 31]; + let mut next_len = BLOCK_LEN; + loop { + input_lengths.push(next_len); + if next_len == MAX_INPUT_LEN { + break; + } + input_lengths.push(next_len + 31); + next_len += BLOCK_LEN; + } + for input_len in input_lengths { + dbg!(input_len); + for counter in INITIAL_COUNTERS { + dbg!(counter); + let portable_output = portable::implementation().universal_hash( + &input_buf[..input_len], + &TEST_KEY, + counter, + ); + let test_output = test_impl.universal_hash(&input_buf[..input_len], &TEST_KEY, counter); + assert_eq!(portable_output, test_output); + } + } +} + +fn reference_impl_universal_hash(input: &[u8], key: &CVBytes) -> [u8; UNIVERSAL_HASH_LEN] { + // The reference_impl doesn't support XOF seeking, so we have to materialize an entire extended + // output to seek to a block. + const MAX_BLOCKS: usize = 2 * MAX_SIMD_DEGREE; + assert!(input.len() / BLOCK_LEN <= MAX_BLOCKS); + let mut output_buffer: [u8; BLOCK_LEN * MAX_BLOCKS] = [0u8; BLOCK_LEN * MAX_BLOCKS]; + let mut result = [0u8; UNIVERSAL_HASH_LEN]; + let mut block_start = 0; + while block_start < input.len() { + let block_len = cmp::min(input.len() - block_start, BLOCK_LEN); + let mut ref_hasher = reference_impl::Hasher::new_keyed(key); + ref_hasher.update(&input[block_start..block_start + block_len]); + ref_hasher.finalize(&mut output_buffer[..block_start + UNIVERSAL_HASH_LEN]); + for byte_index in 0..UNIVERSAL_HASH_LEN { + result[byte_index] ^= output_buffer[block_start + byte_index]; + } + block_start += BLOCK_LEN; + } + result +} + +pub fn test_universal_hash_vs_reference(test_impl: &Implementation) { + const MAX_INPUT_LEN: usize = 2 * MAX_SIMD_DEGREE * BLOCK_LEN; + let mut input_buf = [0; MAX_INPUT_LEN]; + paint_test_input(&mut input_buf); + // Try equal to and partway through every whole number of input blocks. + let mut input_lengths = vec![0, 1, 31]; + let mut next_len = BLOCK_LEN; + loop { + input_lengths.push(next_len); + if next_len == MAX_INPUT_LEN { + break; + } + input_lengths.push(next_len + 31); + next_len += BLOCK_LEN; + } + for input_len in input_lengths { + dbg!(input_len); + let ref_output = reference_impl_universal_hash(&input_buf[..input_len], &TEST_KEY); + let test_output = test_impl.universal_hash(&input_buf[..input_len], &TEST_KEY, 0); + assert_eq!(ref_output, test_output); + } +}