disjoint_set.hpp 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. #include <vector>
  2. /// Disjoint set trait
  3. class DisjointSet
  4. {
  5. public:
  6. std::vector<int> ds;
  7. DisjointSet(int n)
  8. {
  9. ds = std::vector<int>(n, -1);
  10. }
  11. virtual int root(int a)
  12. {
  13. return ds[a] < 0 ? a : ds[a] = root(ds[a]);
  14. }
  15. virtual bool merge(int u, int v)
  16. {
  17. u = root(u), v = root(v);
  18. if (u == v)
  19. return false;
  20. if (ds[u] < ds[v])
  21. std::swap(u, v);
  22. ds[v] += ds[u];
  23. ds[u] = v;
  24. return true;
  25. }
  26. virtual int size(int a)
  27. {
  28. return -ds[root(a)];
  29. }
  30. virtual void undo()
  31. {
  32. throw "Abstract method";
  33. }
  34. };
  35. class DisjointSetWithUndo : public DisjointSet
  36. {
  37. std::vector<std::pair<int, int>> history;
  38. public:
  39. DisjointSetWithUndo(int n) : DisjointSet(n)
  40. {
  41. history = std::vector<std::pair<int, int>>();
  42. }
  43. bool merge(int u, int v) override
  44. {
  45. history.push_back({u, ds[u]});
  46. history.push_back({v, ds[v]});
  47. return DisjointSet::merge(u, v);
  48. }
  49. void undo() override
  50. {
  51. int u, val;
  52. for (int i = 0; i < 2; ++i)
  53. {
  54. std::tie(u, val) = history.back();
  55. history.pop_back();
  56. ds[u] = val;
  57. }
  58. }
  59. };
  60. /// Disjoint set builder
  61. class DisjointSetBuilder
  62. {
  63. bool has_undo;
  64. public:
  65. DisjointSetBuilder() : has_undo(false)
  66. {
  67. }
  68. DisjointSetBuilder *with_undo()
  69. {
  70. has_undo = true;
  71. return this;
  72. }
  73. DisjointSet *init(int size)
  74. {
  75. if (has_undo)
  76. {
  77. return new DisjointSetWithUndo(size);
  78. }
  79. else
  80. {
  81. return new DisjointSet(size);
  82. }
  83. }
  84. };
  85. DisjointSetBuilder *disjoint_set()
  86. {
  87. return new DisjointSetBuilder();
  88. }