mirror of
https://github.com/oconnor663/bao
synced 2025-02-21 23:21:05 +01:00
397 lines
13 KiB
Python
397 lines
13 KiB
Python
#! /usr/bin/env python3
|
|
|
|
# Run this file using pytest, either in this folder or at the root of the
|
|
# project. Since test_vectors.json is generated from bao.py, it's slightly
|
|
# cheating to then test bao.py against its own output. But at least this helps
|
|
# us notice changes, since the vectors are checked in rather than generated
|
|
# every time. Testing the Rust implementation against the same test vectors
|
|
# gives us some confidence that they're correct.
|
|
|
|
from binascii import hexlify, unhexlify
|
|
import io
|
|
import json
|
|
from pathlib import Path
|
|
import subprocess
|
|
import tempfile
|
|
|
|
# Imports from this directory.
|
|
import bao
|
|
import generate_input
|
|
|
|
HERE = Path(__file__).parent
|
|
BAO_PATH = HERE / "bao.py"
|
|
VECTORS_PATH = HERE / "test_vectors.json"
|
|
VECTORS = json.load(VECTORS_PATH.open())
|
|
|
|
# Wrapper functions
|
|
# =================
|
|
#
|
|
# Most of the functions in bao.py (except bao_encode) work with streams. These
|
|
# wrappers work with bytes, and return hashes as strings, which makes them
|
|
# easier to test.
|
|
|
|
|
|
def bao_hash(content):
|
|
return hexlify(bao.bao_hash(io.BytesIO(content))).decode("utf-8")
|
|
|
|
|
|
def blake3(b):
|
|
return bao_hash(b)
|
|
|
|
|
|
def bao_encode(content):
|
|
# Note that unlike the other functions, this one already takes bytes.
|
|
encoded, hash_ = bao.bao_encode(content, outboard=False)
|
|
return encoded, hash_.hex()
|
|
|
|
|
|
def bao_encode_outboard(content):
|
|
# Note that unlike the other functions, this one already takes bytes.
|
|
outboard, hash_ = bao.bao_encode(content, outboard=True)
|
|
return outboard, hash_.hex()
|
|
|
|
|
|
def bao_decode(hash, encoded):
|
|
hashbytes = unhexlify(hash)
|
|
output = io.BytesIO()
|
|
bao.bao_decode(io.BytesIO(encoded), output, hashbytes)
|
|
return output.getvalue()
|
|
|
|
|
|
def bao_decode_outboard(hash, content, outboard):
|
|
hashbytes = unhexlify(hash)
|
|
output = io.BytesIO()
|
|
bao.bao_decode(
|
|
io.BytesIO(content), output, hashbytes, outboard_stream=io.BytesIO(outboard)
|
|
)
|
|
return output.getvalue()
|
|
|
|
|
|
def bao_slice(encoded, slice_start, slice_len):
|
|
output = io.BytesIO()
|
|
bao.bao_slice(io.BytesIO(encoded), output, slice_start, slice_len)
|
|
return output.getvalue()
|
|
|
|
|
|
def bao_slice_outboard(content, outboard, slice_start, slice_len):
|
|
output = io.BytesIO()
|
|
bao.bao_slice(
|
|
io.BytesIO(content),
|
|
output,
|
|
slice_start,
|
|
slice_len,
|
|
outboard_stream=io.BytesIO(outboard),
|
|
)
|
|
return output.getvalue()
|
|
|
|
|
|
def bao_decode_slice(slice_bytes, hash, slice_start, slice_len):
|
|
hashbytes = unhexlify(hash)
|
|
output = io.BytesIO()
|
|
bao.bao_decode_slice(
|
|
io.BytesIO(slice_bytes), output, hashbytes, slice_start, slice_len
|
|
)
|
|
return output.getvalue()
|
|
|
|
|
|
# Tests
|
|
# =====
|
|
|
|
|
|
def test_hashes():
|
|
for case in VECTORS["hash"]:
|
|
input_len = case["input_len"]
|
|
input_bytes = generate_input.input_bytes(input_len)
|
|
expected_hash = case["bao_hash"]
|
|
|
|
computed_hash = bao_hash(input_bytes)
|
|
assert expected_hash == computed_hash
|
|
|
|
|
|
def bao_cli(*args, input=None, should_fail=False):
|
|
output = subprocess.run(
|
|
["python3", str(BAO_PATH), *args],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.DEVNULL if should_fail else None,
|
|
input=input,
|
|
)
|
|
cmd = " ".join(["bao.py"] + list(args))
|
|
if should_fail:
|
|
assert output.returncode != 0, "`{}` should've failed".format(cmd)
|
|
else:
|
|
assert output.returncode == 0, "`{}` failed".format(cmd)
|
|
return output.stdout
|
|
|
|
|
|
def test_hash_cli():
|
|
# CLI tests just use the final (largest) test vector in each set, to avoid
|
|
# shelling out hundreds of times. There's no need to exhaustively test the
|
|
# implementation via the CLI, because it's tested on its own above.
|
|
# Instead, we just need to verify once that it's hooked up properly.
|
|
case = VECTORS["hash"][-1]
|
|
input_len = case["input_len"]
|
|
input_bytes = generate_input.input_bytes(input_len)
|
|
expected_hash = case["bao_hash"]
|
|
|
|
computed_hash = bao_cli("hash", input=input_bytes).decode().strip()
|
|
assert expected_hash == computed_hash
|
|
|
|
|
|
def assert_decode_failure(f, *args):
|
|
try:
|
|
f(*args)
|
|
except (AssertionError, IOError):
|
|
pass
|
|
else:
|
|
raise AssertionError("failure expected, but no exception raised")
|
|
|
|
|
|
def test_encoded():
|
|
for case in VECTORS["encode"]:
|
|
input_len = case["input_len"]
|
|
input_bytes = generate_input.input_bytes(input_len)
|
|
output_len = case["output_len"]
|
|
expected_bao_hash = case["bao_hash"]
|
|
encoded_blake3 = case["encoded_blake3"]
|
|
corruptions = case["corruptions"]
|
|
|
|
# First make sure the encoded output is what it's supposed to be.
|
|
encoded, hash_ = bao_encode(input_bytes)
|
|
assert expected_bao_hash == hash_
|
|
assert output_len == len(encoded)
|
|
assert encoded_blake3 == blake3(encoded)
|
|
|
|
# Now test decoding.
|
|
output = bao_decode(hash_, encoded)
|
|
assert input_bytes == output
|
|
|
|
# Make sure decoding with the wrong hash fails.
|
|
wrong_hash = "0" * len(hash_)
|
|
assert_decode_failure(bao_decode, wrong_hash, encoded)
|
|
|
|
# Make sure each of the corruption points causes decoding to fail.
|
|
for c in corruptions:
|
|
corrupted = bytearray(encoded)
|
|
corrupted[c] ^= 1
|
|
assert_decode_failure(bao_decode, hash_, corrupted)
|
|
|
|
|
|
def make_tempfile(b=b""):
|
|
f = tempfile.NamedTemporaryFile()
|
|
f.write(b)
|
|
f.flush()
|
|
f.seek(0)
|
|
return f
|
|
|
|
|
|
def test_encoded_cli():
|
|
case = VECTORS["encode"][-1]
|
|
input_len = case["input_len"]
|
|
input_bytes = generate_input.input_bytes(input_len)
|
|
output_len = case["output_len"]
|
|
expected_bao_hash = case["bao_hash"]
|
|
encoded_blake3 = case["encoded_blake3"]
|
|
|
|
# First make sure the encoded output is what it's supposed to be.
|
|
input_file = make_tempfile(input_bytes)
|
|
encoded_file = make_tempfile()
|
|
bao_cli("encode", input_file.name, encoded_file.name)
|
|
encoded = encoded_file.read()
|
|
assert output_len == len(encoded)
|
|
assert encoded_blake3 == blake3(encoded)
|
|
|
|
# Now test decoding.
|
|
output = bao_cli("decode", expected_bao_hash, encoded_file.name)
|
|
assert input_bytes == output
|
|
|
|
# Make sure decoding with the wrong hash fails.
|
|
wrong_hash = "0" * len(expected_bao_hash)
|
|
bao_cli("decode", wrong_hash, encoded_file.name, should_fail=True)
|
|
|
|
|
|
def test_outboard():
|
|
for case in VECTORS["outboard"]:
|
|
input_len = case["input_len"]
|
|
input_bytes = generate_input.input_bytes(input_len)
|
|
output_len = case["output_len"]
|
|
expected_bao_hash = case["bao_hash"]
|
|
encoded_blake3 = case["encoded_blake3"]
|
|
outboard_corruptions = case["outboard_corruptions"]
|
|
input_corruptions = case["input_corruptions"]
|
|
|
|
# First make sure the encoded output is what it's supposed to be.
|
|
outboard, hash_ = bao_encode_outboard(input_bytes)
|
|
assert expected_bao_hash == hash_
|
|
assert output_len == len(outboard)
|
|
assert encoded_blake3 == blake3(outboard)
|
|
|
|
# Now test decoding.
|
|
output = bao_decode_outboard(hash_, input_bytes, outboard)
|
|
assert input_bytes == output
|
|
|
|
# Make sure decoding with the wrong hash fails.
|
|
wrong_hash = "0" * len(hash_)
|
|
assert_decode_failure(bao_decode_outboard, wrong_hash, input_bytes, outboard)
|
|
|
|
# Make sure each of the outboard corruption points causes decoding to
|
|
# fail.
|
|
for c in outboard_corruptions:
|
|
corrupted = bytearray(outboard)
|
|
corrupted[c] ^= 1
|
|
assert_decode_failure(bao_decode_outboard, hash_, input_bytes, corrupted)
|
|
|
|
# Make sure each of the input corruption points causes decoding to
|
|
# fail.
|
|
for c in input_corruptions:
|
|
corrupted = bytearray(input_bytes)
|
|
corrupted[c] ^= 1
|
|
assert_decode_failure(bao_decode_outboard, hash_, corrupted, outboard)
|
|
|
|
|
|
def test_outboard_cli():
|
|
case = VECTORS["outboard"][-1]
|
|
input_len = case["input_len"]
|
|
input_bytes = generate_input.input_bytes(input_len)
|
|
output_len = case["output_len"]
|
|
expected_bao_hash = case["bao_hash"]
|
|
encoded_blake3 = case["encoded_blake3"]
|
|
|
|
# First make sure the encoded output is what it's supposed to be.
|
|
input_file = make_tempfile(input_bytes)
|
|
outboard_file = make_tempfile()
|
|
bao_cli("encode", input_file.name, "--outboard", outboard_file.name)
|
|
outboard = outboard_file.read()
|
|
assert output_len == len(outboard)
|
|
assert encoded_blake3 == blake3(outboard)
|
|
|
|
# Now test decoding.
|
|
output = bao_cli(
|
|
"decode", expected_bao_hash, input_file.name, "--outboard", outboard_file.name
|
|
)
|
|
assert input_bytes == output
|
|
|
|
# Make sure decoding with the wrong hash fails.
|
|
wrong_hash = "0" * len(expected_bao_hash)
|
|
output = bao_cli(
|
|
"decode",
|
|
wrong_hash,
|
|
input_file.name,
|
|
"--outboard",
|
|
outboard_file.name,
|
|
should_fail=True,
|
|
)
|
|
|
|
|
|
def test_slices():
|
|
for case in VECTORS["slice"]:
|
|
input_len = case["input_len"]
|
|
input_bytes = generate_input.input_bytes(input_len)
|
|
expected_bao_hash = case["bao_hash"]
|
|
slices = case["slices"]
|
|
|
|
encoded, hash_ = bao_encode(input_bytes)
|
|
outboard, hash_outboard = bao_encode_outboard(input_bytes)
|
|
assert expected_bao_hash == hash_
|
|
assert expected_bao_hash == hash_outboard
|
|
|
|
for slice_case in slices:
|
|
slice_start = slice_case["start"]
|
|
slice_len = slice_case["len"]
|
|
output_len = slice_case["output_len"]
|
|
output_blake3 = slice_case["output_blake3"]
|
|
corruptions = slice_case["corruptions"]
|
|
|
|
# Make sure the slice output is what it should be.
|
|
slice_bytes = bao_slice(encoded, slice_start, slice_len)
|
|
assert output_len == len(slice_bytes)
|
|
assert output_blake3 == blake3(slice_bytes)
|
|
|
|
# Make sure slicing an outboard tree is the same.
|
|
outboard_slice_bytes = bao_slice_outboard(
|
|
input_bytes, outboard, slice_start, slice_len
|
|
)
|
|
assert slice_bytes == outboard_slice_bytes
|
|
|
|
# Test decoding the slice, and compare it to the input. Note that
|
|
# slicing a byte array in Python allows indices past the end of the
|
|
# array, and sort of silently caps them.
|
|
input_slice = input_bytes[slice_start:][:slice_len]
|
|
output = bao_decode_slice(slice_bytes, hash_, slice_start, slice_len)
|
|
assert input_slice == output
|
|
|
|
# Make sure decoding with the wrong hash fails.
|
|
wrong_hash = "0" * len(hash_)
|
|
assert_decode_failure(
|
|
bao_decode_slice, slice_bytes, wrong_hash, slice_start, slice_len
|
|
)
|
|
|
|
# Make sure each of the slice corruption points causes decoding to
|
|
# fail.
|
|
for c in corruptions:
|
|
corrupted = bytearray(slice_bytes)
|
|
corrupted[c] ^= 1
|
|
assert_decode_failure(
|
|
bao_decode_slice, corrupted, hash_, slice_start, slice_len
|
|
)
|
|
|
|
|
|
def test_slices_cli():
|
|
case = VECTORS["slice"][-1]
|
|
input_len = case["input_len"]
|
|
input_bytes = generate_input.input_bytes(input_len)
|
|
expected_bao_hash = case["bao_hash"]
|
|
slices = case["slices"]
|
|
|
|
input_file = make_tempfile(input_bytes)
|
|
encoded_file = make_tempfile()
|
|
bao_cli("encode", input_file.name, encoded_file.name)
|
|
outboard_file = make_tempfile()
|
|
bao_cli("encode", input_file.name, "--outboard", outboard_file.name)
|
|
|
|
# Use the first slice in the list. Currently they're all the same length.
|
|
slice_case = slices[0]
|
|
slice_start = slice_case["start"]
|
|
slice_len = slice_case["len"]
|
|
output_len = slice_case["output_len"]
|
|
output_blake3 = slice_case["output_blake3"]
|
|
|
|
# Make sure the slice output is what it should be.
|
|
slice_bytes = bao_cli("slice", str(slice_start), str(slice_len), encoded_file.name)
|
|
assert output_len == len(slice_bytes)
|
|
assert output_blake3 == blake3(slice_bytes)
|
|
|
|
# Make sure slicing an outboard tree is the same.
|
|
outboard_slice_bytes = bao_cli(
|
|
"slice",
|
|
str(slice_start),
|
|
str(slice_len),
|
|
input_file.name,
|
|
"--outboard",
|
|
outboard_file.name,
|
|
)
|
|
assert slice_bytes == outboard_slice_bytes
|
|
|
|
# Test decoding the slice, and compare it to the input. Note that
|
|
# slicing a byte array in Python allows indices past the end of the
|
|
# array, and sort of silently caps them.
|
|
input_slice = input_bytes[slice_start:][:slice_len]
|
|
output = bao_cli(
|
|
"decode-slice",
|
|
expected_bao_hash,
|
|
str(slice_start),
|
|
str(slice_len),
|
|
input=slice_bytes,
|
|
)
|
|
assert input_slice == output
|
|
|
|
# Make sure decoding with the wrong hash fails.
|
|
wrong_hash = "0" * len(expected_bao_hash)
|
|
bao_cli(
|
|
"decode-slice",
|
|
wrong_hash,
|
|
str(slice_start),
|
|
str(slice_len),
|
|
input=slice_bytes,
|
|
should_fail=True,
|
|
)
|