Quellcode durchsuchen

Shamir Sharing Secret

Marcelo Fornet vor 4 Jahren
Commit
6183cf2ba0
7 geänderte Dateien mit 307 neuen und 0 gelöschten Zeilen
  1. 2 0
      .gitignore
  2. 120 0
      core.py
  3. 2 0
      error.py
  4. 51 0
      main.py
  5. 57 0
      shamir.py
  6. 52 0
      tests.py
  7. 23 0
      utils.py

+ 2 - 0
.gitignore

@@ -0,0 +1,2 @@
+__pycache__/
+.vscode/

+ 120 - 0
core.py

@@ -0,0 +1,120 @@
+from io import BytesIO
+from typing import List, Tuple
+
+import shamir
+import utils
+from error import *
+
+CHUNK_SIZE = 64
+
+
+def pad(plain: bytes) -> bytes:
+    size = len(plain)
+    rem = CHUNK_SIZE - size % CHUNK_SIZE
+    return plain + bytes([rem] * rem)
+
+
+def unpad(padded: bytes) -> bytes:
+    assert len(padded) > 0, f"Invalid padding. Length: {len(padded)}"
+    rem = padded[-1]
+    assert padded[-rem:] == bytes([rem] * rem), \
+        f"Invalid padding. Remainder: {rem} Pad: {list(padded[-rem:])}"
+    return padded[:-rem]
+
+
+def secret_to_parts(secret: bytes, k: int, n: int) -> List[bytes]:
+    assert len(secret) == CHUNK_SIZE
+    assert 0 < k <= n
+    secret_num = utils.bytes_to_long(secret, CHUNK_SIZE)
+    parts = shamir.secret_to_parts(secret_num, k, n)
+    return [utils.long_to_bytes(part, CHUNK_SIZE) for part in parts]
+
+
+def parts_to_secret(parts: List[Tuple[int, bytes]]) -> bytes:
+    assert len(set(part_id for part_id, _ in parts)) == len(parts)
+    assert all(len(buffer) == CHUNK_SIZE for _, buffer in parts)
+    parts_num = [(part_id, utils.bytes_to_long(part, CHUNK_SIZE))
+                 for part_id, part in parts]
+    secret_num = shamir.parts_to_secret(parts_num)
+    return utils.long_to_bytes(secret_num, CHUNK_SIZE)
+
+
+def part_decode(part: bytes) -> Tuple[int, int, bytes]:
+    """
+    @return: Part Id, Expected num parts (k), Sequence of bytes
+    """
+    divisions = []
+    pointer = 0
+
+    while pointer < len(part) and len(divisions) < 2:
+        if part[pointer] == b'-'[0]:
+            divisions.append(pointer)
+        pointer += 1
+
+    assert len(divisions) == 2
+
+    a, b = divisions
+
+    part_id = int(part[:a].decode())
+    k = int(part[a + 1:b].decode())
+    buffer = part[b + 1:]
+
+    return part_id, k, buffer
+
+
+def shamir_encode(plain: bytes, k: int, n: int) -> List[bytes]:
+    plain = pad(plain)
+
+    parts = [BytesIO() for _ in range(n)]
+
+    for i, part in enumerate(parts):
+        part.write(f'{i+1}-{k}-'.encode())
+
+    for offset in range(0, len(plain), CHUNK_SIZE):
+        partial_parts = secret_to_parts(plain[offset:offset + CHUNK_SIZE],
+                                        k, n)
+
+        for part, partial_part in zip(parts, partial_parts):
+            part.write(partial_part)
+
+    return [part.getvalue() for part in parts]
+
+
+def shamir_decode(parts: List[bytes]) -> bytes:
+    parts_dic = {}
+    expected_parts = None
+    expected_size = None
+    first_iteration = True
+
+    for part_id, cur_expected_parts, buffer in map(part_decode, parts):
+        assert 1 <= part_id < shamir.PRIME
+
+        if first_iteration:
+            assert len(buffer) > 0
+            assert len(buffer) % CHUNK_SIZE == 0
+            expected_parts = cur_expected_parts
+            expected_size = len(buffer)
+            first_iteration = False
+            parts_dic[part_id] = buffer
+        else:
+            assert expected_parts == cur_expected_parts
+            assert expected_size == len(buffer)
+
+            if part_id in parts_dic:
+                assert parts_dic[part_id] == buffer
+            else:
+                parts_dic[part_id] = buffer
+
+    if len(parts_dic) < expected_parts:
+        raise InsufficientParts()
+
+    buffer = BytesIO()
+
+    for i in range(0, expected_size, CHUNK_SIZE):
+        parts_chunk = [(part_id, part[i:i+CHUNK_SIZE])
+                       for part_id, part in parts_dic.items()]
+        secret_chunk = parts_to_secret(parts_chunk)
+        buffer.write(secret_chunk)
+
+    secret = buffer.getvalue()
+    return unpad(secret)

+ 2 - 0
error.py

@@ -0,0 +1,2 @@
+class InsufficientParts(Exception):
+    pass

+ 51 - 0
main.py

@@ -0,0 +1,51 @@
+import argparse
+from os.path import join
+
+from core import shamir_decode, shamir_encode
+
+
+def main():
+    parser = argparse.ArgumentParser("Shamir's Secret Sharing")
+    parser.set_defaults(mode='None')
+
+    subparsers = parser.add_subparsers()
+
+    split = subparsers.add_parser("split")
+    split.add_argument("file", help="File to split into multiple files")
+    split.add_argument("-n", type=int, required=True,
+                       help="Number of parts to split the file.")
+    split.add_argument("-k", type=int, required=True,
+                       help="Expected number of parts needed to recover the original file.")
+    split.add_argument("-o", "--output", default=".",
+                       help="Output directory of all the files.")
+    split.set_defaults(mode='split')
+
+    recover = subparsers.add_parser("recover")
+    recover.add_argument("directory")
+    recover.add_argument("-p", "--prefix",
+                         help="Read parts from files in directory with this prefix.")
+    recover.add_argument("-o", "--output",
+                         help="Output file to save the recovered file.")
+    recover.set_defaults(mode='recover')
+
+    args = parser.parse_args()
+
+    if args.mode == 'split':
+        with open(args.file, 'rb') as f:
+            data = f.read()
+        parts = shamir_encode(data, args.k, args.n)
+
+        for part_id, part in enumerate(parts):
+            name_file = f'{args.file}.{part_id+1}.shamir'
+            path = join(args.output, name_file)
+            with open(path, 'wb') as f:
+                f.write(part)
+
+    elif args.mode == 'recover':
+        pass
+    else:
+        parser.print_help()
+
+
+if __name__ == '__main__':
+    main()

+ 57 - 0
shamir.py

@@ -0,0 +1,57 @@
+from typing import List, Tuple
+
+import utils
+
+UPPER = (2**8)**64
+PRIME = (2**8)**64 + 75
+
+
+def inverse(num: int) -> int:
+    return pow(num, PRIME - 2, PRIME)
+
+
+def mod(num: int) -> int:
+    return num % PRIME
+
+
+def rand_coefficient() -> int:
+    with open('/dev/random', 'rb') as f:
+        return utils.bytes_to_long(f.read(64))
+
+
+def generate_polynomial(num_coeff: int, intercept: int) -> List[int]:
+    assert num_coeff >= 1
+    return [rand_coefficient() for _ in range(num_coeff - 1)] + [intercept]
+
+
+def eval_polynomial(polynomial: List[int], point: int):
+    result = 0
+    for coeff in polynomial:
+        result = (result * point + coeff) % PRIME
+    return result
+
+
+def secret_to_parts(secret: int, k: int, n: int) -> List[int]:
+    assert 0 <= secret < UPPER
+    assert 0 < k <= n
+
+    while True:
+        polynomial = generate_polynomial(k, secret)
+        parts = [eval_polynomial(polynomial, point)
+                 for point in range(1, n + 1)]
+        # Check that all parts fit in 64 bytes, otherwise try with new polynomial
+        if all(part < UPPER for part in parts):
+            return parts
+
+
+def parts_to_secret(parts: List[Tuple[int, int]]) -> int:
+    secret = 0
+    for part_id, secret_part in parts:
+        current = secret_part
+        for other_id, _ in parts:
+            if other_id == part_id:
+                continue
+            current = mod(current * mod(-other_id))
+            current = mod(current * inverse(mod(part_id - other_id)))
+        secret = mod(secret + current)
+    return secret

+ 52 - 0
tests.py

@@ -0,0 +1,52 @@
+import random
+import unittest
+
+import shamir
+import core
+from error import *
+
+
+class TestShamir(unittest.TestCase):
+    def setUp(self):
+        self.rnd = random.Random(0)
+
+    def test_encode_decode(self):
+        for it in range(3):
+            for n in range(1, 8):
+                for k in range(1, n + 1):
+                    secret = self.rnd.randint(0, shamir.PRIME - 1)
+                    parts = shamir.secret_to_parts(secret, k, n)
+                    parts_i = [(i + 1, part) for i, part in enumerate(parts)]
+                    self.rnd.shuffle(parts_i)
+
+                    for sub in range(1, n + 1):
+                        sub_parts_i = parts_i[:sub]
+                        recovered_secret = shamir.parts_to_secret(sub_parts_i)
+
+                        if sub >= k:
+                            self.assertEqual(recovered_secret, secret)
+                        else:
+                            self.assertNotEqual(recovered_secret, secret)
+
+    def test_full_encode_decode(self):
+        K = [2, 2, 2, 3, 3, 4, 1, 2, 3, 4, 5, 6, 7]
+        N = [2, 3, 4, 4, 5, 6, 7, 7, 7, 4, 5, 6, 7]
+
+        for it in range(10):
+            for k, n in zip(K, N):
+                plain = bytes([self.rnd.getrandbits(8)
+                               for _ in range(self.rnd.randint(1, 100))])
+                parts = core.shamir_encode(plain, k, n)
+                self.rnd.shuffle(parts)
+                for sub in range(1, n + 1):
+                    sub_parts = parts[:sub]
+                    try:
+                        recovered_plain = core.shamir_decode(sub_parts)
+                        self.assertGreaterEqual(sub, k)
+                        self.assertEqual(recovered_plain, plain)
+                    except InsufficientParts:
+                        self.assertLess(sub, k)
+
+
+if __name__ == "__main__":
+    unittest.main()

+ 23 - 0
utils.py

@@ -0,0 +1,23 @@
+from typing import Optional
+
+
+def bytes_to_long(buffer: bytes, buffer_size: Optional[int] = None) -> int:
+    if buffer_size is not None:
+        assert len(buffer) == buffer_size
+
+    num = 0
+    for byte in buffer:
+        num = num * 256 + byte
+    return num
+
+
+def long_to_bytes(num: int, buffer_size: int) -> bytes:
+    assert 0 <= num
+
+    buffer = [0] * buffer_size
+
+    for i in range(buffer_size):
+        buffer[-i - 1] = num & 255
+        num >>= 8
+
+    return bytes(buffer)