tests.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import random
  2. import unittest
  3. import shamir
  4. import core
  5. from error import *
  6. class TestShamir(unittest.TestCase):
  7. def setUp(self):
  8. self.rnd = random.Random(0)
  9. def test_encode_decode(self):
  10. for it in range(3):
  11. for n in range(1, 8):
  12. for k in range(1, n + 1):
  13. secret = self.rnd.randint(0, shamir.PRIME - 1)
  14. parts = shamir.secret_to_parts(secret, k, n)
  15. parts_i = [(i + 1, part) for i, part in enumerate(parts)]
  16. self.rnd.shuffle(parts_i)
  17. for sub in range(1, n + 1):
  18. sub_parts_i = parts_i[:sub]
  19. recovered_secret = shamir.parts_to_secret(sub_parts_i)
  20. if sub >= k:
  21. self.assertEqual(recovered_secret, secret)
  22. else:
  23. self.assertNotEqual(recovered_secret, secret)
  24. def test_full_encode_decode(self):
  25. K = [2, 2, 2, 3, 3, 4, 1, 2, 3, 4, 5, 6, 7]
  26. N = [2, 3, 4, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7]
  27. for it in range(10):
  28. for k, n in zip(K, N):
  29. plain = bytes([self.rnd.getrandbits(8)
  30. for _ in range(self.rnd.randint(1, 100))])
  31. parts = core.shamir_encode(plain, k, n)
  32. self.rnd.shuffle(parts)
  33. for sub in range(1, n + 1):
  34. sub_parts = parts[:sub]
  35. try:
  36. recovered_plain = core.shamir_decode(sub_parts)
  37. self.assertGreaterEqual(sub, k)
  38. self.assertEqual(recovered_plain, plain)
  39. except InsufficientParts:
  40. self.assertLess(sub, k)
  41. if __name__ == "__main__":
  42. unittest.main()