123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- use std::{
- cmp::{max, Ordering},
- fmt::Debug,
- mem::swap,
- };
- #[derive(Debug)]
- pub struct Node<T: Ord> {
- value: T,
- height: i32,
- count: i32,
- left: Option<Box<Node<T>>>,
- right: Option<Box<Node<T>>>,
- }
- impl<T: Ord> Node<T> {
- pub fn new(value: T) -> Self {
- Node {
- value,
- height: 1,
- count: 1,
- left: None,
- right: None,
- }
- }
- #[inline]
- pub fn count(&self) -> i32 {
- self.count
- }
- fn balance_factor(&self) -> i32 {
- let lheight = self.left.as_ref().map_or(0, |n| n.height);
- let rheight = self.right.as_ref().map_or(0, |n| n.height);
- lheight - rheight
- }
- fn update_stats(&mut self) {
- let (lheight, lcount) = self.left.as_ref().map_or((0, 0), |n| (n.height, n.count));
- let (rheight, rcount) = self.right.as_ref().map_or((0, 0), |n| (n.height, n.count));
- self.height = max(lheight, rheight) + 1;
- self.count = lcount + rcount + 1;
- }
- fn rotate_right(&mut self) {
- let mut x = self.left.take().unwrap();
- self.left = x.right.take();
- self.update_stats();
- swap(self, &mut x);
- self.right = Some(x);
- self.update_stats();
- }
- fn rotate_left(&mut self) {
- let mut x = self.right.take().unwrap();
- self.right = x.left.take();
- self.update_stats();
- swap(self, &mut x);
- self.left = Some(x);
- self.update_stats();
- }
- fn balance(&mut self) {
- if self.balance_factor() > 1 {
- if self.left.as_ref().unwrap().balance_factor() < 0 {
- let mut left = self.left.take().unwrap();
- left.rotate_left();
- self.left = Some(left);
- }
- self.rotate_right();
- } else if self.balance_factor() < -1 {
- if self.right.as_ref().unwrap().balance_factor() > 0 {
- let mut right = self.right.take().unwrap();
- right.rotate_right();
- self.right = Some(right);
- }
- self.rotate_left();
- }
- }
- pub fn insert(&mut self, value: T) -> bool {
- let ret_val = match value.cmp(&self.value) {
- Ordering::Less => {
- if let Some(ref mut left) = self.left {
- left.insert(value)
- } else {
- self.left = Some(Box::new(Node::new(value)));
- true
- }
- }
- Ordering::Greater => {
- if let Some(ref mut right) = self.right {
- right.insert(value)
- } else {
- self.right = Some(Box::new(Node::new(value)));
- true
- }
- }
- Ordering::Equal => false,
- };
- self.update_stats();
- self.balance();
- ret_val
- }
- pub fn count_le(&self, value: &T) -> i32 {
- match value.cmp(&self.value) {
- Ordering::Less => self.left.as_ref().map_or(0, |n| n.count_le(value)),
- Ordering::Greater => {
- self.count
- + self
- .right
- .as_ref()
- .map_or(0, |n| n.count_le(value) - n.count)
- }
- Ordering::Equal => self.count - self.right.as_ref().map_or(0, |n| n.count),
- }
- }
- pub fn count_ge(&self, value: &T) -> i32 {
- match value.cmp(&self.value) {
- Ordering::Less => {
- self.count
- + self
- .left
- .as_ref()
- .map_or(0, |n| n.count_ge(value) - n.count)
- }
- Ordering::Greater => self.right.as_ref().map_or(0, |n| n.count_ge(value)),
- Ordering::Equal => self.count - self.left.as_ref().map_or(0, |n| n.count),
- }
- }
- }
- #[cfg(test)]
- mod tests {
- use super::Node;
- #[test]
- fn is_sorted() {
- let mut values = vec![5, 3, 7, 1, 4, 6, 9];
- let mut root = Node::new(values[0]);
- for v in values[1..].iter() {
- root.insert(*v);
- }
- let mut items = Vec::<i32>::new();
- fn traverse(node: &Node<i32>, items: &mut Vec<i32>) {
- if let Some(ref left) = node.left {
- traverse(left, items);
- }
- items.push(node.value);
- if let Some(ref right) = node.right {
- traverse(right, items);
- }
- }
- traverse(&root, &mut items);
- values.sort();
- assert_eq!(items, values);
- }
- #[test]
- fn is_balanced() {
- let values = vec![5, 3, 7, 1, 4, 6, 9];
- let mut root = Node::new(values[0]);
- for v in values[1..].iter() {
- root.insert(*v);
- }
- fn traverse(node: &Node<i32>) {
- assert!((node.balance_factor()).abs() <= 1);
- if let Some(ref left) = node.left {
- traverse(left);
- }
- if let Some(ref right) = node.right {
- traverse(right);
- }
- }
- traverse(&root);
- }
- }
|