1
0
Fork 0
mirror of https://github.com/BLAKE3-team/BLAKE3 synced 2024-05-27 00:16:03 +02:00

AVX-512 universal_hash

This commit is contained in:
Jack O'Connor 2023-07-18 21:16:54 -07:00
parent e86e18cb7a
commit e56c6a814f
4 changed files with 242 additions and 6 deletions

View File

@ -2,7 +2,8 @@
extern crate test;
use blake3_guts::BLOCK_LEN;
use blake3_guts as guts;
use guts::BLOCK_LEN;
use rand::prelude::*;
use test::Bencher;
@ -375,3 +376,34 @@ fn bench_xof_0512(b: &mut Bencher) {
fn bench_xof_1024(b: &mut Bencher) {
bench_xof(b, 1024);
}
fn bench_universal_hash(b: &mut Bencher, len: usize) {
let mut input = RandomInput::new(b, len);
let key = [99; 32];
b.iter(|| guts::DETECTED_IMPL.universal_hash(input.get(), &key, 0));
}
#[bench]
fn bench_universal_hash_0064(b: &mut Bencher) {
bench_universal_hash(b, 64);
}
#[bench]
fn bench_universal_hash_0128(b: &mut Bencher) {
bench_universal_hash(b, 128);
}
#[bench]
fn bench_universal_hash_0256(b: &mut Bencher) {
bench_universal_hash(b, 256);
}
#[bench]
fn bench_universal_hash_0512(b: &mut Bencher) {
bench_universal_hash(b, 512);
}
#[bench]
fn bench_universal_hash_1024(b: &mut Bencher) {
bench_universal_hash(b, 1024);
}

View File

@ -31,6 +31,8 @@
.global _blake3_guts_avx512_xof_16_exact
.global blake3_guts_avx512_xof_xor_16_exact
.global _blake3_guts_avx512_xof_xor_16_exact
.global blake3_guts_avx512_universal_hash_16_exact
.global _blake3_guts_avx512_universal_hash_16_exact
#ifdef __APPLE__
.text
@ -4056,6 +4058,183 @@ blake3_guts_avx512_xof_xor_16_exact:
vmovdqu32 ZMMWORD PTR [r9+0x3c0],zmm15
ret
// rdi: input pointer
// rsi: input length [unused]
// rdx: key
// rcx: counter
// r8: out pointer
.p2align 6
_blake3_guts_avx512_universal_hash_16_exact:
blake3_guts_avx512_universal_hash_16_exact:
// load the message words
vmovdqu32 ymm16, ymmword ptr [rdi+0x0*0x40]
vinserti64x4 zmm16, zmm16, ymmword ptr [rdi+0x8*0x40], 0x01
vmovdqu32 ymm17, ymmword ptr [rdi+0x1*0x40]
vinserti64x4 zmm17, zmm17, ymmword ptr [rdi+0x9*0x40], 0x01
vpunpcklqdq zmm8, zmm16, zmm17
vpunpckhqdq zmm9, zmm16, zmm17
vmovdqu32 ymm18, ymmword ptr [rdi+0x2*0x40]
vinserti64x4 zmm18, zmm18, ymmword ptr [rdi+0xa*0x40], 0x01
vmovdqu32 ymm19, ymmword ptr [rdi+0x3*0x40]
vinserti64x4 zmm19, zmm19, ymmword ptr [rdi+0xb*0x40], 0x01
vpunpcklqdq zmm10, zmm18, zmm19
vpunpckhqdq zmm11, zmm18, zmm19
vmovdqu32 ymm16, ymmword ptr [rdi+0x4*0x40]
vinserti64x4 zmm16, zmm16, ymmword ptr [rdi+0xc*0x40], 0x01
vmovdqu32 ymm17, ymmword ptr [rdi+0x5*0x40]
vinserti64x4 zmm17, zmm17, ymmword ptr [rdi+0xd*0x40], 0x01
vpunpcklqdq zmm12, zmm16, zmm17
vpunpckhqdq zmm13, zmm16, zmm17
vmovdqu32 ymm18, ymmword ptr [rdi+0x6*0x40]
vinserti64x4 zmm18, zmm18, ymmword ptr [rdi+0xe*0x40], 0x01
vmovdqu32 ymm19, ymmword ptr [rdi+0x7*0x40]
vinserti64x4 zmm19, zmm19, ymmword ptr [rdi+0xf*0x40], 0x01
vpunpcklqdq zmm14, zmm18, zmm19
vpunpckhqdq zmm15, zmm18, zmm19
vmovdqa32 zmm27, zmmword ptr [INDEX0+rip]
vmovdqa32 zmm31, zmmword ptr [INDEX1+rip]
vshufps zmm16, zmm8, zmm10, 136
vshufps zmm17, zmm12, zmm14, 136
vmovdqa32 zmm20, zmm16
vpermt2d zmm16, zmm27, zmm17
vpermt2d zmm20, zmm31, zmm17
vshufps zmm17, zmm8, zmm10, 221
vshufps zmm30, zmm12, zmm14, 221
vmovdqa32 zmm21, zmm17
vpermt2d zmm17, zmm27, zmm30
vpermt2d zmm21, zmm31, zmm30
vshufps zmm18, zmm9, zmm11, 136
vshufps zmm8, zmm13, zmm15, 136
vmovdqa32 zmm22, zmm18
vpermt2d zmm18, zmm27, zmm8
vpermt2d zmm22, zmm31, zmm8
vshufps zmm19, zmm9, zmm11, 221
vshufps zmm8, zmm13, zmm15, 221
vmovdqa32 zmm23, zmm19
vpermt2d zmm19, zmm27, zmm8
vpermt2d zmm23, zmm31, zmm8
vmovdqu32 ymm24, ymmword ptr [rdi+0x0*0x40+0x20]
vinserti64x4 zmm24, zmm24, ymmword ptr [rdi+0x8*0x40+0x20], 0x01
vmovdqu32 ymm25, ymmword ptr [rdi+0x1*0x40+0x20]
vinserti64x4 zmm25, zmm25, ymmword ptr [rdi+0x9*0x40+0x20], 0x01
vpunpcklqdq zmm8, zmm24, zmm25
vpunpckhqdq zmm9, zmm24, zmm25
vmovdqu32 ymm24, ymmword ptr [rdi+0x2*0x40+0x20]
vinserti64x4 zmm24, zmm24, ymmword ptr [rdi+0xa*0x40+0x20], 0x01
vmovdqu32 ymm25, ymmword ptr [rdi+0x3*0x40+0x20]
vinserti64x4 zmm25, zmm25, ymmword ptr [rdi+0xb*0x40+0x20], 0x01
vpunpcklqdq zmm10, zmm24, zmm25
vpunpckhqdq zmm11, zmm24, zmm25
vmovdqu32 ymm24, ymmword ptr [rdi+0x4*0x40+0x20]
vinserti64x4 zmm24, zmm24, ymmword ptr [rdi+0xc*0x40+0x20], 0x01
vmovdqu32 ymm25, ymmword ptr [rdi+0x5*0x40+0x20]
vinserti64x4 zmm25, zmm25, ymmword ptr [rdi+0xd*0x40+0x20], 0x01
vpunpcklqdq zmm12, zmm24, zmm25
vpunpckhqdq zmm13, zmm24, zmm25
vmovdqu32 ymm24, ymmword ptr [rdi+0x6*0x40+0x20]
vinserti64x4 zmm24, zmm24, ymmword ptr [rdi+0xe*0x40+0x20], 0x01
vmovdqu32 ymm25, ymmword ptr [rdi+0x7*0x40+0x20]
vinserti64x4 zmm25, zmm25, ymmword ptr [rdi+0xf*0x40+0x20], 0x01
vpunpcklqdq zmm14, zmm24, zmm25
vpunpckhqdq zmm15, zmm24, zmm25
vshufps zmm24, zmm8, zmm10, 136
vshufps zmm30, zmm12, zmm14, 136
vmovdqa32 zmm28, zmm24
vpermt2d zmm24, zmm27, zmm30
vpermt2d zmm28, zmm31, zmm30
vshufps zmm25, zmm8, zmm10, 221
vshufps zmm30, zmm12, zmm14, 221
vmovdqa32 zmm29, zmm25
vpermt2d zmm25, zmm27, zmm30
vpermt2d zmm29, zmm31, zmm30
vshufps zmm26, zmm9, zmm11, 136
vshufps zmm8, zmm13, zmm15, 136
vmovdqa32 zmm30, zmm26
vpermt2d zmm26, zmm27, zmm8
vpermt2d zmm30, zmm31, zmm8
vshufps zmm8, zmm9, zmm11, 221
vshufps zmm10, zmm13, zmm15, 221
vpermi2d zmm27, zmm8, zmm10
vpermi2d zmm31, zmm8, zmm10
// broadcast the key
vpbroadcastd zmm0,DWORD PTR [rdx]
vpbroadcastd zmm1,DWORD PTR [rdx+0x4]
vpbroadcastd zmm2,DWORD PTR [rdx+0x8]
vpbroadcastd zmm3,DWORD PTR [rdx+0xc]
vpbroadcastd zmm4,DWORD PTR [rdx+0x10]
vpbroadcastd zmm5,DWORD PTR [rdx+0x14]
vpbroadcastd zmm6,DWORD PTR [rdx+0x18]
vpbroadcastd zmm7,DWORD PTR [rdx+0x1c]
// increment and broadcast the counter
vpbroadcastd zmm14,ecx
mov rax, rcx
shr rax,0x20
vpbroadcastd zmm13,eax
vpaddd zmm12,zmm14,ZMMWORD PTR [ADD0+rip]
vpcmpltud k1,zmm12,zmm14
vpaddd zmm13{k1},zmm13,DWORD PTR [ADD1+rip]{1to16}
// broadcast the block length
mov eax, 64
vpbroadcastd zmm14, eax
// broadcast the flags (always CHUNK_START|CHUNK_END|ROOT|KEYED_HASH = 27)
mov eax, 0b11011
vpbroadcastd zmm15, eax
// execute the kernel
call blake3_guts_avx512_kernel_16
// fold the first four words of the state (the rest are unused)
vpxord zmm0, zmm0, zmm8
vpxord zmm1, zmm1, zmm9
vpxord zmm2, zmm2, zmm10
vpxord zmm3, zmm3, zmm11
// xor-reduce zmm0-3
vmovdqa ymm4, ymm0
vextracti64x4 ymm0, zmm0, 0x1
vpxor ymm4, ymm4, ymm0
vmovdqa xmm0, xmm4
vextracti128 xmm4, ymm4, 0x1
vpxor xmm0, xmm0, xmm4
vpsrldq xmm4, xmm0, 8
vpxor xmm0, xmm0, xmm4
vmovdqa ymm4, ymm1
vextracti64x4 ymm1, zmm1, 0x1
vpxor ymm1, ymm4, ymm1
vmovdqa xmm4, xmm1
vextracti128 xmm1, ymm1, 0x1
vpxor xmm4, xmm4, xmm1
vpsrldq xmm1, xmm4, 8
vpxor xmm4, xmm4, xmm1
vmovdqa ymm1, ymm2
vextracti64x4 ymm2, zmm2, 0x1
vpxor ymm2, ymm1, ymm2
vmovdqa xmm1, xmm2
vextracti128 xmm2, ymm2, 0x1
vpxor xmm1, xmm1, xmm2
vpsrldq xmm2, xmm1, 8
vpxor xmm1, xmm1, xmm2
vextracti64x4 ymm2, zmm3, 0x1
vpxor ymm3, ymm2, ymm3
vextracti128 xmm2, ymm3, 0x1
vpxor xmm2, xmm2, xmm3
vpsrldq xmm3, xmm2, 8
vpxor xmm2, xmm2, xmm3
vpsrldq xmm3, xmm0, 4
vpxor xmm0, xmm0, xmm3
vpsrldq xmm3, xmm4, 4
vpxor xmm4, xmm4, xmm3
vpsrldq xmm3, xmm1, 4
vpxor xmm1, xmm1, xmm3
vpsrldq xmm3, xmm2, 4
vmovd edx, xmm4
vpxor xmm2, xmm2, xmm3
vpinsrd xmm0, xmm0, edx, 1
vmovd eax, xmm2
vpinsrd xmm1, xmm1, eax, 1
vpunpcklqdq xmm0, xmm0, xmm1
vmovdqu XMMWORD PTR [r8], xmm0
ret
#ifdef __APPLE__
.static_data
#else

View File

@ -50,6 +50,13 @@ extern "C" {
flags: u32,
out: *mut u8,
);
fn blake3_guts_avx512_universal_hash_16_exact(
input: *const u8,
input_len: usize,
key: *const CVBytes,
counter: u64,
out: *mut [u8; 16],
);
}
unsafe extern "C" fn hash_chunks(
@ -172,6 +179,11 @@ unsafe extern "C" fn universal_hash(
counter: u64,
out: *mut [u8; 16],
) {
debug_assert!(input_len <= 16 * BLOCK_LEN);
if input_len == 16 * BLOCK_LEN {
blake3_guts_avx512_universal_hash_16_exact(input, input_len, key, counter, out);
return;
}
crate::universal_hash_using_compress(
blake3_guts_avx512_compress,
input,

View File

@ -142,6 +142,8 @@ impl Implementation {
pub fn degree(&self) -> usize {
let degree = self.degree_fn()();
debug_assert!(degree >= 2);
debug_assert!(degree <= MAX_SIMD_DEGREE);
debug_assert_eq!(1, degree.count_ones(), "power of 2");
degree
}
@ -333,12 +335,23 @@ impl Implementation {
}
#[inline]
pub fn universal_hash(&self, input: &[u8], key: &CVBytes, counter: u64) -> [u8; 16] {
let mut out = [0u8; 16];
unsafe {
self.universal_hash_fn()(input.as_ptr(), input.len(), key, counter, &mut out);
pub fn universal_hash(&self, mut input: &[u8], key: &CVBytes, mut counter: u64) -> [u8; 16] {
let degree = self.degree();
let simd_len = degree * BLOCK_LEN;
let mut ret = [0u8; 16];
while !input.is_empty() {
let take = cmp::min(simd_len, input.len());
let mut output = [0u8; 16];
unsafe {
self.universal_hash_fn()(input.as_ptr(), take, key, counter, &mut output);
}
input = &input[take..];
counter += degree as u64;
for byte_index in 0..16 {
ret[byte_index] ^= output[byte_index];
}
}
out
ret
}
}