from io import BytesIO from typing import List, Tuple import shamir import utils from error import * CHUNK_SIZE = 64 def pad(plain: bytes) -> bytes: size = len(plain) rem = CHUNK_SIZE - size % CHUNK_SIZE return plain + bytes([rem] * rem) def unpad(padded: bytes) -> bytes: assert len(padded) > 0, f"Invalid padding. Length: {len(padded)}" rem = padded[-1] assert padded[-rem:] == bytes([rem] * rem), \ f"Invalid padding. Remainder: {rem} Pad: {list(padded[-rem:])}" return padded[:-rem] def secret_to_parts(secret: bytes, k: int, n: int) -> List[bytes]: assert len(secret) == CHUNK_SIZE assert 0 < k <= n secret_num = utils.bytes_to_long(secret, CHUNK_SIZE) parts = shamir.secret_to_parts(secret_num, k, n) return [utils.long_to_bytes(part, CHUNK_SIZE) for part in parts] def parts_to_secret(parts: List[Tuple[int, bytes]]) -> bytes: assert len(set(part_id for part_id, _ in parts)) == len(parts) assert all(len(buffer) == CHUNK_SIZE for _, buffer in parts) parts_num = [(part_id, utils.bytes_to_long(part, CHUNK_SIZE)) for part_id, part in parts] secret_num = shamir.parts_to_secret(parts_num) return utils.long_to_bytes(secret_num, CHUNK_SIZE) def part_decode(part: bytes) -> Tuple[int, int, bytes]: """ @return: Part Id, Expected num parts (k), Sequence of bytes """ divisions = [] pointer = 0 while pointer < len(part) and len(divisions) < 2: if part[pointer] == b'-'[0]: divisions.append(pointer) pointer += 1 assert len(divisions) == 2 a, b = divisions part_id = int(part[:a].decode()) k = int(part[a + 1:b].decode()) buffer = part[b + 1:] return part_id, k, buffer def shamir_encode(plain: bytes, k: int, n: int) -> List[bytes]: plain = pad(plain) parts = [BytesIO() for _ in range(n)] for i, part in enumerate(parts): part.write(f'{i+1}-{k}-'.encode()) for offset in range(0, len(plain), CHUNK_SIZE): partial_parts = secret_to_parts(plain[offset:offset + CHUNK_SIZE], k, n) for part, partial_part in zip(parts, partial_parts): part.write(partial_part) return [part.getvalue() for part in parts] def shamir_decode(parts: List[bytes]) -> bytes: parts_dic = {} expected_parts = None expected_size = None first_iteration = True for part_id, cur_expected_parts, buffer in map(part_decode, parts): assert 1 <= part_id < shamir.PRIME if first_iteration: assert len(buffer) > 0 assert len(buffer) % CHUNK_SIZE == 0 expected_parts = cur_expected_parts expected_size = len(buffer) first_iteration = False parts_dic[part_id] = buffer else: assert expected_parts == cur_expected_parts assert expected_size == len(buffer) if part_id in parts_dic: assert parts_dic[part_id] == buffer else: parts_dic[part_id] = buffer if len(parts_dic) < expected_parts: raise InsufficientParts() buffer = BytesIO() for i in range(0, expected_size, CHUNK_SIZE): parts_chunk = [(part_id, part[i:i+CHUNK_SIZE]) for part_id, part in parts_dic.items()] secret_chunk = parts_to_secret(parts_chunk) buffer.write(secret_chunk) secret = buffer.getvalue() return unpad(secret)