12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- import random
- import unittest
- import shamir
- import core
- from error import *
- class TestShamir(unittest.TestCase):
- def setUp(self):
- self.rnd = random.Random(0)
- def test_encode_decode(self):
- for it in range(3):
- for n in range(1, 8):
- for k in range(1, n + 1):
- secret = self.rnd.randint(0, shamir.PRIME - 1)
- parts = shamir.secret_to_parts(secret, k, n)
- parts_i = [(i + 1, part) for i, part in enumerate(parts)]
- self.rnd.shuffle(parts_i)
- for sub in range(1, n + 1):
- sub_parts_i = parts_i[:sub]
- recovered_secret = shamir.parts_to_secret(sub_parts_i)
- if sub >= k:
- self.assertEqual(recovered_secret, secret)
- else:
- self.assertNotEqual(recovered_secret, secret)
- def test_full_encode_decode(self):
- K = [2, 2, 2, 3, 3, 4, 1, 2, 3, 4, 5, 6, 7]
- N = [2, 3, 4, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7]
- for it in range(10):
- for k, n in zip(K, N):
- plain = bytes([self.rnd.getrandbits(8)
- for _ in range(self.rnd.randint(1, 100))])
- parts = core.shamir_encode(plain, k, n)
- self.rnd.shuffle(parts)
- for sub in range(1, n + 1):
- sub_parts = parts[:sub]
- try:
- recovered_plain = core.shamir_decode(sub_parts)
- self.assertGreaterEqual(sub, k)
- self.assertEqual(recovered_plain, plain)
- except InsufficientParts:
- self.assertLess(sub, k)
- if __name__ == "__main__":
- unittest.main()
|