123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- from io import BytesIO
- from typing import List, Tuple
- import shamir
- import utils
- from error import *
- CHUNK_SIZE = 64
- def pad(plain: bytes) -> bytes:
- size = len(plain)
- rem = CHUNK_SIZE - size % CHUNK_SIZE
- return plain + bytes([rem] * rem)
- def unpad(padded: bytes) -> bytes:
- assert len(padded) > 0, f"Invalid padding. Length: {len(padded)}"
- rem = padded[-1]
- assert padded[-rem:] == bytes([rem] * rem), \
- f"Invalid padding. Remainder: {rem} Pad: {list(padded[-rem:])}"
- return padded[:-rem]
- def secret_to_parts(secret: bytes, k: int, n: int) -> List[bytes]:
- assert len(secret) == CHUNK_SIZE
- assert 0 < k <= n
- secret_num = utils.bytes_to_long(secret, CHUNK_SIZE)
- parts = shamir.secret_to_parts(secret_num, k, n)
- return [utils.long_to_bytes(part, CHUNK_SIZE) for part in parts]
- def parts_to_secret(parts: List[Tuple[int, bytes]]) -> bytes:
- assert len(set(part_id for part_id, _ in parts)) == len(parts)
- assert all(len(buffer) == CHUNK_SIZE for _, buffer in parts)
- parts_num = [(part_id, utils.bytes_to_long(part, CHUNK_SIZE))
- for part_id, part in parts]
- secret_num = shamir.parts_to_secret(parts_num)
- return utils.long_to_bytes(secret_num, CHUNK_SIZE)
- def part_decode(part: bytes) -> Tuple[int, int, bytes]:
- """
- @return: Part Id, Expected num parts (k), Sequence of bytes
- """
- divisions = []
- pointer = 0
- while pointer < len(part) and len(divisions) < 2:
- if part[pointer] == b'-'[0]:
- divisions.append(pointer)
- pointer += 1
- assert len(divisions) == 2
- a, b = divisions
- part_id = int(part[:a].decode())
- k = int(part[a + 1:b].decode())
- buffer = part[b + 1:]
- return part_id, k, buffer
- def shamir_encode(plain: bytes, k: int, n: int) -> List[bytes]:
- plain = pad(plain)
- parts = [BytesIO() for _ in range(n)]
- for i, part in enumerate(parts):
- part.write(f'{i+1}-{k}-'.encode())
- for offset in range(0, len(plain), CHUNK_SIZE):
- partial_parts = secret_to_parts(plain[offset:offset + CHUNK_SIZE],
- k, n)
- for part, partial_part in zip(parts, partial_parts):
- part.write(partial_part)
- return [part.getvalue() for part in parts]
- def shamir_decode(parts: List[bytes]) -> bytes:
- parts_dic = {}
- expected_parts = None
- expected_size = None
- first_iteration = True
- for part_id, cur_expected_parts, buffer in map(part_decode, parts):
- assert 1 <= part_id < shamir.PRIME
- if first_iteration:
- assert len(buffer) > 0
- assert len(buffer) % CHUNK_SIZE == 0
- expected_parts = cur_expected_parts
- expected_size = len(buffer)
- first_iteration = False
- parts_dic[part_id] = buffer
- else:
- assert expected_parts == cur_expected_parts
- assert expected_size == len(buffer)
- if part_id in parts_dic:
- assert parts_dic[part_id] == buffer
- else:
- parts_dic[part_id] = buffer
- if len(parts_dic) < expected_parts:
- raise InsufficientParts()
- buffer = BytesIO()
- for i in range(0, expected_size, CHUNK_SIZE):
- parts_chunk = [(part_id, part[i:i+CHUNK_SIZE])
- for part_id, part in parts_dic.items()]
- secret_chunk = parts_to_secret(parts_chunk)
- buffer.write(secret_chunk)
- secret = buffer.getvalue()
- return unpad(secret)
|