node.rs 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. use std::{
  2. cmp::{max, Ordering},
  3. fmt::Debug,
  4. mem::swap,
  5. };
  6. #[derive(Debug)]
  7. pub struct Node<T: Ord> {
  8. value: T,
  9. height: i32,
  10. count: i32,
  11. left: Option<Box<Node<T>>>,
  12. right: Option<Box<Node<T>>>,
  13. }
  14. impl<T: Ord> Node<T> {
  15. pub fn new(value: T) -> Self {
  16. Node {
  17. value,
  18. height: 1,
  19. count: 1,
  20. left: None,
  21. right: None,
  22. }
  23. }
  24. #[inline]
  25. pub fn count(&self) -> i32 {
  26. self.count
  27. }
  28. fn balance_factor(&self) -> i32 {
  29. let lheight = self.left.as_ref().map_or(0, |n| n.height);
  30. let rheight = self.right.as_ref().map_or(0, |n| n.height);
  31. lheight - rheight
  32. }
  33. fn update_stats(&mut self) {
  34. let (lheight, lcount) = self.left.as_ref().map_or((0, 0), |n| (n.height, n.count));
  35. let (rheight, rcount) = self.right.as_ref().map_or((0, 0), |n| (n.height, n.count));
  36. self.height = max(lheight, rheight) + 1;
  37. self.count = lcount + rcount + 1;
  38. }
  39. fn rotate_right(&mut self) {
  40. let mut x = self.left.take().unwrap();
  41. self.left = x.right.take();
  42. self.update_stats();
  43. swap(self, &mut x);
  44. self.right = Some(x);
  45. self.update_stats();
  46. }
  47. fn rotate_left(&mut self) {
  48. let mut x = self.right.take().unwrap();
  49. self.right = x.left.take();
  50. self.update_stats();
  51. swap(self, &mut x);
  52. self.left = Some(x);
  53. self.update_stats();
  54. }
  55. fn balance(&mut self) {
  56. if self.balance_factor() > 1 {
  57. if self.left.as_ref().unwrap().balance_factor() < 0 {
  58. let mut left = self.left.take().unwrap();
  59. left.rotate_left();
  60. self.left = Some(left);
  61. }
  62. self.rotate_right();
  63. } else if self.balance_factor() < -1 {
  64. if self.right.as_ref().unwrap().balance_factor() > 0 {
  65. let mut right = self.right.take().unwrap();
  66. right.rotate_right();
  67. self.right = Some(right);
  68. }
  69. self.rotate_left();
  70. }
  71. }
  72. pub fn insert(&mut self, value: T) -> bool {
  73. let ret_val = match value.cmp(&self.value) {
  74. Ordering::Less => {
  75. if let Some(ref mut left) = self.left {
  76. left.insert(value)
  77. } else {
  78. self.left = Some(Box::new(Node::new(value)));
  79. true
  80. }
  81. }
  82. Ordering::Greater => {
  83. if let Some(ref mut right) = self.right {
  84. right.insert(value)
  85. } else {
  86. self.right = Some(Box::new(Node::new(value)));
  87. true
  88. }
  89. }
  90. Ordering::Equal => false,
  91. };
  92. self.update_stats();
  93. self.balance();
  94. ret_val
  95. }
  96. pub fn count_le(&self, value: &T) -> i32 {
  97. match value.cmp(&self.value) {
  98. Ordering::Less => self.left.as_ref().map_or(0, |n| n.count_le(value)),
  99. Ordering::Greater => {
  100. self.count
  101. + self
  102. .right
  103. .as_ref()
  104. .map_or(0, |n| n.count_le(value) - n.count)
  105. }
  106. Ordering::Equal => self.count - self.right.as_ref().map_or(0, |n| n.count),
  107. }
  108. }
  109. pub fn count_ge(&self, value: &T) -> i32 {
  110. match value.cmp(&self.value) {
  111. Ordering::Less => {
  112. self.count
  113. + self
  114. .left
  115. .as_ref()
  116. .map_or(0, |n| n.count_ge(value) - n.count)
  117. }
  118. Ordering::Greater => self.right.as_ref().map_or(0, |n| n.count_ge(value)),
  119. Ordering::Equal => self.count - self.left.as_ref().map_or(0, |n| n.count),
  120. }
  121. }
  122. }
  123. #[cfg(test)]
  124. mod tests {
  125. use super::Node;
  126. #[test]
  127. fn is_sorted() {
  128. let mut values = vec![5, 3, 7, 1, 4, 6, 9];
  129. let mut root = Node::new(values[0]);
  130. for v in values[1..].iter() {
  131. root.insert(*v);
  132. }
  133. let mut items = Vec::<i32>::new();
  134. fn traverse(node: &Node<i32>, items: &mut Vec<i32>) {
  135. if let Some(ref left) = node.left {
  136. traverse(left, items);
  137. }
  138. items.push(node.value);
  139. if let Some(ref right) = node.right {
  140. traverse(right, items);
  141. }
  142. }
  143. traverse(&root, &mut items);
  144. values.sort();
  145. assert_eq!(items, values);
  146. }
  147. #[test]
  148. fn is_balanced() {
  149. let values = vec![5, 3, 7, 1, 4, 6, 9];
  150. let mut root = Node::new(values[0]);
  151. for v in values[1..].iter() {
  152. root.insert(*v);
  153. }
  154. fn traverse(node: &Node<i32>) {
  155. assert!((node.balance_factor()).abs() <= 1);
  156. if let Some(ref left) = node.left {
  157. traverse(left);
  158. }
  159. if let Some(ref right) = node.right {
  160. traverse(right);
  161. }
  162. }
  163. traverse(&root);
  164. }
  165. }