1
0
Fork 0
mirror of https://github.com/BLAKE3-team/BLAKE3 synced 2024-05-04 23:26:15 +02:00

factor out just the portable parts of the guts_api branch

This commit is contained in:
Jack O'Connor 2024-01-21 17:35:29 -08:00
parent 6f3e6fc86c
commit fc75227170
5 changed files with 1821 additions and 0 deletions

18
rust/guts/Cargo.toml Normal file
View File

@ -0,0 +1,18 @@
[package]
name = "blake3_guts"
version = "0.0.0"
authors = ["Jack O'Connor <oconnor663@gmail.com>", "Samuel Neves"]
description = "low-level building blocks for the BLAKE3 hash function"
repository = "https://github.com/BLAKE3-team/BLAKE3"
license = "CC0-1.0 OR Apache-2.0"
documentation = "https://docs.rs/blake3_guts"
readme = "readme.md"
edition = "2021"
[dev-dependencies]
hex = "0.4.3"
reference_impl = { path = "../../reference_impl" }
[features]
default = ["std"]
std = []

62
rust/guts/readme.md Normal file
View File

@ -0,0 +1,62 @@
# The BLAKE3 Guts API
## Introduction
This crate contains low-level, high-performance, platform-specific
implementations of the BLAKE3 compression function. This API is complicated and
unsafe, and this crate will never have a stable release. For the standard
BLAKE3 hash function, see the [`blake3`](https://crates.io/crates/blake3)
crate, which depends on this one.
The most important ingredient in a high-performance implementation of BLAKE3 is
parallelism. The BLAKE3 tree structure lets us hash different parts of the tree
in parallel, and modern computers have a _lot_ of parallelism to offer.
Sometimes that means using multiple threads running on multiple cores, but
multithreading isn't appropriate for all applications, and it's not the usual
default for library APIs. More commonly, BLAKE3 implementations use SIMD
instructions ("Single Instruction Multiple Data") to improve the performance of
a single thread. When we do use multithreading, the performance benefits
multiply.
The tricky thing about SIMD is that each instruction set works differently.
Instead of writing portable code once and letting the compiler do most of the
optimization work, we need to write platform-specific implementations, and
sometimes more than one per platform. We maintain *four* different
implementations on x86 alone (targeting SSE2, SSE4.1, AVX2, and AVX-512), in
addition to ARM NEON and the RISC-V vector extensions. In the future we might
add ARM SVE2.
All of that means a lot of duplicated logic and maintenance. So while the main
goal of this API is high performance, it's also important to keep the API as
small and simple as possible. Higher level details like the "CV stack", input
buffering, and multithreading are handled by portable code in the main `blake3`
crate. These are just building blocks.
## The private API
This is the API that each platform reimplements. It's completely `unsafe`,
inputs and outputs are allowed to alias, and bounds checking is the caller's
responsibility.
- `degree`
- `compress`
- `hash_chunks`
- `hash_parents`
- `xof`
- `xof_xor`
- `universal_hash`
## The public API
This is the API that this crate exposes to callers, i.e. to the main `blake3`
crate. It's a thin, portable layer on top of the private API above. The Rust
version of this API is memory-safe.
- `degree`
- `compress`
- `hash_chunks`
- `hash_parents`
- `reduce_parents`
- `xof`
- `xof_xor`
- `universal_hash`

956
rust/guts/src/lib.rs Normal file
View File

@ -0,0 +1,956 @@
use core::cmp;
use core::marker::PhantomData;
use core::mem;
use core::ptr;
use core::sync::atomic::{AtomicPtr, Ordering::Relaxed};
pub mod portable;
#[cfg(test)]
mod test;
pub const OUT_LEN: usize = 32;
pub const BLOCK_LEN: usize = 64;
pub const CHUNK_LEN: usize = 1024;
pub const WORD_LEN: usize = 4;
pub const UNIVERSAL_HASH_LEN: usize = 16;
pub const CHUNK_START: u32 = 1 << 0;
pub const CHUNK_END: u32 = 1 << 1;
pub const PARENT: u32 = 1 << 2;
pub const ROOT: u32 = 1 << 3;
pub const KEYED_HASH: u32 = 1 << 4;
pub const DERIVE_KEY_CONTEXT: u32 = 1 << 5;
pub const DERIVE_KEY_MATERIAL: u32 = 1 << 6;
pub const IV: CVWords = [
0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
];
pub const IV_BYTES: CVBytes = le_bytes_from_words_32(&IV);
pub const MSG_SCHEDULE: [[usize; 16]; 7] = [
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8],
[3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1],
[10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6],
[12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4],
[9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7],
[11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13],
];
// never less than 2
pub const MAX_SIMD_DEGREE: usize = 2;
pub type CVBytes = [u8; 32];
pub type CVWords = [u32; 8];
pub type BlockBytes = [u8; 64];
pub type BlockWords = [u32; 16];
pub static DETECTED_IMPL: Implementation = Implementation::new(
degree_init,
compress_init,
hash_chunks_init,
hash_parents_init,
xof_init,
xof_xor_init,
universal_hash_init,
);
fn detect() -> Implementation {
portable::implementation()
}
fn init_detected_impl() {
let detected = detect();
DETECTED_IMPL
.degree_ptr
.store(detected.degree_ptr.load(Relaxed), Relaxed);
DETECTED_IMPL
.compress_ptr
.store(detected.compress_ptr.load(Relaxed), Relaxed);
DETECTED_IMPL
.hash_chunks_ptr
.store(detected.hash_chunks_ptr.load(Relaxed), Relaxed);
DETECTED_IMPL
.hash_parents_ptr
.store(detected.hash_parents_ptr.load(Relaxed), Relaxed);
DETECTED_IMPL
.xof_ptr
.store(detected.xof_ptr.load(Relaxed), Relaxed);
DETECTED_IMPL
.xof_xor_ptr
.store(detected.xof_xor_ptr.load(Relaxed), Relaxed);
DETECTED_IMPL
.universal_hash_ptr
.store(detected.universal_hash_ptr.load(Relaxed), Relaxed);
}
pub struct Implementation {
degree_ptr: AtomicPtr<()>,
compress_ptr: AtomicPtr<()>,
hash_chunks_ptr: AtomicPtr<()>,
hash_parents_ptr: AtomicPtr<()>,
xof_ptr: AtomicPtr<()>,
xof_xor_ptr: AtomicPtr<()>,
universal_hash_ptr: AtomicPtr<()>,
}
impl Implementation {
const fn new(
degree_fn: DegreeFn,
compress_fn: CompressFn,
hash_chunks_fn: HashChunksFn,
hash_parents_fn: HashParentsFn,
xof_fn: XofFn,
xof_xor_fn: XofFn,
universal_hash_fn: UniversalHashFn,
) -> Self {
Self {
degree_ptr: AtomicPtr::new(degree_fn as *mut ()),
compress_ptr: AtomicPtr::new(compress_fn as *mut ()),
hash_chunks_ptr: AtomicPtr::new(hash_chunks_fn as *mut ()),
hash_parents_ptr: AtomicPtr::new(hash_parents_fn as *mut ()),
xof_ptr: AtomicPtr::new(xof_fn as *mut ()),
xof_xor_ptr: AtomicPtr::new(xof_xor_fn as *mut ()),
universal_hash_ptr: AtomicPtr::new(universal_hash_fn as *mut ()),
}
}
#[inline]
fn degree_fn(&self) -> DegreeFn {
unsafe { mem::transmute(self.degree_ptr.load(Relaxed)) }
}
#[inline]
pub fn degree(&self) -> usize {
let degree = unsafe { self.degree_fn()() };
debug_assert!(degree >= 2);
debug_assert!(degree <= MAX_SIMD_DEGREE);
debug_assert_eq!(1, degree.count_ones(), "power of 2");
degree
}
#[inline]
pub fn split_transposed_vectors<'v>(
&self,
vectors: &'v mut TransposedVectors,
) -> (TransposedSplit<'v>, TransposedSplit<'v>) {
unsafe { vectors.split(self.degree()) }
}
#[inline]
fn compress_fn(&self) -> CompressFn {
unsafe { mem::transmute(self.compress_ptr.load(Relaxed)) }
}
#[inline]
pub fn compress(
&self,
block: &BlockBytes,
block_len: u32,
cv: &CVBytes,
counter: u64,
flags: u32,
) -> CVBytes {
let mut out = [0u8; 32];
unsafe {
self.compress_fn()(block, block_len, cv, counter, flags, &mut out);
}
out
}
// The contract for HashChunksFn doesn't require the implementation to support single-chunk
// inputs. Instead we handle that case here by calling compress in a loop.
#[inline]
fn hash_one_chunk(
&self,
mut input: &[u8],
key: &CVBytes,
counter: u64,
mut flags: u32,
output: TransposedSplit,
) {
debug_assert!(input.len() <= CHUNK_LEN);
let mut cv = *key;
flags |= CHUNK_START;
while input.len() > BLOCK_LEN {
cv = self.compress(
input[..BLOCK_LEN].try_into().unwrap(),
BLOCK_LEN as u32,
&cv,
counter,
flags,
);
input = &input[BLOCK_LEN..];
flags &= !CHUNK_START;
}
let mut final_block = [0u8; BLOCK_LEN];
final_block[..input.len()].copy_from_slice(input);
cv = self.compress(
&final_block,
input.len() as u32,
&cv,
counter,
flags | CHUNK_END,
);
unsafe {
write_transposed_cv(&words_from_le_bytes_32(&cv), output.ptr);
}
}
#[inline]
fn hash_chunks_fn(&self) -> HashChunksFn {
unsafe { mem::transmute(self.hash_chunks_ptr.load(Relaxed)) }
}
#[inline]
pub fn hash_chunks(
&self,
input: &[u8],
key: &CVBytes,
counter: u64,
flags: u32,
transposed_output: TransposedSplit,
) -> usize {
debug_assert!(input.len() <= self.degree() * CHUNK_LEN);
if input.len() <= CHUNK_LEN {
// The underlying hash_chunks_fn isn't required to support this case. Instead we handle
// it by calling compress_fn in a loop. But note that we still don't support root
// finalization or the empty input here.
self.hash_one_chunk(input, key, counter, flags, transposed_output);
return 1;
}
// SAFETY: If the caller passes in more than MAX_SIMD_DEGREE * CHUNK_LEN bytes, silently
// ignore the remainder. This makes it impossible to write out of bounds in a properly
// constructed TransposedSplit.
let len = cmp::min(input.len(), MAX_SIMD_DEGREE * CHUNK_LEN);
unsafe {
self.hash_chunks_fn()(
input.as_ptr(),
len,
key,
counter,
flags,
transposed_output.ptr,
);
}
if input.len() % CHUNK_LEN == 0 {
input.len() / CHUNK_LEN
} else {
(input.len() / CHUNK_LEN) + 1
}
}
#[inline]
fn hash_parents_fn(&self) -> HashParentsFn {
unsafe { mem::transmute(self.hash_parents_ptr.load(Relaxed)) }
}
#[inline]
pub fn hash_parents(
&self,
transposed_input: &TransposedVectors,
mut num_cvs: usize,
key: &CVBytes,
flags: u32,
transposed_output: TransposedSplit,
) -> usize {
debug_assert!(num_cvs <= 2 * MAX_SIMD_DEGREE);
// SAFETY: Cap num_cvs at 2 * MAX_SIMD_DEGREE, to guarantee no out-of-bounds accesses.
num_cvs = cmp::min(num_cvs, 2 * MAX_SIMD_DEGREE);
let mut odd_cv = [0u32; 8];
if num_cvs % 2 == 1 {
unsafe {
odd_cv = read_transposed_cv(transposed_input.as_ptr().add(num_cvs - 1));
}
}
let num_parents = num_cvs / 2;
unsafe {
self.hash_parents_fn()(
transposed_input.as_ptr(),
num_parents,
key,
flags | PARENT,
transposed_output.ptr,
);
}
if num_cvs % 2 == 1 {
unsafe {
write_transposed_cv(&odd_cv, transposed_output.ptr.add(num_parents));
}
num_parents + 1
} else {
num_parents
}
}
#[inline]
pub fn reduce_parents(
&self,
transposed_in_out: &mut TransposedVectors,
mut num_cvs: usize,
key: &CVBytes,
flags: u32,
) -> usize {
debug_assert!(num_cvs <= 2 * MAX_SIMD_DEGREE);
// SAFETY: Cap num_cvs at 2 * MAX_SIMD_DEGREE, to guarantee no out-of-bounds accesses.
num_cvs = cmp::min(num_cvs, 2 * MAX_SIMD_DEGREE);
let in_out_ptr = transposed_in_out.as_mut_ptr();
let mut odd_cv = [0u32; 8];
if num_cvs % 2 == 1 {
unsafe {
odd_cv = read_transposed_cv(in_out_ptr.add(num_cvs - 1));
}
}
let num_parents = num_cvs / 2;
unsafe {
self.hash_parents_fn()(in_out_ptr, num_parents, key, flags | PARENT, in_out_ptr);
}
if num_cvs % 2 == 1 {
unsafe {
write_transposed_cv(&odd_cv, in_out_ptr.add(num_parents));
}
num_parents + 1
} else {
num_parents
}
}
#[inline]
fn xof_fn(&self) -> XofFn {
unsafe { mem::transmute(self.xof_ptr.load(Relaxed)) }
}
#[inline]
pub fn xof(
&self,
block: &BlockBytes,
block_len: u32,
cv: &CVBytes,
mut counter: u64,
flags: u32,
mut out: &mut [u8],
) {
let degree = self.degree();
let simd_len = degree * BLOCK_LEN;
while !out.is_empty() {
let take = cmp::min(simd_len, out.len());
unsafe {
self.xof_fn()(
block,
block_len,
cv,
counter,
flags | ROOT,
out.as_mut_ptr(),
take,
);
}
out = &mut out[take..];
counter += degree as u64;
}
}
#[inline]
fn xof_xor_fn(&self) -> XofFn {
unsafe { mem::transmute(self.xof_xor_ptr.load(Relaxed)) }
}
#[inline]
pub fn xof_xor(
&self,
block: &BlockBytes,
block_len: u32,
cv: &CVBytes,
mut counter: u64,
flags: u32,
mut out: &mut [u8],
) {
let degree = self.degree();
let simd_len = degree * BLOCK_LEN;
while !out.is_empty() {
let take = cmp::min(simd_len, out.len());
unsafe {
self.xof_xor_fn()(
block,
block_len,
cv,
counter,
flags | ROOT,
out.as_mut_ptr(),
take,
);
}
out = &mut out[take..];
counter += degree as u64;
}
}
#[inline]
fn universal_hash_fn(&self) -> UniversalHashFn {
unsafe { mem::transmute(self.universal_hash_ptr.load(Relaxed)) }
}
#[inline]
pub fn universal_hash(&self, mut input: &[u8], key: &CVBytes, mut counter: u64) -> [u8; 16] {
let degree = self.degree();
let simd_len = degree * BLOCK_LEN;
let mut ret = [0u8; 16];
while !input.is_empty() {
let take = cmp::min(simd_len, input.len());
let mut output = [0u8; 16];
unsafe {
self.universal_hash_fn()(input.as_ptr(), take, key, counter, &mut output);
}
input = &input[take..];
counter += degree as u64;
for byte_index in 0..16 {
ret[byte_index] ^= output[byte_index];
}
}
ret
}
}
impl Clone for Implementation {
fn clone(&self) -> Self {
Self {
degree_ptr: AtomicPtr::new(self.degree_ptr.load(Relaxed)),
compress_ptr: AtomicPtr::new(self.compress_ptr.load(Relaxed)),
hash_chunks_ptr: AtomicPtr::new(self.hash_chunks_ptr.load(Relaxed)),
hash_parents_ptr: AtomicPtr::new(self.hash_parents_ptr.load(Relaxed)),
xof_ptr: AtomicPtr::new(self.xof_ptr.load(Relaxed)),
xof_xor_ptr: AtomicPtr::new(self.xof_xor_ptr.load(Relaxed)),
universal_hash_ptr: AtomicPtr::new(self.universal_hash_ptr.load(Relaxed)),
}
}
}
// never less than 2
type DegreeFn = unsafe extern "C" fn() -> usize;
unsafe extern "C" fn degree_init() -> usize {
init_detected_impl();
DETECTED_IMPL.degree_fn()()
}
type CompressFn = unsafe extern "C" fn(
block: *const BlockBytes, // zero padded to 64 bytes
block_len: u32,
cv: *const CVBytes,
counter: u64,
flags: u32,
out: *mut CVBytes, // may overlap the input
);
unsafe extern "C" fn compress_init(
block: *const BlockBytes,
block_len: u32,
cv: *const CVBytes,
counter: u64,
flags: u32,
out: *mut CVBytes,
) {
init_detected_impl();
DETECTED_IMPL.compress_fn()(block, block_len, cv, counter, flags, out);
}
type CompressXofFn = unsafe extern "C" fn(
block: *const BlockBytes, // zero padded to 64 bytes
block_len: u32,
cv: *const CVBytes,
counter: u64,
flags: u32,
out: *mut BlockBytes, // may overlap the input
);
type HashChunksFn = unsafe extern "C" fn(
input: *const u8,
input_len: usize,
key: *const CVBytes,
counter: u64,
flags: u32,
transposed_output: *mut u32,
);
unsafe extern "C" fn hash_chunks_init(
input: *const u8,
input_len: usize,
key: *const CVBytes,
counter: u64,
flags: u32,
transposed_output: *mut u32,
) {
init_detected_impl();
DETECTED_IMPL.hash_chunks_fn()(input, input_len, key, counter, flags, transposed_output);
}
type HashParentsFn = unsafe extern "C" fn(
transposed_input: *const u32,
num_parents: usize,
key: *const CVBytes,
flags: u32,
transposed_output: *mut u32, // may overlap the input
);
unsafe extern "C" fn hash_parents_init(
transposed_input: *const u32,
num_parents: usize,
key: *const CVBytes,
flags: u32,
transposed_output: *mut u32,
) {
init_detected_impl();
DETECTED_IMPL.hash_parents_fn()(transposed_input, num_parents, key, flags, transposed_output);
}
// This signature covers both xof() and xof_xor().
type XofFn = unsafe extern "C" fn(
block: *const BlockBytes, // zero padded to 64 bytes
block_len: u32,
cv: *const CVBytes,
counter: u64,
flags: u32,
out: *mut u8,
out_len: usize,
);
unsafe extern "C" fn xof_init(
block: *const BlockBytes,
block_len: u32,
cv: *const CVBytes,
counter: u64,
flags: u32,
out: *mut u8,
out_len: usize,
) {
init_detected_impl();
DETECTED_IMPL.xof_fn()(block, block_len, cv, counter, flags, out, out_len);
}
unsafe extern "C" fn xof_xor_init(
block: *const BlockBytes,
block_len: u32,
cv: *const CVBytes,
counter: u64,
flags: u32,
out: *mut u8,
out_len: usize,
) {
init_detected_impl();
DETECTED_IMPL.xof_xor_fn()(block, block_len, cv, counter, flags, out, out_len);
}
type UniversalHashFn = unsafe extern "C" fn(
input: *const u8,
input_len: usize,
key: *const CVBytes,
counter: u64,
out: *mut [u8; 16],
);
unsafe extern "C" fn universal_hash_init(
input: *const u8,
input_len: usize,
key: *const CVBytes,
counter: u64,
out: *mut [u8; 16],
) {
init_detected_impl();
DETECTED_IMPL.universal_hash_fn()(input, input_len, key, counter, out);
}
// The implicit degree of this implementation is MAX_SIMD_DEGREE.
#[inline(always)]
unsafe fn hash_chunks_using_compress(
compress: CompressFn,
mut input: *const u8,
mut input_len: usize,
key: *const CVBytes,
mut counter: u64,
flags: u32,
mut transposed_output: *mut u32,
) {
debug_assert!(input_len > 0);
debug_assert!(input_len <= MAX_SIMD_DEGREE * CHUNK_LEN);
input_len = cmp::min(input_len, MAX_SIMD_DEGREE * CHUNK_LEN);
while input_len > 0 {
let mut chunk_len = cmp::min(input_len, CHUNK_LEN);
input_len -= chunk_len;
// We only use 8 words of the CV, but compress returns 16.
let mut cv = *key;
let cv_ptr: *mut CVBytes = &mut cv;
let mut chunk_flags = flags | CHUNK_START;
while chunk_len > BLOCK_LEN {
compress(
input as *const BlockBytes,
BLOCK_LEN as u32,
cv_ptr,
counter,
chunk_flags,
cv_ptr,
);
input = input.add(BLOCK_LEN);
chunk_len -= BLOCK_LEN;
chunk_flags &= !CHUNK_START;
}
let mut last_block = [0u8; BLOCK_LEN];
ptr::copy_nonoverlapping(input, last_block.as_mut_ptr(), chunk_len);
input = input.add(chunk_len);
compress(
&last_block,
chunk_len as u32,
cv_ptr,
counter,
chunk_flags | CHUNK_END,
cv_ptr,
);
let cv_words = words_from_le_bytes_32(&cv);
for word_index in 0..8 {
transposed_output
.add(word_index * TRANSPOSED_STRIDE)
.write(cv_words[word_index]);
}
transposed_output = transposed_output.add(1);
counter += 1;
}
}
// The implicit degree of this implementation is MAX_SIMD_DEGREE.
#[inline(always)]
unsafe fn hash_parents_using_compress(
compress: CompressFn,
mut transposed_input: *const u32,
mut num_parents: usize,
key: *const CVBytes,
flags: u32,
mut transposed_output: *mut u32, // may overlap the input
) {
debug_assert!(num_parents > 0);
debug_assert!(num_parents <= MAX_SIMD_DEGREE);
while num_parents > 0 {
let mut block_bytes = [0u8; 64];
for word_index in 0..8 {
let left_child_word = transposed_input.add(word_index * TRANSPOSED_STRIDE).read();
block_bytes[WORD_LEN * word_index..][..WORD_LEN]
.copy_from_slice(&left_child_word.to_le_bytes());
let right_child_word = transposed_input
.add(word_index * TRANSPOSED_STRIDE + 1)
.read();
block_bytes[WORD_LEN * (word_index + 8)..][..WORD_LEN]
.copy_from_slice(&right_child_word.to_le_bytes());
}
let mut cv = [0u8; 32];
compress(&block_bytes, BLOCK_LEN as u32, key, 0, flags, &mut cv);
let cv_words = words_from_le_bytes_32(&cv);
for word_index in 0..8 {
transposed_output
.add(word_index * TRANSPOSED_STRIDE)
.write(cv_words[word_index]);
}
transposed_input = transposed_input.add(2);
transposed_output = transposed_output.add(1);
num_parents -= 1;
}
}
#[inline(always)]
unsafe fn xof_using_compress_xof(
compress_xof: CompressXofFn,
block: *const BlockBytes,
block_len: u32,
cv: *const CVBytes,
mut counter: u64,
flags: u32,
mut out: *mut u8,
mut out_len: usize,
) {
debug_assert!(out_len <= MAX_SIMD_DEGREE * BLOCK_LEN);
while out_len > 0 {
let mut block_output = [0u8; 64];
compress_xof(block, block_len, cv, counter, flags, &mut block_output);
let take = cmp::min(out_len, BLOCK_LEN);
ptr::copy_nonoverlapping(block_output.as_ptr(), out, take);
out = out.add(take);
out_len -= take;
counter += 1;
}
}
#[inline(always)]
unsafe fn xof_xor_using_compress_xof(
compress_xof: CompressXofFn,
block: *const BlockBytes,
block_len: u32,
cv: *const CVBytes,
mut counter: u64,
flags: u32,
mut out: *mut u8,
mut out_len: usize,
) {
debug_assert!(out_len <= MAX_SIMD_DEGREE * BLOCK_LEN);
while out_len > 0 {
let mut block_output = [0u8; 64];
compress_xof(block, block_len, cv, counter, flags, &mut block_output);
let take = cmp::min(out_len, BLOCK_LEN);
for i in 0..take {
*out.add(i) ^= block_output[i];
}
out = out.add(take);
out_len -= take;
counter += 1;
}
}
#[inline(always)]
unsafe fn universal_hash_using_compress(
compress: CompressFn,
mut input: *const u8,
mut input_len: usize,
key: *const CVBytes,
mut counter: u64,
out: *mut [u8; 16],
) {
let flags = KEYED_HASH | CHUNK_START | CHUNK_END | ROOT;
let mut result = [0u8; 16];
while input_len > 0 {
let block_len = cmp::min(input_len, BLOCK_LEN);
let mut block = [0u8; BLOCK_LEN];
ptr::copy_nonoverlapping(input, block.as_mut_ptr(), block_len);
let mut block_output = [0u8; 32];
compress(
&block,
block_len as u32,
key,
counter,
flags,
&mut block_output,
);
for i in 0..16 {
result[i] ^= block_output[i];
}
input = input.add(block_len);
input_len -= block_len;
counter += 1;
}
*out = result;
}
// this is in units of *words*, for pointer operations on *const/*mut u32
const TRANSPOSED_STRIDE: usize = 2 * MAX_SIMD_DEGREE;
#[cfg_attr(any(target_arch = "x86", target_arch = "x86_64"), repr(C, align(64)))]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TransposedVectors([[u32; 2 * MAX_SIMD_DEGREE]; 8]);
impl TransposedVectors {
pub fn new() -> Self {
Self([[0; 2 * MAX_SIMD_DEGREE]; 8])
}
pub fn extract_cv(&self, cv_index: usize) -> CVBytes {
let mut words = [0u32; 8];
for word_index in 0..8 {
words[word_index] = self.0[word_index][cv_index];
}
le_bytes_from_words_32(&words)
}
pub fn extract_parent_node(&self, parent_index: usize) -> BlockBytes {
let mut bytes = [0u8; 64];
bytes[..32].copy_from_slice(&self.extract_cv(parent_index / 2));
bytes[32..].copy_from_slice(&self.extract_cv(parent_index / 2 + 1));
bytes
}
fn as_ptr(&self) -> *const u32 {
self.0[0].as_ptr()
}
fn as_mut_ptr(&mut self) -> *mut u32 {
self.0[0].as_mut_ptr()
}
// SAFETY: This function is just pointer arithmetic, but callers assume that it's safe (not
// necessarily correct) to write up to `degree` words to either side of the split, possibly
// from different threads.
unsafe fn split(&mut self, degree: usize) -> (TransposedSplit, TransposedSplit) {
debug_assert!(degree > 0);
debug_assert!(degree <= MAX_SIMD_DEGREE);
debug_assert_eq!(degree.count_ones(), 1, "power of 2");
let ptr = self.as_mut_ptr();
let left = TransposedSplit {
ptr,
phantom_data: PhantomData,
};
let right = TransposedSplit {
ptr: ptr.wrapping_add(degree),
phantom_data: PhantomData,
};
(left, right)
}
}
pub struct TransposedSplit<'vectors> {
ptr: *mut u32,
phantom_data: PhantomData<&'vectors mut u32>,
}
unsafe impl<'vectors> Send for TransposedSplit<'vectors> {}
unsafe impl<'vectors> Sync for TransposedSplit<'vectors> {}
unsafe fn read_transposed_cv(src: *const u32) -> CVWords {
let mut cv = [0u32; 8];
for word_index in 0..8 {
let offset_words = word_index * TRANSPOSED_STRIDE;
cv[word_index] = src.add(offset_words).read();
}
cv
}
unsafe fn write_transposed_cv(cv: &CVWords, dest: *mut u32) {
for word_index in 0..8 {
let offset_words = word_index * TRANSPOSED_STRIDE;
dest.add(offset_words).write(cv[word_index]);
}
}
#[inline(always)]
pub const fn le_bytes_from_words_32(words: &CVWords) -> CVBytes {
let mut bytes = [0u8; 32];
// This loop is super verbose because currently that's what it takes to be const.
let mut word_index = 0;
while word_index < bytes.len() / WORD_LEN {
let word_bytes = words[word_index].to_le_bytes();
let mut byte_index = 0;
while byte_index < WORD_LEN {
bytes[word_index * WORD_LEN + byte_index] = word_bytes[byte_index];
byte_index += 1;
}
word_index += 1;
}
bytes
}
#[inline(always)]
pub const fn le_bytes_from_words_64(words: &BlockWords) -> BlockBytes {
let mut bytes = [0u8; 64];
// This loop is super verbose because currently that's what it takes to be const.
let mut word_index = 0;
while word_index < bytes.len() / WORD_LEN {
let word_bytes = words[word_index].to_le_bytes();
let mut byte_index = 0;
while byte_index < WORD_LEN {
bytes[word_index * WORD_LEN + byte_index] = word_bytes[byte_index];
byte_index += 1;
}
word_index += 1;
}
bytes
}
#[inline(always)]
pub const fn words_from_le_bytes_32(bytes: &CVBytes) -> CVWords {
let mut words = [0u32; 8];
// This loop is super verbose because currently that's what it takes to be const.
let mut word_index = 0;
while word_index < words.len() {
let mut word_bytes = [0u8; WORD_LEN];
let mut byte_index = 0;
while byte_index < WORD_LEN {
word_bytes[byte_index] = bytes[word_index * WORD_LEN + byte_index];
byte_index += 1;
}
words[word_index] = u32::from_le_bytes(word_bytes);
word_index += 1;
}
words
}
#[inline(always)]
pub const fn words_from_le_bytes_64(bytes: &BlockBytes) -> BlockWords {
let mut words = [0u32; 16];
// This loop is super verbose because currently that's what it takes to be const.
let mut word_index = 0;
while word_index < words.len() {
let mut word_bytes = [0u8; WORD_LEN];
let mut byte_index = 0;
while byte_index < WORD_LEN {
word_bytes[byte_index] = bytes[word_index * WORD_LEN + byte_index];
byte_index += 1;
}
words[word_index] = u32::from_le_bytes(word_bytes);
word_index += 1;
}
words
}
#[test]
fn test_byte_word_round_trips() {
let cv = *b"This is 32 LE bytes/eight words.";
assert_eq!(cv, le_bytes_from_words_32(&words_from_le_bytes_32(&cv)));
let block = *b"This is sixty-four little-endian bytes, or sixteen 32-bit words.";
assert_eq!(
block,
le_bytes_from_words_64(&words_from_le_bytes_64(&block)),
);
}
// The largest power of two less than or equal to `n`, used for left_len()
// immediately below, and also directly in Hasher::update().
pub fn largest_power_of_two_leq(n: usize) -> usize {
((n / 2) + 1).next_power_of_two()
}
#[test]
fn test_largest_power_of_two_leq() {
let input_output = &[
// The zero case is nonsensical, but it does work.
(0, 1),
(1, 1),
(2, 2),
(3, 2),
(4, 4),
(5, 4),
(6, 4),
(7, 4),
(8, 8),
// the largest possible usize
(usize::MAX, (usize::MAX >> 1) + 1),
];
for &(input, output) in input_output {
assert_eq!(
output,
crate::largest_power_of_two_leq(input),
"wrong output for n={}",
input
);
}
}
// Given some input larger than one chunk, return the number of bytes that
// should go in the left subtree. This is the largest power-of-2 number of
// chunks that leaves at least 1 byte for the right subtree.
pub fn left_len(content_len: usize) -> usize {
debug_assert!(content_len > CHUNK_LEN);
// Subtract 1 to reserve at least one byte for the right side.
let full_chunks = (content_len - 1) / CHUNK_LEN;
largest_power_of_two_leq(full_chunks) * CHUNK_LEN
}
#[test]
fn test_left_len() {
let input_output = &[
(CHUNK_LEN + 1, CHUNK_LEN),
(2 * CHUNK_LEN - 1, CHUNK_LEN),
(2 * CHUNK_LEN, CHUNK_LEN),
(2 * CHUNK_LEN + 1, 2 * CHUNK_LEN),
(4 * CHUNK_LEN - 1, 2 * CHUNK_LEN),
(4 * CHUNK_LEN, 2 * CHUNK_LEN),
(4 * CHUNK_LEN + 1, 4 * CHUNK_LEN),
];
for &(input, output) in input_output {
assert_eq!(left_len(input), output);
}
}

262
rust/guts/src/portable.rs Normal file
View File

@ -0,0 +1,262 @@
use crate::{
le_bytes_from_words_32, le_bytes_from_words_64, words_from_le_bytes_32, words_from_le_bytes_64,
BlockBytes, BlockWords, CVBytes, CVWords, Implementation, IV, MAX_SIMD_DEGREE, MSG_SCHEDULE,
};
const DEGREE: usize = MAX_SIMD_DEGREE;
unsafe extern "C" fn degree() -> usize {
DEGREE
}
#[inline(always)]
fn g(state: &mut BlockWords, a: usize, b: usize, c: usize, d: usize, x: u32, y: u32) {
state[a] = state[a].wrapping_add(state[b]).wrapping_add(x);
state[d] = (state[d] ^ state[a]).rotate_right(16);
state[c] = state[c].wrapping_add(state[d]);
state[b] = (state[b] ^ state[c]).rotate_right(12);
state[a] = state[a].wrapping_add(state[b]).wrapping_add(y);
state[d] = (state[d] ^ state[a]).rotate_right(8);
state[c] = state[c].wrapping_add(state[d]);
state[b] = (state[b] ^ state[c]).rotate_right(7);
}
#[inline(always)]
fn round(state: &mut [u32; 16], msg: &BlockWords, round: usize) {
// Select the message schedule based on the round.
let schedule = MSG_SCHEDULE[round];
// Mix the columns.
g(state, 0, 4, 8, 12, msg[schedule[0]], msg[schedule[1]]);
g(state, 1, 5, 9, 13, msg[schedule[2]], msg[schedule[3]]);
g(state, 2, 6, 10, 14, msg[schedule[4]], msg[schedule[5]]);
g(state, 3, 7, 11, 15, msg[schedule[6]], msg[schedule[7]]);
// Mix the diagonals.
g(state, 0, 5, 10, 15, msg[schedule[8]], msg[schedule[9]]);
g(state, 1, 6, 11, 12, msg[schedule[10]], msg[schedule[11]]);
g(state, 2, 7, 8, 13, msg[schedule[12]], msg[schedule[13]]);
g(state, 3, 4, 9, 14, msg[schedule[14]], msg[schedule[15]]);
}
#[inline(always)]
fn compress_inner(
block_words: &BlockWords,
block_len: u32,
cv_words: &CVWords,
counter: u64,
flags: u32,
) -> [u32; 16] {
let mut state = [
cv_words[0],
cv_words[1],
cv_words[2],
cv_words[3],
cv_words[4],
cv_words[5],
cv_words[6],
cv_words[7],
IV[0],
IV[1],
IV[2],
IV[3],
counter as u32,
(counter >> 32) as u32,
block_len as u32,
flags as u32,
];
for round_number in 0..7 {
round(&mut state, &block_words, round_number);
}
state
}
pub(crate) unsafe extern "C" fn compress(
block: *const BlockBytes,
block_len: u32,
cv: *const CVBytes,
counter: u64,
flags: u32,
out: *mut CVBytes,
) {
let block_words = words_from_le_bytes_64(&*block);
let cv_words = words_from_le_bytes_32(&*cv);
let mut state = compress_inner(&block_words, block_len, &cv_words, counter, flags);
for word_index in 0..8 {
state[word_index] ^= state[word_index + 8];
}
*out = le_bytes_from_words_32(state[..8].try_into().unwrap());
}
pub(crate) unsafe extern "C" fn compress_xof(
block: *const BlockBytes,
block_len: u32,
cv: *const CVBytes,
counter: u64,
flags: u32,
out: *mut BlockBytes,
) {
let block_words = words_from_le_bytes_64(&*block);
let cv_words = words_from_le_bytes_32(&*cv);
let mut state = compress_inner(&block_words, block_len, &cv_words, counter, flags);
for word_index in 0..8 {
state[word_index] ^= state[word_index + 8];
state[word_index + 8] ^= cv_words[word_index];
}
*out = le_bytes_from_words_64(&state);
}
pub(crate) unsafe extern "C" fn hash_chunks(
input: *const u8,
input_len: usize,
key: *const CVBytes,
counter: u64,
flags: u32,
transposed_output: *mut u32,
) {
crate::hash_chunks_using_compress(
compress,
input,
input_len,
key,
counter,
flags,
transposed_output,
)
}
pub(crate) unsafe extern "C" fn hash_parents(
transposed_input: *const u32,
num_parents: usize,
key: *const CVBytes,
flags: u32,
transposed_output: *mut u32, // may overlap the input
) {
crate::hash_parents_using_compress(
compress,
transposed_input,
num_parents,
key,
flags,
transposed_output,
)
}
pub(crate) unsafe extern "C" fn xof(
block: *const BlockBytes,
block_len: u32,
cv: *const CVBytes,
counter: u64,
flags: u32,
out: *mut u8,
out_len: usize,
) {
crate::xof_using_compress_xof(
compress_xof,
block,
block_len,
cv,
counter,
flags,
out,
out_len,
)
}
pub(crate) unsafe extern "C" fn xof_xor(
block: *const BlockBytes,
block_len: u32,
cv: *const CVBytes,
counter: u64,
flags: u32,
out: *mut u8,
out_len: usize,
) {
crate::xof_xor_using_compress_xof(
compress_xof,
block,
block_len,
cv,
counter,
flags,
out,
out_len,
)
}
pub(crate) unsafe extern "C" fn universal_hash(
input: *const u8,
input_len: usize,
key: *const CVBytes,
counter: u64,
out: *mut [u8; 16],
) {
crate::universal_hash_using_compress(compress, input, input_len, key, counter, out)
}
pub fn implementation() -> Implementation {
Implementation::new(
degree,
compress,
hash_chunks,
hash_parents,
xof,
xof_xor,
universal_hash,
)
}
#[cfg(test)]
mod test {
use super::*;
// This is circular but do it anyway.
#[test]
fn test_compress_vs_portable() {
crate::test::test_compress_vs_portable(&implementation());
}
#[test]
fn test_compress_vs_reference() {
crate::test::test_compress_vs_reference(&implementation());
}
// This is circular but do it anyway.
#[test]
fn test_hash_chunks_vs_portable() {
crate::test::test_hash_chunks_vs_portable(&implementation());
}
// This is circular but do it anyway.
#[test]
fn test_hash_parents_vs_portable() {
crate::test::test_hash_parents_vs_portable(&implementation());
}
#[test]
fn test_chunks_and_parents_vs_reference() {
crate::test::test_chunks_and_parents_vs_reference(&implementation());
}
// This is circular but do it anyway.
#[test]
fn test_xof_vs_portable() {
crate::test::test_xof_vs_portable(&implementation());
}
#[test]
fn test_xof_vs_reference() {
crate::test::test_xof_vs_reference(&implementation());
}
// This is circular but do it anyway.
#[test]
fn test_universal_hash_vs_portable() {
crate::test::test_universal_hash_vs_portable(&implementation());
}
#[test]
fn test_universal_hash_vs_reference() {
crate::test::test_universal_hash_vs_reference(&implementation());
}
}

523
rust/guts/src/test.rs Normal file
View File

@ -0,0 +1,523 @@
use crate::*;
pub const TEST_KEY: CVBytes = *b"whats the Elvish word for friend";
// Test a few different initial counter values.
// - 0: The base case.
// - i32::MAX: *No* overflow. But carry bugs in tricky SIMD code can screw this up, if you XOR when
// you're supposed to ANDNOT.
// - u32::MAX: The low word of the counter overflows for all inputs except the first.
// - (42 << 32) + u32::MAX: Same but with a non-zero value in the high word.
const INITIAL_COUNTERS: [u64; 4] = [
0,
i32::MAX as u64,
u32::MAX as u64,
(42u64 << 32) + u32::MAX as u64,
];
const BLOCK_LENGTHS: [usize; 4] = [0, 1, 63, 64];
pub fn paint_test_input(buf: &mut [u8]) {
for (i, b) in buf.iter_mut().enumerate() {
*b = (i % 251) as u8;
}
}
pub fn test_compress_vs_portable(test_impl: &Implementation) {
for block_len in BLOCK_LENGTHS {
dbg!(block_len);
let mut block = [0; BLOCK_LEN];
paint_test_input(&mut block[..block_len]);
for counter in INITIAL_COUNTERS {
dbg!(counter);
let portable_cv = portable::implementation().compress(
&block,
block_len as u32,
&TEST_KEY,
counter,
KEYED_HASH,
);
let test_cv =
test_impl.compress(&block, block_len as u32, &TEST_KEY, counter, KEYED_HASH);
assert_eq!(portable_cv, test_cv);
}
}
}
pub fn test_compress_vs_reference(test_impl: &Implementation) {
for block_len in BLOCK_LENGTHS {
dbg!(block_len);
let mut block = [0; BLOCK_LEN];
paint_test_input(&mut block[..block_len]);
let mut ref_hasher = reference_impl::Hasher::new_keyed(&TEST_KEY);
ref_hasher.update(&block[..block_len]);
let mut ref_hash = [0u8; 32];
ref_hasher.finalize(&mut ref_hash);
let test_cv = test_impl.compress(
&block,
block_len as u32,
&TEST_KEY,
0,
CHUNK_START | CHUNK_END | ROOT | KEYED_HASH,
);
assert_eq!(ref_hash, test_cv);
}
}
fn check_transposed_eq(output_a: &TransposedVectors, output_b: &TransposedVectors) {
if output_a == output_b {
return;
}
for cv_index in 0..2 * MAX_SIMD_DEGREE {
let cv_a = output_a.extract_cv(cv_index);
let cv_b = output_b.extract_cv(cv_index);
if cv_a == [0; 32] && cv_b == [0; 32] {
println!("CV {cv_index:2} empty");
} else if cv_a == cv_b {
println!("CV {cv_index:2} matches");
} else {
println!("CV {cv_index:2} mismatch:");
println!(" {}", hex::encode(cv_a));
println!(" {}", hex::encode(cv_b));
}
}
panic!("transposed outputs are not equal");
}
pub fn test_hash_chunks_vs_portable(test_impl: &Implementation) {
assert!(test_impl.degree() <= MAX_SIMD_DEGREE);
dbg!(test_impl.degree() * CHUNK_LEN);
// Allocate 4 extra bytes of padding so we can make aligned slices.
let mut input_buf = [0u8; 2 * 2 * MAX_SIMD_DEGREE * CHUNK_LEN + 4];
let mut input_slice = &mut input_buf[..];
// Make sure the start of the input is word-aligned.
while input_slice.as_ptr() as usize % 4 != 0 {
input_slice = &mut input_slice[1..];
}
let (aligned_input, mut unaligned_input) =
input_slice.split_at_mut(2 * MAX_SIMD_DEGREE * CHUNK_LEN);
unaligned_input = &mut unaligned_input[1..][..2 * MAX_SIMD_DEGREE * CHUNK_LEN];
assert_eq!(aligned_input.as_ptr() as usize % 4, 0);
assert_eq!(unaligned_input.as_ptr() as usize % 4, 1);
paint_test_input(aligned_input);
paint_test_input(unaligned_input);
// Try just below, equal to, and just above every whole number of chunks.
let mut input_2_lengths = Vec::new();
let mut next_len = 2 * CHUNK_LEN;
loop {
// 95 is one whole block plus one interesting part of another
input_2_lengths.push(next_len - 95);
input_2_lengths.push(next_len);
if next_len == test_impl.degree() * CHUNK_LEN {
break;
}
input_2_lengths.push(next_len + 95);
next_len += CHUNK_LEN;
}
for input_2_len in input_2_lengths {
dbg!(input_2_len);
let aligned_input1 = &aligned_input[..test_impl.degree() * CHUNK_LEN];
let aligned_input2 = &aligned_input[test_impl.degree() * CHUNK_LEN..][..input_2_len];
let unaligned_input1 = &unaligned_input[..test_impl.degree() * CHUNK_LEN];
let unaligned_input2 = &unaligned_input[test_impl.degree() * CHUNK_LEN..][..input_2_len];
for initial_counter in INITIAL_COUNTERS {
dbg!(initial_counter);
// Make two calls, to test the output_column parameter.
let mut portable_output = TransposedVectors::new();
let (portable_left, portable_right) =
test_impl.split_transposed_vectors(&mut portable_output);
portable::implementation().hash_chunks(
aligned_input1,
&IV_BYTES,
initial_counter,
0,
portable_left,
);
portable::implementation().hash_chunks(
aligned_input2,
&TEST_KEY,
initial_counter + test_impl.degree() as u64,
KEYED_HASH,
portable_right,
);
let mut test_output = TransposedVectors::new();
let (test_left, test_right) = test_impl.split_transposed_vectors(&mut test_output);
test_impl.hash_chunks(aligned_input1, &IV_BYTES, initial_counter, 0, test_left);
test_impl.hash_chunks(
aligned_input2,
&TEST_KEY,
initial_counter + test_impl.degree() as u64,
KEYED_HASH,
test_right,
);
check_transposed_eq(&portable_output, &test_output);
// Do the same thing with unaligned input.
let mut unaligned_test_output = TransposedVectors::new();
let (unaligned_left, unaligned_right) =
test_impl.split_transposed_vectors(&mut unaligned_test_output);
test_impl.hash_chunks(
unaligned_input1,
&IV_BYTES,
initial_counter,
0,
unaligned_left,
);
test_impl.hash_chunks(
unaligned_input2,
&TEST_KEY,
initial_counter + test_impl.degree() as u64,
KEYED_HASH,
unaligned_right,
);
check_transposed_eq(&portable_output, &unaligned_test_output);
}
}
}
fn painted_transposed_input() -> TransposedVectors {
let mut vectors = TransposedVectors::new();
let mut val = 0;
for col in 0..2 * MAX_SIMD_DEGREE {
for row in 0..8 {
vectors.0[row][col] = val;
val += 1;
}
}
vectors
}
pub fn test_hash_parents_vs_portable(test_impl: &Implementation) {
assert!(test_impl.degree() <= MAX_SIMD_DEGREE);
let input = painted_transposed_input();
for num_parents in 2..=(test_impl.degree() / 2) {
dbg!(num_parents);
let mut portable_output = TransposedVectors::new();
let (portable_left, portable_right) =
test_impl.split_transposed_vectors(&mut portable_output);
portable::implementation().hash_parents(
&input,
2 * num_parents, // num_cvs
&IV_BYTES,
0,
portable_left,
);
portable::implementation().hash_parents(
&input,
2 * num_parents, // num_cvs
&TEST_KEY,
KEYED_HASH,
portable_right,
);
let mut test_output = TransposedVectors::new();
let (test_left, test_right) = test_impl.split_transposed_vectors(&mut test_output);
test_impl.hash_parents(
&input,
2 * num_parents, // num_cvs
&IV_BYTES,
0,
test_left,
);
test_impl.hash_parents(
&input,
2 * num_parents, // num_cvs
&TEST_KEY,
KEYED_HASH,
test_right,
);
check_transposed_eq(&portable_output, &test_output);
}
}
fn hash_with_chunks_and_parents_recurse(
test_impl: &Implementation,
input: &[u8],
counter: u64,
output: TransposedSplit,
) -> usize {
assert!(input.len() > 0);
if input.len() <= test_impl.degree() * CHUNK_LEN {
return test_impl.hash_chunks(input, &IV_BYTES, counter, 0, output);
}
let (left_input, right_input) = input.split_at(left_len(input.len()));
let mut child_output = TransposedVectors::new();
let (left_output, right_output) = test_impl.split_transposed_vectors(&mut child_output);
let mut children =
hash_with_chunks_and_parents_recurse(test_impl, left_input, counter, left_output);
assert_eq!(children, test_impl.degree());
children += hash_with_chunks_and_parents_recurse(
test_impl,
right_input,
counter + (left_input.len() / CHUNK_LEN) as u64,
right_output,
);
test_impl.hash_parents(&child_output, children, &IV_BYTES, PARENT, output)
}
// Note: This test implementation doesn't support the 1-chunk-or-less case.
fn root_hash_with_chunks_and_parents(test_impl: &Implementation, input: &[u8]) -> CVBytes {
// TODO: handle the 1-chunk case?
assert!(input.len() > CHUNK_LEN);
let mut cvs = TransposedVectors::new();
// The right half of these vectors are never used.
let (cvs_left, _) = test_impl.split_transposed_vectors(&mut cvs);
let mut num_cvs = hash_with_chunks_and_parents_recurse(test_impl, input, 0, cvs_left);
while num_cvs > 2 {
num_cvs = test_impl.reduce_parents(&mut cvs, num_cvs, &IV_BYTES, 0);
}
test_impl.compress(
&cvs.extract_parent_node(0),
BLOCK_LEN as u32,
&IV_BYTES,
0,
PARENT | ROOT,
)
}
pub fn test_chunks_and_parents_vs_reference(test_impl: &Implementation) {
assert_eq!(test_impl.degree().count_ones(), 1, "power of 2");
const MAX_INPUT_LEN: usize = 2 * MAX_SIMD_DEGREE * CHUNK_LEN;
let mut input_buf = [0u8; MAX_INPUT_LEN];
paint_test_input(&mut input_buf);
// Try just below, equal to, and just above every whole number of chunks, except that
// root_hash_with_chunks_and_parents doesn't support the 1-chunk-or-less case.
let mut test_lengths = vec![CHUNK_LEN + 1];
let mut next_len = 2 * CHUNK_LEN;
loop {
// 95 is one whole block plus one interesting part of another
test_lengths.push(next_len - 95);
test_lengths.push(next_len);
if next_len == MAX_INPUT_LEN {
break;
}
test_lengths.push(next_len + 95);
next_len += CHUNK_LEN;
}
for test_len in test_lengths {
dbg!(test_len);
let input = &input_buf[..test_len];
let mut ref_hasher = reference_impl::Hasher::new();
ref_hasher.update(&input);
let mut ref_hash = [0u8; 32];
ref_hasher.finalize(&mut ref_hash);
let test_hash = root_hash_with_chunks_and_parents(test_impl, input);
assert_eq!(ref_hash, test_hash);
}
}
pub fn test_xof_vs_portable(test_impl: &Implementation) {
let flags = CHUNK_START | CHUNK_END | KEYED_HASH;
for counter in INITIAL_COUNTERS {
dbg!(counter);
for input_len in [0, 1, BLOCK_LEN] {
dbg!(input_len);
let mut input_block = [0u8; BLOCK_LEN];
for byte_index in 0..input_len {
input_block[byte_index] = byte_index as u8 + 42;
}
// Try equal to and partway through every whole number of output blocks.
const MAX_OUTPUT_LEN: usize = 2 * MAX_SIMD_DEGREE * BLOCK_LEN;
let mut output_lengths = Vec::new();
let mut next_len = 0;
loop {
output_lengths.push(next_len);
if next_len == MAX_OUTPUT_LEN {
break;
}
output_lengths.push(next_len + 31);
next_len += BLOCK_LEN;
}
for output_len in output_lengths {
dbg!(output_len);
let mut portable_output = [0xff; MAX_OUTPUT_LEN];
portable::implementation().xof(
&input_block,
input_len as u32,
&TEST_KEY,
counter,
flags,
&mut portable_output[..output_len],
);
let mut test_output = [0xff; MAX_OUTPUT_LEN];
test_impl.xof(
&input_block,
input_len as u32,
&TEST_KEY,
counter,
flags,
&mut test_output[..output_len],
);
assert_eq!(portable_output, test_output);
// Double check that the implementation didn't overwrite.
assert!(test_output[output_len..].iter().all(|&b| b == 0xff));
// The first XOR cancels out the output.
test_impl.xof_xor(
&input_block,
input_len as u32,
&TEST_KEY,
counter,
flags,
&mut test_output[..output_len],
);
assert!(test_output[..output_len].iter().all(|&b| b == 0));
assert!(test_output[output_len..].iter().all(|&b| b == 0xff));
// The second XOR restores out the output.
test_impl.xof_xor(
&input_block,
input_len as u32,
&TEST_KEY,
counter,
flags,
&mut test_output[..output_len],
);
assert_eq!(portable_output, test_output);
assert!(test_output[output_len..].iter().all(|&b| b == 0xff));
}
}
}
}
pub fn test_xof_vs_reference(test_impl: &Implementation) {
let input = b"hello world";
let mut input_block = [0; BLOCK_LEN];
input_block[..input.len()].copy_from_slice(input);
const MAX_OUTPUT_LEN: usize = 2 * MAX_SIMD_DEGREE * BLOCK_LEN;
let mut ref_output = [0; MAX_OUTPUT_LEN];
let mut ref_hasher = reference_impl::Hasher::new_keyed(&TEST_KEY);
ref_hasher.update(input);
ref_hasher.finalize(&mut ref_output);
// Try equal to and partway through every whole number of output blocks.
let mut output_lengths = vec![0, 1, 31];
let mut next_len = BLOCK_LEN;
loop {
output_lengths.push(next_len);
if next_len == MAX_OUTPUT_LEN {
break;
}
output_lengths.push(next_len + 31);
next_len += BLOCK_LEN;
}
for output_len in output_lengths {
dbg!(output_len);
let mut test_output = [0; MAX_OUTPUT_LEN];
test_impl.xof(
&input_block,
input.len() as u32,
&TEST_KEY,
0,
KEYED_HASH | CHUNK_START | CHUNK_END,
&mut test_output[..output_len],
);
assert_eq!(ref_output[..output_len], test_output[..output_len]);
// Double check that the implementation didn't overwrite.
assert!(test_output[output_len..].iter().all(|&b| b == 0));
// Do it again starting from block 1.
if output_len >= BLOCK_LEN {
test_impl.xof(
&input_block,
input.len() as u32,
&TEST_KEY,
1,
KEYED_HASH | CHUNK_START | CHUNK_END,
&mut test_output[..output_len - BLOCK_LEN],
);
assert_eq!(
ref_output[BLOCK_LEN..output_len],
test_output[..output_len - BLOCK_LEN],
);
}
}
}
pub fn test_universal_hash_vs_portable(test_impl: &Implementation) {
const MAX_INPUT_LEN: usize = 2 * MAX_SIMD_DEGREE * BLOCK_LEN;
let mut input_buf = [0; MAX_INPUT_LEN];
paint_test_input(&mut input_buf);
// Try equal to and partway through every whole number of input blocks.
let mut input_lengths = vec![0, 1, 31];
let mut next_len = BLOCK_LEN;
loop {
input_lengths.push(next_len);
if next_len == MAX_INPUT_LEN {
break;
}
input_lengths.push(next_len + 31);
next_len += BLOCK_LEN;
}
for input_len in input_lengths {
dbg!(input_len);
for counter in INITIAL_COUNTERS {
dbg!(counter);
let portable_output = portable::implementation().universal_hash(
&input_buf[..input_len],
&TEST_KEY,
counter,
);
let test_output = test_impl.universal_hash(&input_buf[..input_len], &TEST_KEY, counter);
assert_eq!(portable_output, test_output);
}
}
}
fn reference_impl_universal_hash(input: &[u8], key: &CVBytes) -> [u8; UNIVERSAL_HASH_LEN] {
// The reference_impl doesn't support XOF seeking, so we have to materialize an entire extended
// output to seek to a block.
const MAX_BLOCKS: usize = 2 * MAX_SIMD_DEGREE;
assert!(input.len() / BLOCK_LEN <= MAX_BLOCKS);
let mut output_buffer: [u8; BLOCK_LEN * MAX_BLOCKS] = [0u8; BLOCK_LEN * MAX_BLOCKS];
let mut result = [0u8; UNIVERSAL_HASH_LEN];
let mut block_start = 0;
while block_start < input.len() {
let block_len = cmp::min(input.len() - block_start, BLOCK_LEN);
let mut ref_hasher = reference_impl::Hasher::new_keyed(key);
ref_hasher.update(&input[block_start..block_start + block_len]);
ref_hasher.finalize(&mut output_buffer[..block_start + UNIVERSAL_HASH_LEN]);
for byte_index in 0..UNIVERSAL_HASH_LEN {
result[byte_index] ^= output_buffer[block_start + byte_index];
}
block_start += BLOCK_LEN;
}
result
}
pub fn test_universal_hash_vs_reference(test_impl: &Implementation) {
const MAX_INPUT_LEN: usize = 2 * MAX_SIMD_DEGREE * BLOCK_LEN;
let mut input_buf = [0; MAX_INPUT_LEN];
paint_test_input(&mut input_buf);
// Try equal to and partway through every whole number of input blocks.
let mut input_lengths = vec![0, 1, 31];
let mut next_len = BLOCK_LEN;
loop {
input_lengths.push(next_len);
if next_len == MAX_INPUT_LEN {
break;
}
input_lengths.push(next_len + 31);
next_len += BLOCK_LEN;
}
for input_len in input_lengths {
dbg!(input_len);
let ref_output = reference_impl_universal_hash(&input_buf[..input_len], &TEST_KEY);
let test_output = test_impl.universal_hash(&input_buf[..input_len], &TEST_KEY, 0);
assert_eq!(ref_output, test_output);
}
}