|
@@ -0,0 +1,120 @@
|
|
|
+from io import BytesIO
|
|
|
+from typing import List, Tuple
|
|
|
+
|
|
|
+import shamir
|
|
|
+import utils
|
|
|
+from error import *
|
|
|
+
|
|
|
+CHUNK_SIZE = 64
|
|
|
+
|
|
|
+
|
|
|
+def pad(plain: bytes) -> bytes:
|
|
|
+ size = len(plain)
|
|
|
+ rem = CHUNK_SIZE - size % CHUNK_SIZE
|
|
|
+ return plain + bytes([rem] * rem)
|
|
|
+
|
|
|
+
|
|
|
+def unpad(padded: bytes) -> bytes:
|
|
|
+ assert len(padded) > 0, f"Invalid padding. Length: {len(padded)}"
|
|
|
+ rem = padded[-1]
|
|
|
+ assert padded[-rem:] == bytes([rem] * rem), \
|
|
|
+ f"Invalid padding. Remainder: {rem} Pad: {list(padded[-rem:])}"
|
|
|
+ return padded[:-rem]
|
|
|
+
|
|
|
+
|
|
|
+def secret_to_parts(secret: bytes, k: int, n: int) -> List[bytes]:
|
|
|
+ assert len(secret) == CHUNK_SIZE
|
|
|
+ assert 0 < k <= n
|
|
|
+ secret_num = utils.bytes_to_long(secret, CHUNK_SIZE)
|
|
|
+ parts = shamir.secret_to_parts(secret_num, k, n)
|
|
|
+ return [utils.long_to_bytes(part, CHUNK_SIZE) for part in parts]
|
|
|
+
|
|
|
+
|
|
|
+def parts_to_secret(parts: List[Tuple[int, bytes]]) -> bytes:
|
|
|
+ assert len(set(part_id for part_id, _ in parts)) == len(parts)
|
|
|
+ assert all(len(buffer) == CHUNK_SIZE for _, buffer in parts)
|
|
|
+ parts_num = [(part_id, utils.bytes_to_long(part, CHUNK_SIZE))
|
|
|
+ for part_id, part in parts]
|
|
|
+ secret_num = shamir.parts_to_secret(parts_num)
|
|
|
+ return utils.long_to_bytes(secret_num, CHUNK_SIZE)
|
|
|
+
|
|
|
+
|
|
|
+def part_decode(part: bytes) -> Tuple[int, int, bytes]:
|
|
|
+ """
|
|
|
+ @return: Part Id, Expected num parts (k), Sequence of bytes
|
|
|
+ """
|
|
|
+ divisions = []
|
|
|
+ pointer = 0
|
|
|
+
|
|
|
+ while pointer < len(part) and len(divisions) < 2:
|
|
|
+ if part[pointer] == b'-'[0]:
|
|
|
+ divisions.append(pointer)
|
|
|
+ pointer += 1
|
|
|
+
|
|
|
+ assert len(divisions) == 2
|
|
|
+
|
|
|
+ a, b = divisions
|
|
|
+
|
|
|
+ part_id = int(part[:a].decode())
|
|
|
+ k = int(part[a + 1:b].decode())
|
|
|
+ buffer = part[b + 1:]
|
|
|
+
|
|
|
+ return part_id, k, buffer
|
|
|
+
|
|
|
+
|
|
|
+def shamir_encode(plain: bytes, k: int, n: int) -> List[bytes]:
|
|
|
+ plain = pad(plain)
|
|
|
+
|
|
|
+ parts = [BytesIO() for _ in range(n)]
|
|
|
+
|
|
|
+ for i, part in enumerate(parts):
|
|
|
+ part.write(f'{i+1}-{k}-'.encode())
|
|
|
+
|
|
|
+ for offset in range(0, len(plain), CHUNK_SIZE):
|
|
|
+ partial_parts = secret_to_parts(plain[offset:offset + CHUNK_SIZE],
|
|
|
+ k, n)
|
|
|
+
|
|
|
+ for part, partial_part in zip(parts, partial_parts):
|
|
|
+ part.write(partial_part)
|
|
|
+
|
|
|
+ return [part.getvalue() for part in parts]
|
|
|
+
|
|
|
+
|
|
|
+def shamir_decode(parts: List[bytes]) -> bytes:
|
|
|
+ parts_dic = {}
|
|
|
+ expected_parts = None
|
|
|
+ expected_size = None
|
|
|
+ first_iteration = True
|
|
|
+
|
|
|
+ for part_id, cur_expected_parts, buffer in map(part_decode, parts):
|
|
|
+ assert 1 <= part_id < shamir.PRIME
|
|
|
+
|
|
|
+ if first_iteration:
|
|
|
+ assert len(buffer) > 0
|
|
|
+ assert len(buffer) % CHUNK_SIZE == 0
|
|
|
+ expected_parts = cur_expected_parts
|
|
|
+ expected_size = len(buffer)
|
|
|
+ first_iteration = False
|
|
|
+ parts_dic[part_id] = buffer
|
|
|
+ else:
|
|
|
+ assert expected_parts == cur_expected_parts
|
|
|
+ assert expected_size == len(buffer)
|
|
|
+
|
|
|
+ if part_id in parts_dic:
|
|
|
+ assert parts_dic[part_id] == buffer
|
|
|
+ else:
|
|
|
+ parts_dic[part_id] = buffer
|
|
|
+
|
|
|
+ if len(parts_dic) < expected_parts:
|
|
|
+ raise InsufficientParts()
|
|
|
+
|
|
|
+ buffer = BytesIO()
|
|
|
+
|
|
|
+ for i in range(0, expected_size, CHUNK_SIZE):
|
|
|
+ parts_chunk = [(part_id, part[i:i+CHUNK_SIZE])
|
|
|
+ for part_id, part in parts_dic.items()]
|
|
|
+ secret_chunk = parts_to_secret(parts_chunk)
|
|
|
+ buffer.write(secret_chunk)
|
|
|
+
|
|
|
+ secret = buffer.getvalue()
|
|
|
+ return unpad(secret)
|