core.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. from io import BytesIO
  2. from typing import List, Tuple
  3. import shamir
  4. import utils
  5. from error import *
  6. CHUNK_SIZE = 64
  7. def pad(plain: bytes) -> bytes:
  8. size = len(plain)
  9. rem = CHUNK_SIZE - size % CHUNK_SIZE
  10. return plain + bytes([rem] * rem)
  11. def unpad(padded: bytes) -> bytes:
  12. assert len(padded) > 0, f"Invalid padding. Length: {len(padded)}"
  13. rem = padded[-1]
  14. assert padded[-rem:] == bytes([rem] * rem), \
  15. f"Invalid padding. Remainder: {rem} Pad: {list(padded[-rem:])}"
  16. return padded[:-rem]
  17. def secret_to_parts(secret: bytes, k: int, n: int) -> List[bytes]:
  18. assert len(secret) == CHUNK_SIZE
  19. assert 0 < k <= n
  20. secret_num = utils.bytes_to_long(secret, CHUNK_SIZE)
  21. parts = shamir.secret_to_parts(secret_num, k, n)
  22. return [utils.long_to_bytes(part, CHUNK_SIZE) for part in parts]
  23. def parts_to_secret(parts: List[Tuple[int, bytes]]) -> bytes:
  24. assert len(set(part_id for part_id, _ in parts)) == len(parts)
  25. assert all(len(buffer) == CHUNK_SIZE for _, buffer in parts)
  26. parts_num = [(part_id, utils.bytes_to_long(part, CHUNK_SIZE))
  27. for part_id, part in parts]
  28. secret_num = shamir.parts_to_secret(parts_num)
  29. return utils.long_to_bytes(secret_num, CHUNK_SIZE)
  30. def part_decode(part: bytes) -> Tuple[int, int, bytes]:
  31. """
  32. @return: Part Id, Expected num parts (k), Sequence of bytes
  33. """
  34. divisions = []
  35. pointer = 0
  36. while pointer < len(part) and len(divisions) < 2:
  37. if part[pointer] == b'-'[0]:
  38. divisions.append(pointer)
  39. pointer += 1
  40. assert len(divisions) == 2
  41. a, b = divisions
  42. part_id = int(part[:a].decode())
  43. k = int(part[a + 1:b].decode())
  44. buffer = part[b + 1:]
  45. return part_id, k, buffer
  46. def shamir_encode(plain: bytes, k: int, n: int) -> List[bytes]:
  47. plain = pad(plain)
  48. parts = [BytesIO() for _ in range(n)]
  49. for i, part in enumerate(parts):
  50. part.write(f'{i+1}-{k}-'.encode())
  51. for offset in range(0, len(plain), CHUNK_SIZE):
  52. partial_parts = secret_to_parts(plain[offset:offset + CHUNK_SIZE],
  53. k, n)
  54. for part, partial_part in zip(parts, partial_parts):
  55. part.write(partial_part)
  56. return [part.getvalue() for part in parts]
  57. def shamir_decode(parts: List[bytes]) -> bytes:
  58. parts_dic = {}
  59. expected_parts = None
  60. expected_size = None
  61. first_iteration = True
  62. for part_id, cur_expected_parts, buffer in map(part_decode, parts):
  63. assert 1 <= part_id < shamir.PRIME
  64. if first_iteration:
  65. assert len(buffer) > 0
  66. assert len(buffer) % CHUNK_SIZE == 0
  67. expected_parts = cur_expected_parts
  68. expected_size = len(buffer)
  69. first_iteration = False
  70. parts_dic[part_id] = buffer
  71. else:
  72. assert expected_parts == cur_expected_parts
  73. assert expected_size == len(buffer)
  74. if part_id in parts_dic:
  75. assert parts_dic[part_id] == buffer
  76. else:
  77. parts_dic[part_id] = buffer
  78. if len(parts_dic) < expected_parts:
  79. raise InsufficientParts()
  80. buffer = BytesIO()
  81. for i in range(0, expected_size, CHUNK_SIZE):
  82. parts_chunk = [(part_id, part[i:i+CHUNK_SIZE])
  83. for part_id, part in parts_dic.items()]
  84. secret_chunk = parts_to_secret(parts_chunk)
  85. buffer.write(secret_chunk)
  86. secret = buffer.getvalue()
  87. return unpad(secret)