#! /usr/bin/env python3 # This is a Python port of the Rust reference implementation of BLAKE3: # https://github.com/BLAKE3-team/BLAKE3/blob/master/reference_impl/reference_impl.rs from __future__ import annotations from dataclasses import dataclass OUT_LEN = 32 KEY_LEN = 32 BLOCK_LEN = 64 CHUNK_LEN = 1024 CHUNK_START = 1 << 0 CHUNK_END = 1 << 1 PARENT = 1 << 2 ROOT = 1 << 3 KEYED_HASH = 1 << 4 DERIVE_KEY_CONTEXT = 1 << 5 DERIVE_KEY_MATERIAL = 1 << 6 IV = [ 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, ] MSG_PERMUTATION = [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8] def mask32(x: int) -> int: return x & 0xFFFFFFFF def add32(x: int, y: int) -> int: return mask32(x + y) def rightrotate32(x: int, n: int) -> int: return mask32(x << (32 - n)) | (x >> n) # The mixing function, G, which mixes either a column or a diagonal. def g(state: list[int], a: int, b: int, c: int, d: int, mx: int, my: int) -> None: state[a] = add32(state[a], add32(state[b], mx)) state[d] = rightrotate32(state[d] ^ state[a], 16) state[c] = add32(state[c], state[d]) state[b] = rightrotate32(state[b] ^ state[c], 12) state[a] = add32(state[a], add32(state[b], my)) state[d] = rightrotate32(state[d] ^ state[a], 8) state[c] = add32(state[c], state[d]) state[b] = rightrotate32(state[b] ^ state[c], 7) def round(state: list[int], m: list[int]) -> None: # Mix the columns. g(state, 0, 4, 8, 12, m[0], m[1]) g(state, 1, 5, 9, 13, m[2], m[3]) g(state, 2, 6, 10, 14, m[4], m[5]) g(state, 3, 7, 11, 15, m[6], m[7]) # Mix the diagonals. g(state, 0, 5, 10, 15, m[8], m[9]) g(state, 1, 6, 11, 12, m[10], m[11]) g(state, 2, 7, 8, 13, m[12], m[13]) g(state, 3, 4, 9, 14, m[14], m[15]) def permute(m: list[int]) -> None: original = list(m) for i in range(16): m[i] = original[MSG_PERMUTATION[i]] def compress( chaining_value: list[int], block_words: list[int], counter: int, block_len: int, flags: int, ) -> list[int]: state = [ chaining_value[0], chaining_value[1], chaining_value[2], chaining_value[3], chaining_value[4], chaining_value[5], chaining_value[6], chaining_value[7], IV[0], IV[1], IV[2], IV[3], mask32(counter), mask32(counter >> 32), block_len, flags, ] assert len(block_words) == 16 block = list(block_words) round(state, block) # round 1 permute(block) round(state, block) # round 2 permute(block) round(state, block) # round 3 permute(block) round(state, block) # round 4 permute(block) round(state, block) # round 5 permute(block) round(state, block) # round 6 permute(block) round(state, block) # round 7 for i in range(8): state[i] ^= state[i + 8] state[i + 8] ^= chaining_value[i] return state def words_from_little_endian_bytes(b: bytes) -> list[int]: assert len(b) % 4 == 0 return [int.from_bytes(b[i : i + 4], "little") for i in range(0, len(b), 4)] # Each chunk or parent node can produce either an 8-word chaining value or, by # setting the ROOT flag, any number of final output bytes. The Output struct # captures the state just prior to choosing between those two possibilities. @dataclass class Output: input_chaining_value: list[int] block_words: list[int] counter: int block_len: int flags: int def chaining_value(self) -> list[int]: return compress( self.input_chaining_value, self.block_words, self.counter, self.block_len, self.flags, )[:8] def root_output_bytes(self, length: int) -> bytes: output_bytes = bytearray() i = 0 while i < length: words = compress( self.input_chaining_value, self.block_words, i // 64, self.block_len, self.flags | ROOT, ) # The output length might not be a multiple of 4. for word in words: word_bytes = word.to_bytes(4, "little") take = min(len(word_bytes), length - i) output_bytes.extend(word_bytes[:take]) i += take return output_bytes @dataclass class ChunkState: chaining_value: list[int] chunk_counter: int block: bytearray block_len: int blocks_compressed: int flags: int def __init__(self, key_words: list[int], chunk_counter: int, flags: int) -> None: self.chaining_value = key_words self.chunk_counter = chunk_counter self.block = bytearray(BLOCK_LEN) self.block_len = 0 self.blocks_compressed = 0 self.flags = flags def len(self) -> int: return BLOCK_LEN * self.blocks_compressed + self.block_len def start_flag(self) -> int: if self.blocks_compressed == 0: return CHUNK_START else: return 0 def update(self, input_bytes: bytes) -> None: while input_bytes: # If the block buffer is full, compress it and clear it. More # input_bytes is coming, so this compression is not CHUNK_END. if self.block_len == BLOCK_LEN: block_words = words_from_little_endian_bytes(self.block) self.chaining_value = compress( self.chaining_value, block_words, self.chunk_counter, BLOCK_LEN, self.flags | self.start_flag(), )[:8] self.blocks_compressed += 1 self.block = bytearray(BLOCK_LEN) self.block_len = 0 # Copy input bytes into the block buffer. want = BLOCK_LEN - self.block_len take = min(want, len(input_bytes)) self.block[self.block_len : self.block_len + take] = input_bytes[:take] self.block_len += take input_bytes = input_bytes[take:] def output(self) -> Output: block_words = words_from_little_endian_bytes(self.block) return Output( self.chaining_value, block_words, self.chunk_counter, self.block_len, self.flags | self.start_flag() | CHUNK_END, ) def parent_output( left_child_cv: list[int], right_child_cv: list[int], key_words: list[int], flags: int, ) -> Output: return Output( key_words, left_child_cv + right_child_cv, 0, BLOCK_LEN, PARENT | flags ) def parent_cv( left_child_cv: list[int], right_child_cv: list[int], key_words: list[int], flags: int, ) -> list[int]: return parent_output( left_child_cv, right_child_cv, key_words, flags ).chaining_value() # An incremental hasher that can accept any number of writes. @dataclass class Hasher: chunk_state: ChunkState key_words: list[int] cv_stack: list[list[int]] flags: int def _init(self, key_words: list[int], flags: int) -> None: assert len(key_words) == 8 self.chunk_state = ChunkState(key_words, 0, flags) self.key_words = key_words self.cv_stack = [] self.flags = flags # Construct a new `Hasher` for the regular hash function. def __init__(self) -> None: self._init(IV, 0) # Construct a new `Hasher` for the keyed hash function. @classmethod def new_keyed(cls, key: bytes) -> Hasher: keyed_hasher = cls() key_words = words_from_little_endian_bytes(key) keyed_hasher._init(key_words, KEYED_HASH) return keyed_hasher # Construct a new `Hasher` for the key derivation function. The context # string should be hardcoded, globally unique, and application-specific. @classmethod def new_derive_key(cls, context: str) -> Hasher: context_hasher = cls() context_hasher._init(IV, DERIVE_KEY_CONTEXT) context_hasher.update(context.encode("utf8")) context_key = context_hasher.finalize(KEY_LEN) context_key_words = words_from_little_endian_bytes(context_key) derive_key_hasher = cls() derive_key_hasher._init(context_key_words, DERIVE_KEY_MATERIAL) return derive_key_hasher # Section 5.1.2 of the BLAKE3 spec explains this algorithm in more detail. def add_chunk_chaining_value(self, new_cv: list[int], total_chunks: int) -> None: # This chunk might complete some subtrees. For each completed subtree, # its left child will be the current top entry in the CV stack, and # its right child will be the current value of `new_cv`. Pop each left # child off the stack, merge it with `new_cv`, and overwrite `new_cv` # with the result. After all these merges, push the final value of # `new_cv` onto the stack. The number of completed subtrees is given # by the number of trailing 0-bits in the new total number of chunks. while total_chunks & 1 == 0: new_cv = parent_cv(self.cv_stack.pop(), new_cv, self.key_words, self.flags) total_chunks >>= 1 self.cv_stack.append(new_cv) # Add input to the hash state. This can be called any number of times. def update(self, input_bytes: bytes) -> None: while input_bytes: # If the current chunk is complete, finalize it and reset the # chunk state. More input is coming, so this chunk is not ROOT. if self.chunk_state.len() == CHUNK_LEN: chunk_cv = self.chunk_state.output().chaining_value() total_chunks = self.chunk_state.chunk_counter + 1 self.add_chunk_chaining_value(chunk_cv, total_chunks) self.chunk_state = ChunkState(self.key_words, total_chunks, self.flags) # Compress input bytes into the current chunk state. want = CHUNK_LEN - self.chunk_state.len() take = min(want, len(input_bytes)) self.chunk_state.update(input_bytes[:take]) input_bytes = input_bytes[take:] # Finalize the hash and write any number of output bytes. def finalize(self, length: int = OUT_LEN) -> bytes: # Starting with the Output from the current chunk, compute all the # parent chaining values along the right edge of the tree, until we # have the root Output. output = self.chunk_state.output() parent_nodes_remaining = len(self.cv_stack) while parent_nodes_remaining > 0: parent_nodes_remaining -= 1 output = parent_output( self.cv_stack[parent_nodes_remaining], output.chaining_value(), self.key_words, self.flags, ) return output.root_output_bytes(length) # If this file is executed directly, hash standard input. if __name__ == "__main__": import sys hasher = Hasher() while buf := sys.stdin.buffer.read(65536): hasher.update(buf) print(hasher.finalize().hex())