123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- #include <vector>
- /// Disjoint set trait
- class DisjointSet
- {
- public:
- std::vector<int> ds;
- DisjointSet(int n)
- {
- ds = std::vector<int>(n, -1);
- }
- virtual int root(int a)
- {
- return ds[a] < 0 ? a : ds[a] = root(ds[a]);
- }
- virtual bool merge(int u, int v)
- {
- u = root(u), v = root(v);
- if (u == v)
- return false;
- if (ds[u] < ds[v])
- std::swap(u, v);
- ds[v] += ds[u];
- ds[u] = v;
- return true;
- }
- virtual int size(int a)
- {
- return -ds[root(a)];
- }
- virtual void undo()
- {
- throw "Abstract method";
- }
- };
- class DisjointSetWithUndo : public DisjointSet
- {
- std::vector<std::pair<int, int>> history;
- public:
- DisjointSetWithUndo(int n) : DisjointSet(n)
- {
- history = std::vector<std::pair<int, int>>();
- }
- bool merge(int u, int v) override
- {
- history.push_back({u, ds[u]});
- history.push_back({v, ds[v]});
- return DisjointSet::merge(u, v);
- }
- void undo() override
- {
- int u, val;
- for (int i = 0; i < 2; ++i)
- {
- std::tie(u, val) = history.back();
- history.pop_back();
- ds[u] = val;
- }
- }
- };
- /// Disjoint set builder
- class DisjointSetBuilder
- {
- bool has_undo;
- public:
- DisjointSetBuilder() : has_undo(false)
- {
- }
- DisjointSetBuilder *with_undo()
- {
- has_undo = true;
- return this;
- }
- DisjointSet *init(int size)
- {
- if (has_undo)
- {
- return new DisjointSetWithUndo(size);
- }
- else
- {
- return new DisjointSet(size);
- }
- }
- };
- DisjointSetBuilder *disjoint_set()
- {
- return new DisjointSetBuilder();
- }
|