import random import unittest import shamir import core from error import * class TestShamir(unittest.TestCase): def setUp(self): self.rnd = random.Random(0) def test_encode_decode(self): for it in range(3): for n in range(1, 8): for k in range(1, n + 1): secret = self.rnd.randint(0, shamir.PRIME - 1) parts = shamir.secret_to_parts(secret, k, n) parts_i = [(i + 1, part) for i, part in enumerate(parts)] self.rnd.shuffle(parts_i) for sub in range(1, n + 1): sub_parts_i = parts_i[:sub] recovered_secret = shamir.parts_to_secret(sub_parts_i) if sub >= k: self.assertEqual(recovered_secret, secret) else: self.assertNotEqual(recovered_secret, secret) def test_full_encode_decode(self): K = [2, 2, 2, 3, 3, 4, 1, 2, 3, 4, 5, 6, 7] N = [2, 3, 4, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7] for it in range(10): for k, n in zip(K, N): plain = bytes([self.rnd.getrandbits(8) for _ in range(self.rnd.randint(1, 100))]) parts = core.shamir_encode(plain, k, n) self.rnd.shuffle(parts) for sub in range(1, n + 1): sub_parts = parts[:sub] try: recovered_plain = core.shamir_decode(sub_parts) self.assertGreaterEqual(sub, k) self.assertEqual(recovered_plain, plain) except InsufficientParts: self.assertLess(sub, k) if __name__ == "__main__": unittest.main()