segment_tree.hpp 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. #include <functional>
  2. #include <vector>
  3. template <typename T, typename U>
  4. class SegmentTree
  5. {
  6. int n;
  7. std::function<void(T &, U &)> _update;
  8. std::function<T(T &, T &)> _merge;
  9. std::function<void(T &, T &, T &, int, int, int)> _push;
  10. void update(int p, int b, int e, int x, int y, U &upd)
  11. {
  12. if (x <= b && e <= y)
  13. {
  14. _update(ds[p], upd);
  15. }
  16. else
  17. {
  18. int m = (b + e) >> 1, l = p << 1, r = l | 1;
  19. _push(ds[p], ds[l], ds[r], b, m, e);
  20. if (x < m)
  21. update(l, b, m, x, y, upd);
  22. if (m < y)
  23. update(r, m, e, x, y, upd);
  24. ds[p] = _merge(ds[l], ds[r]);
  25. }
  26. }
  27. T query(int p, int b, int e, int x, int y)
  28. {
  29. if (x <= b && e <= y)
  30. {
  31. return ds[p];
  32. }
  33. else
  34. {
  35. int m = (b + e) >> 1, l = p << 1, r = l | 1;
  36. _push(ds[p], ds[l], ds[r], b, m, e);
  37. if (x < m)
  38. {
  39. auto le = query(l, b, m, x, y);
  40. if (m < y)
  41. {
  42. auto ri = query(r, m, e, x, y);
  43. return _merge(le, ri);
  44. }
  45. else
  46. {
  47. return le;
  48. }
  49. }
  50. else
  51. {
  52. return query(r, m, e, x, y);
  53. }
  54. }
  55. }
  56. public:
  57. std::vector<T> ds;
  58. SegmentTree(int n,
  59. std::function<void(T &, U &)> update,
  60. std::function<T(T &, T &)> merge,
  61. std::function<void(T &, T &, T &, int, int, int)> push)
  62. : n(n), _update(update), _merge(merge), _push(push)
  63. {
  64. ds = std::vector<T>(4 * n);
  65. }
  66. // Update on point
  67. void update(int x, U upd)
  68. {
  69. update(1, 0, n, x, x + 1, upd);
  70. }
  71. // Update on range
  72. void update(int x, int y, U upd)
  73. {
  74. update(1, 0, n, x, y, upd);
  75. }
  76. // Query on a point
  77. T query(int x)
  78. {
  79. return query(1, 0, n, x, x + 1);
  80. }
  81. // Query on a range
  82. T query(int x, int y)
  83. {
  84. return query(1, 0, n, x, y);
  85. }
  86. };
  87. template <typename T>
  88. void noop_push(T &p, T &l, T &r, int b, int m, int e) {}
  89. template <typename T, typename U>
  90. class SegmentTreeBuilder
  91. {
  92. std::function<void(T &, U &)> _update;
  93. std::function<T(T &, T &)> _merge;
  94. std::function<void(T &, T &, T &, int, int, int)> _push;
  95. public:
  96. SegmentTreeBuilder() : _push(noop_push<T>)
  97. {
  98. }
  99. SegmentTree<T, U> *init(int size)
  100. {
  101. return new SegmentTree<T, U>(size, _update, _merge, _push);
  102. }
  103. SegmentTreeBuilder<T, U> *with_lazy(std::function<void(T &, T &, T &, int, int, int)> push)
  104. {
  105. _push = push;
  106. return this;
  107. }
  108. SegmentTreeBuilder<T, U> *with_lazy(std::function<void(T &, T &, T &)> push)
  109. {
  110. _push = [push](T &p, T &l, T &r, int b, int m, int e) {
  111. push(p, l, r);
  112. };
  113. return this;
  114. }
  115. SegmentTreeBuilder<T, U> *with_update(std::function<void(T &, U &)> update)
  116. {
  117. _update = update;
  118. return this;
  119. }
  120. SegmentTreeBuilder<T, U> *with_merge(std::function<T(T &, T &)> merge)
  121. {
  122. _merge = merge;
  123. return this;
  124. }
  125. };
  126. template <typename T, typename U, typename R = T>
  127. SegmentTreeBuilder<T, U> *segment_tree()
  128. {
  129. return new SegmentTreeBuilder<T, U>();
  130. }
  131. // TODO: Allow custom initializer
  132. // TODO: Use array of length 2 * n (instead of 4 * n)
  133. // TODO: Use iterative segment tree whenever possible.