|
@@ -8,6 +8,7 @@ use std::{
|
|
|
pub struct Node<T: Ord> {
|
|
|
value: T,
|
|
|
height: i32,
|
|
|
+ count: i32,
|
|
|
left: Option<Box<Node<T>>>,
|
|
|
right: Option<Box<Node<T>>>,
|
|
|
}
|
|
@@ -17,39 +18,47 @@ impl<T: Ord> Node<T> {
|
|
|
Node {
|
|
|
value,
|
|
|
height: 1,
|
|
|
+ count: 1,
|
|
|
left: None,
|
|
|
right: None,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ #[inline]
|
|
|
+ pub fn count(&self) -> i32 {
|
|
|
+ self.count
|
|
|
+ }
|
|
|
+
|
|
|
fn balance_factor(&self) -> i32 {
|
|
|
let lheight = self.left.as_ref().map_or(0, |n| n.height);
|
|
|
let rheight = self.right.as_ref().map_or(0, |n| n.height);
|
|
|
lheight - rheight
|
|
|
}
|
|
|
|
|
|
- fn update_height(&mut self) {
|
|
|
- let lheight = self.left.as_ref().map_or(0, |n| n.height);
|
|
|
- let rheight = self.right.as_ref().map_or(0, |n| n.height);
|
|
|
+ fn update_stats(&mut self) {
|
|
|
+ let (lheight, lcount) = self.left.as_ref().map_or((0, 0), |n| (n.height, n.count));
|
|
|
+ let (rheight, rcount) = self.right.as_ref().map_or((0, 0), |n| (n.height, n.count));
|
|
|
+
|
|
|
self.height = max(lheight, rheight) + 1;
|
|
|
+ self.count = lcount + rcount + 1;
|
|
|
}
|
|
|
|
|
|
fn rotate_right(&mut self) {
|
|
|
let mut x = self.left.take().unwrap();
|
|
|
self.left = x.right.take();
|
|
|
- self.update_height();
|
|
|
+ self.update_stats();
|
|
|
swap(self, &mut x);
|
|
|
self.right = Some(x);
|
|
|
- self.update_height();
|
|
|
+ self.update_stats();
|
|
|
}
|
|
|
|
|
|
fn rotate_left(&mut self) {
|
|
|
let mut x = self.right.take().unwrap();
|
|
|
self.right = x.left.take();
|
|
|
- self.update_height();
|
|
|
+ self.update_stats();
|
|
|
swap(self, &mut x);
|
|
|
self.left = Some(x);
|
|
|
- self.update_height();
|
|
|
+ self.update_stats();
|
|
|
}
|
|
|
|
|
|
fn balance(&mut self) {
|
|
@@ -91,15 +100,43 @@ impl<T: Ord> Node<T> {
|
|
|
Ordering::Equal => false,
|
|
|
};
|
|
|
|
|
|
- self.update_height();
|
|
|
+ self.update_stats();
|
|
|
self.balance();
|
|
|
ret_val
|
|
|
}
|
|
|
+
|
|
|
+ pub fn count_le(&self, value: &T) -> i32 {
|
|
|
+ match value.cmp(&self.value) {
|
|
|
+ Ordering::Less => self.left.as_ref().map_or(0, |n| n.count_le(value)),
|
|
|
+ Ordering::Greater => {
|
|
|
+ self.count
|
|
|
+ + self
|
|
|
+ .right
|
|
|
+ .as_ref()
|
|
|
+ .map_or(0, |n| n.count_le(value) - n.count)
|
|
|
+ }
|
|
|
+ Ordering::Equal => self.count - self.right.as_ref().map_or(0, |n| n.count),
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ pub fn count_ge(&self, value: &T) -> i32 {
|
|
|
+ match value.cmp(&self.value) {
|
|
|
+ Ordering::Less => {
|
|
|
+ self.count
|
|
|
+ + self
|
|
|
+ .left
|
|
|
+ .as_ref()
|
|
|
+ .map_or(0, |n| n.count_ge(value) - n.count)
|
|
|
+ }
|
|
|
+ Ordering::Greater => self.right.as_ref().map_or(0, |n| n.count_ge(value)),
|
|
|
+ Ordering::Equal => self.count - self.left.as_ref().map_or(0, |n| n.count),
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
#[cfg(test)]
|
|
|
mod tests {
|
|
|
- use super::*;
|
|
|
+ use super::Node;
|
|
|
|
|
|
#[test]
|
|
|
fn is_sorted() {
|