1
0
mirror of https://github.com/oconnor663/bao synced 2025-02-21 23:21:05 +01:00
bao/tests/test_bao.py
Jack O'Connor 431882d657 reformat Python files with Black
tests/docopt.py is vendored from upstream, so leave that as-is.
2022-07-11 20:31:25 -04:00

397 lines
13 KiB
Python

#! /usr/bin/env python3
# Run this file using pytest, either in this folder or at the root of the
# project. Since test_vectors.json is generated from bao.py, it's slightly
# cheating to then test bao.py against its own output. But at least this helps
# us notice changes, since the vectors are checked in rather than generated
# every time. Testing the Rust implementation against the same test vectors
# gives us some confidence that they're correct.
from binascii import hexlify, unhexlify
import io
import json
from pathlib import Path
import subprocess
import tempfile
# Imports from this directory.
import bao
import generate_input
HERE = Path(__file__).parent
BAO_PATH = HERE / "bao.py"
VECTORS_PATH = HERE / "test_vectors.json"
VECTORS = json.load(VECTORS_PATH.open())
# Wrapper functions
# =================
#
# Most of the functions in bao.py (except bao_encode) work with streams. These
# wrappers work with bytes, and return hashes as strings, which makes them
# easier to test.
def bao_hash(content):
return hexlify(bao.bao_hash(io.BytesIO(content))).decode("utf-8")
def blake3(b):
return bao_hash(b)
def bao_encode(content):
# Note that unlike the other functions, this one already takes bytes.
encoded, hash_ = bao.bao_encode(content, outboard=False)
return encoded, hash_.hex()
def bao_encode_outboard(content):
# Note that unlike the other functions, this one already takes bytes.
outboard, hash_ = bao.bao_encode(content, outboard=True)
return outboard, hash_.hex()
def bao_decode(hash, encoded):
hashbytes = unhexlify(hash)
output = io.BytesIO()
bao.bao_decode(io.BytesIO(encoded), output, hashbytes)
return output.getvalue()
def bao_decode_outboard(hash, content, outboard):
hashbytes = unhexlify(hash)
output = io.BytesIO()
bao.bao_decode(
io.BytesIO(content), output, hashbytes, outboard_stream=io.BytesIO(outboard)
)
return output.getvalue()
def bao_slice(encoded, slice_start, slice_len):
output = io.BytesIO()
bao.bao_slice(io.BytesIO(encoded), output, slice_start, slice_len)
return output.getvalue()
def bao_slice_outboard(content, outboard, slice_start, slice_len):
output = io.BytesIO()
bao.bao_slice(
io.BytesIO(content),
output,
slice_start,
slice_len,
outboard_stream=io.BytesIO(outboard),
)
return output.getvalue()
def bao_decode_slice(slice_bytes, hash, slice_start, slice_len):
hashbytes = unhexlify(hash)
output = io.BytesIO()
bao.bao_decode_slice(
io.BytesIO(slice_bytes), output, hashbytes, slice_start, slice_len
)
return output.getvalue()
# Tests
# =====
def test_hashes():
for case in VECTORS["hash"]:
input_len = case["input_len"]
input_bytes = generate_input.input_bytes(input_len)
expected_hash = case["bao_hash"]
computed_hash = bao_hash(input_bytes)
assert expected_hash == computed_hash
def bao_cli(*args, input=None, should_fail=False):
output = subprocess.run(
["python3", str(BAO_PATH), *args],
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL if should_fail else None,
input=input,
)
cmd = " ".join(["bao.py"] + list(args))
if should_fail:
assert output.returncode != 0, "`{}` should've failed".format(cmd)
else:
assert output.returncode == 0, "`{}` failed".format(cmd)
return output.stdout
def test_hash_cli():
# CLI tests just use the final (largest) test vector in each set, to avoid
# shelling out hundreds of times. There's no need to exhaustively test the
# implementation via the CLI, because it's tested on its own above.
# Instead, we just need to verify once that it's hooked up properly.
case = VECTORS["hash"][-1]
input_len = case["input_len"]
input_bytes = generate_input.input_bytes(input_len)
expected_hash = case["bao_hash"]
computed_hash = bao_cli("hash", input=input_bytes).decode().strip()
assert expected_hash == computed_hash
def assert_decode_failure(f, *args):
try:
f(*args)
except (AssertionError, IOError):
pass
else:
raise AssertionError("failure expected, but no exception raised")
def test_encoded():
for case in VECTORS["encode"]:
input_len = case["input_len"]
input_bytes = generate_input.input_bytes(input_len)
output_len = case["output_len"]
expected_bao_hash = case["bao_hash"]
encoded_blake3 = case["encoded_blake3"]
corruptions = case["corruptions"]
# First make sure the encoded output is what it's supposed to be.
encoded, hash_ = bao_encode(input_bytes)
assert expected_bao_hash == hash_
assert output_len == len(encoded)
assert encoded_blake3 == blake3(encoded)
# Now test decoding.
output = bao_decode(hash_, encoded)
assert input_bytes == output
# Make sure decoding with the wrong hash fails.
wrong_hash = "0" * len(hash_)
assert_decode_failure(bao_decode, wrong_hash, encoded)
# Make sure each of the corruption points causes decoding to fail.
for c in corruptions:
corrupted = bytearray(encoded)
corrupted[c] ^= 1
assert_decode_failure(bao_decode, hash_, corrupted)
def make_tempfile(b=b""):
f = tempfile.NamedTemporaryFile()
f.write(b)
f.flush()
f.seek(0)
return f
def test_encoded_cli():
case = VECTORS["encode"][-1]
input_len = case["input_len"]
input_bytes = generate_input.input_bytes(input_len)
output_len = case["output_len"]
expected_bao_hash = case["bao_hash"]
encoded_blake3 = case["encoded_blake3"]
# First make sure the encoded output is what it's supposed to be.
input_file = make_tempfile(input_bytes)
encoded_file = make_tempfile()
bao_cli("encode", input_file.name, encoded_file.name)
encoded = encoded_file.read()
assert output_len == len(encoded)
assert encoded_blake3 == blake3(encoded)
# Now test decoding.
output = bao_cli("decode", expected_bao_hash, encoded_file.name)
assert input_bytes == output
# Make sure decoding with the wrong hash fails.
wrong_hash = "0" * len(expected_bao_hash)
bao_cli("decode", wrong_hash, encoded_file.name, should_fail=True)
def test_outboard():
for case in VECTORS["outboard"]:
input_len = case["input_len"]
input_bytes = generate_input.input_bytes(input_len)
output_len = case["output_len"]
expected_bao_hash = case["bao_hash"]
encoded_blake3 = case["encoded_blake3"]
outboard_corruptions = case["outboard_corruptions"]
input_corruptions = case["input_corruptions"]
# First make sure the encoded output is what it's supposed to be.
outboard, hash_ = bao_encode_outboard(input_bytes)
assert expected_bao_hash == hash_
assert output_len == len(outboard)
assert encoded_blake3 == blake3(outboard)
# Now test decoding.
output = bao_decode_outboard(hash_, input_bytes, outboard)
assert input_bytes == output
# Make sure decoding with the wrong hash fails.
wrong_hash = "0" * len(hash_)
assert_decode_failure(bao_decode_outboard, wrong_hash, input_bytes, outboard)
# Make sure each of the outboard corruption points causes decoding to
# fail.
for c in outboard_corruptions:
corrupted = bytearray(outboard)
corrupted[c] ^= 1
assert_decode_failure(bao_decode_outboard, hash_, input_bytes, corrupted)
# Make sure each of the input corruption points causes decoding to
# fail.
for c in input_corruptions:
corrupted = bytearray(input_bytes)
corrupted[c] ^= 1
assert_decode_failure(bao_decode_outboard, hash_, corrupted, outboard)
def test_outboard_cli():
case = VECTORS["outboard"][-1]
input_len = case["input_len"]
input_bytes = generate_input.input_bytes(input_len)
output_len = case["output_len"]
expected_bao_hash = case["bao_hash"]
encoded_blake3 = case["encoded_blake3"]
# First make sure the encoded output is what it's supposed to be.
input_file = make_tempfile(input_bytes)
outboard_file = make_tempfile()
bao_cli("encode", input_file.name, "--outboard", outboard_file.name)
outboard = outboard_file.read()
assert output_len == len(outboard)
assert encoded_blake3 == blake3(outboard)
# Now test decoding.
output = bao_cli(
"decode", expected_bao_hash, input_file.name, "--outboard", outboard_file.name
)
assert input_bytes == output
# Make sure decoding with the wrong hash fails.
wrong_hash = "0" * len(expected_bao_hash)
output = bao_cli(
"decode",
wrong_hash,
input_file.name,
"--outboard",
outboard_file.name,
should_fail=True,
)
def test_slices():
for case in VECTORS["slice"]:
input_len = case["input_len"]
input_bytes = generate_input.input_bytes(input_len)
expected_bao_hash = case["bao_hash"]
slices = case["slices"]
encoded, hash_ = bao_encode(input_bytes)
outboard, hash_outboard = bao_encode_outboard(input_bytes)
assert expected_bao_hash == hash_
assert expected_bao_hash == hash_outboard
for slice_case in slices:
slice_start = slice_case["start"]
slice_len = slice_case["len"]
output_len = slice_case["output_len"]
output_blake3 = slice_case["output_blake3"]
corruptions = slice_case["corruptions"]
# Make sure the slice output is what it should be.
slice_bytes = bao_slice(encoded, slice_start, slice_len)
assert output_len == len(slice_bytes)
assert output_blake3 == blake3(slice_bytes)
# Make sure slicing an outboard tree is the same.
outboard_slice_bytes = bao_slice_outboard(
input_bytes, outboard, slice_start, slice_len
)
assert slice_bytes == outboard_slice_bytes
# Test decoding the slice, and compare it to the input. Note that
# slicing a byte array in Python allows indices past the end of the
# array, and sort of silently caps them.
input_slice = input_bytes[slice_start:][:slice_len]
output = bao_decode_slice(slice_bytes, hash_, slice_start, slice_len)
assert input_slice == output
# Make sure decoding with the wrong hash fails.
wrong_hash = "0" * len(hash_)
assert_decode_failure(
bao_decode_slice, slice_bytes, wrong_hash, slice_start, slice_len
)
# Make sure each of the slice corruption points causes decoding to
# fail.
for c in corruptions:
corrupted = bytearray(slice_bytes)
corrupted[c] ^= 1
assert_decode_failure(
bao_decode_slice, corrupted, hash_, slice_start, slice_len
)
def test_slices_cli():
case = VECTORS["slice"][-1]
input_len = case["input_len"]
input_bytes = generate_input.input_bytes(input_len)
expected_bao_hash = case["bao_hash"]
slices = case["slices"]
input_file = make_tempfile(input_bytes)
encoded_file = make_tempfile()
bao_cli("encode", input_file.name, encoded_file.name)
outboard_file = make_tempfile()
bao_cli("encode", input_file.name, "--outboard", outboard_file.name)
# Use the first slice in the list. Currently they're all the same length.
slice_case = slices[0]
slice_start = slice_case["start"]
slice_len = slice_case["len"]
output_len = slice_case["output_len"]
output_blake3 = slice_case["output_blake3"]
# Make sure the slice output is what it should be.
slice_bytes = bao_cli("slice", str(slice_start), str(slice_len), encoded_file.name)
assert output_len == len(slice_bytes)
assert output_blake3 == blake3(slice_bytes)
# Make sure slicing an outboard tree is the same.
outboard_slice_bytes = bao_cli(
"slice",
str(slice_start),
str(slice_len),
input_file.name,
"--outboard",
outboard_file.name,
)
assert slice_bytes == outboard_slice_bytes
# Test decoding the slice, and compare it to the input. Note that
# slicing a byte array in Python allows indices past the end of the
# array, and sort of silently caps them.
input_slice = input_bytes[slice_start:][:slice_len]
output = bao_cli(
"decode-slice",
expected_bao_hash,
str(slice_start),
str(slice_len),
input=slice_bytes,
)
assert input_slice == output
# Make sure decoding with the wrong hash fails.
wrong_hash = "0" * len(expected_bao_hash)
bao_cli(
"decode-slice",
wrong_hash,
str(slice_start),
str(slice_len),
input=slice_bytes,
should_fail=True,
)