Browse Source

Add segment tree

Marcelo Fornet 4 năm trước cách đây
mục cha
commit
f9515373b2
4 tập tin đã thay đổi với 294 bổ sung1 xóa
  1. 2 1
      .gitignore
  2. 0 0
      disjoint_set.cpp
  3. 132 0
      segment_tree.cpp
  4. 160 0
      segment_tree.hpp

+ 2 - 1
.gitignore

@@ -1,2 +1,3 @@
 .vscode/
-main
+main
+main.*/

main.cpp → disjoint_set.cpp


+ 132 - 0
segment_tree.cpp

@@ -0,0 +1,132 @@
+#include <iostream>
+#include "segment_tree.hpp"
+
+#define endl '\n'
+
+using namespace std;
+
+void example1()
+{
+    cout << "\nExample 1" << endl;
+    // Segment tree that allows setting an element in a range and fetching a single element.
+    auto st = segment_tree<int, int>()
+                  ->with_merge([&](int &a, int &b) {
+                      return 0;
+                  })
+                  ->with_lazy([&](int &a, int &b, int &c) {
+                      if (a)
+                      {
+                          b = a;
+                          c = a;
+                      }
+                  })
+                  ->with_update([&](int &a, int &b) {
+                      a = b;
+                  })
+                  ->init(4);
+
+    st->update(0, 1);
+    cout << "Set value 1 at position 0" << endl;
+
+    st->update(1, 2);
+    cout << "Set value 2 at position 1" << endl;
+
+    cout << "All values" << endl;
+    for (int i = 0; i < 4; ++i)
+        cout << st->query(i) << " ";
+    cout << endl;
+
+    cout << "Set value 4 in the range [1,3)" << endl;
+    st->update(1, 3, 4);
+
+    cout << "All values" << endl;
+    for (int i = 0; i < 4; ++i)
+        cout << st->query(i) << " ";
+    cout << endl;
+}
+
+void example2()
+{
+    cout << "\nExample 2" << endl;
+    // Segment tree of maximum that allows setting adding in a point
+    // It stores a pair on each node:
+    //      - Maximum on the range
+    //      - Lazy to set children
+    auto st = segment_tree<int, int>()
+                  ->with_merge([&](int &a, int &b) {
+                      return max(a, b);
+                  })
+                  ->with_update([&](int &a, int &b) {
+                      a += b;
+                  })
+                  ->init(4);
+
+    st->update(0, 1);
+
+    st->update(1, 5);
+
+    for (int i = 0; i < 4; ++i)
+        cout << st->query(i) << " ";
+    cout << endl;
+
+    st->update(1, 3);
+    st->update(2, 2);
+
+    for (int i = 1; i <= 4; ++i)
+    {
+        for (int j = 0; j + i <= 4; ++j)
+            cout << st->query(j, j + i) << " ";
+        cout << endl;
+    }
+}
+
+void example3()
+{
+    cout << "\nExample 3" << endl;
+    // Segment tree of maximum that allows setting a value on range
+    // It stores a pair on each node:
+    //      - Maximum on the range
+    //      - Lazy to set children
+    typedef pair<int, int> T;
+    auto st = segment_tree<T, int>()
+                  ->with_merge([&](T &a, T &b) {
+                      return make_pair(max(a.first, b.first), 0);
+                  })
+                  ->with_lazy([&](T &a, T &b, T &c) {
+                      if (a.second)
+                      {
+                          b = make_pair(a.second, a.second);
+                          c = make_pair(a.second, a.second);
+                      }
+                  })
+                  ->with_update([&](T &a, int &b) {
+                      a = make_pair(b, b);
+                  })
+                  ->init(4);
+
+    st->update(0, 1);
+
+    st->update(1, 5);
+
+    for (int i = 0; i < 4; ++i)
+        cout << st->query(i).first << " ";
+    cout << endl;
+
+    st->update(1, 4, 3);
+    st->update(2, 4, 2);
+
+    for (int i = 1; i <= 4; ++i)
+    {
+        for (int j = 0; j + i <= 4; ++j)
+            cout << st->query(j, j + i).first << " ";
+        cout << endl;
+    }
+}
+int main()
+{
+    example1();
+    example2();
+    example3();
+
+    return 0;
+}

+ 160 - 0
segment_tree.hpp

@@ -0,0 +1,160 @@
+#include <functional>
+#include <vector>
+
+template <typename T, typename U>
+class SegmentTree
+{
+    int n;
+    std::function<void(T &, U &)> _update;
+    std::function<T(T &, T &)> _merge;
+    std::function<void(T &, T &, T &, int, int, int)> _push;
+
+    void update(int p, int b, int e, int x, int y, U &upd)
+    {
+        if (x <= b && e <= y)
+        {
+            _update(ds[p], upd);
+        }
+        else
+        {
+            int m = (b + e) >> 1, l = p << 1, r = l | 1;
+
+            _push(ds[p], ds[l], ds[r], b, m, e);
+
+            if (x < m)
+                update(l, b, m, x, y, upd);
+
+            if (m < y)
+                update(r, m, e, x, y, upd);
+
+            ds[p] = _merge(ds[l], ds[r]);
+        }
+    }
+
+    T query(int p, int b, int e, int x, int y)
+    {
+        if (x <= b && e <= y)
+        {
+            return ds[p];
+        }
+        else
+        {
+            int m = (b + e) >> 1, l = p << 1, r = l | 1;
+
+            _push(ds[p], ds[l], ds[r], b, m, e);
+
+            if (x < m)
+            {
+                auto le = query(l, b, m, x, y);
+
+                if (m < y)
+                {
+                    auto ri = query(r, m, e, x, y);
+                    return _merge(le, ri);
+                }
+                else
+                {
+                    return le;
+                }
+            }
+            else
+            {
+                return query(r, m, e, x, y);
+            }
+        }
+    }
+
+public:
+    std::vector<T> ds;
+
+    SegmentTree(int n,
+                std::function<void(T &, U &)> update,
+                std::function<T(T &, T &)> merge,
+                std::function<void(T &, T &, T &, int, int, int)> push)
+        : n(n), _update(update), _merge(merge), _push(push)
+    {
+        ds = std::vector<T>(4 * n);
+    }
+
+    // Update on point
+    void update(int x, U upd)
+    {
+        update(1, 0, n, x, x + 1, upd);
+    }
+
+    // Update on range
+    void update(int x, int y, U upd)
+    {
+        update(1, 0, n, x, y, upd);
+    }
+
+    // Query on a point
+    T query(int x)
+    {
+        return query(1, 0, n, x, x + 1);
+    }
+
+    // Query on a range
+    T query(int x, int y)
+    {
+        return query(1, 0, n, x, y);
+    }
+};
+
+template <typename T>
+void noop_push(T &p, T &l, T &r, int b, int m, int e) {}
+
+template <typename T, typename U>
+class SegmentTreeBuilder
+{
+    std::function<void(T &, U &)> _update;
+    std::function<T(T &, T &)> _merge;
+    std::function<void(T &, T &, T &, int, int, int)> _push;
+
+public:
+    SegmentTreeBuilder() : _push(noop_push<T>)
+    {
+    }
+
+    SegmentTree<T, U> *init(int size)
+    {
+        return new SegmentTree<T, U>(size, _update, _merge, _push);
+    }
+
+    SegmentTreeBuilder<T, U> *with_lazy(std::function<void(T &, T &, T &, int, int, int)> push)
+    {
+        _push = push;
+        return this;
+    }
+
+    SegmentTreeBuilder<T, U> *with_lazy(std::function<void(T &, T &, T &)> push)
+    {
+
+        _push = [push](T &p, T &l, T &r, int b, int m, int e) {
+            push(p, l, r);
+        };
+        return this;
+    }
+
+    SegmentTreeBuilder<T, U> *with_update(std::function<void(T &, U &)> update)
+    {
+        _update = update;
+        return this;
+    }
+
+    SegmentTreeBuilder<T, U> *with_merge(std::function<T(T &, T &)> merge)
+    {
+        _merge = merge;
+        return this;
+    }
+};
+
+template <typename T, typename U, typename R = T>
+SegmentTreeBuilder<T, U> *segment_tree()
+{
+    return new SegmentTreeBuilder<T, U>();
+}
+
+// TODO: Allow custom initializer
+// TODO: Use array of length 2 * n (instead of 4 * n)
+// TODO: Use iterative segment tree whenever possible.