use std::{ cmp::{max, Ordering}, fmt::Debug, mem::swap, }; #[derive(Debug)] pub struct Node { value: T, height: i32, count: i32, left: Option>>, right: Option>>, } impl Node { pub fn new(value: T) -> Self { Node { value, height: 1, count: 1, left: None, right: None, } } #[inline] pub fn count(&self) -> i32 { self.count } fn balance_factor(&self) -> i32 { let lheight = self.left.as_ref().map_or(0, |n| n.height); let rheight = self.right.as_ref().map_or(0, |n| n.height); lheight - rheight } fn update_stats(&mut self) { let (lheight, lcount) = self.left.as_ref().map_or((0, 0), |n| (n.height, n.count)); let (rheight, rcount) = self.right.as_ref().map_or((0, 0), |n| (n.height, n.count)); self.height = max(lheight, rheight) + 1; self.count = lcount + rcount + 1; } fn rotate_right(&mut self) { let mut x = self.left.take().unwrap(); self.left = x.right.take(); self.update_stats(); swap(self, &mut x); self.right = Some(x); self.update_stats(); } fn rotate_left(&mut self) { let mut x = self.right.take().unwrap(); self.right = x.left.take(); self.update_stats(); swap(self, &mut x); self.left = Some(x); self.update_stats(); } fn balance(&mut self) { if self.balance_factor() > 1 { if self.left.as_ref().unwrap().balance_factor() < 0 { let mut left = self.left.take().unwrap(); left.rotate_left(); self.left = Some(left); } self.rotate_right(); } else if self.balance_factor() < -1 { if self.right.as_ref().unwrap().balance_factor() > 0 { let mut right = self.right.take().unwrap(); right.rotate_right(); self.right = Some(right); } self.rotate_left(); } } pub fn insert(&mut self, value: T) -> bool { let ret_val = match value.cmp(&self.value) { Ordering::Less => { if let Some(ref mut left) = self.left { left.insert(value) } else { self.left = Some(Box::new(Node::new(value))); true } } Ordering::Greater => { if let Some(ref mut right) = self.right { right.insert(value) } else { self.right = Some(Box::new(Node::new(value))); true } } Ordering::Equal => false, }; self.update_stats(); self.balance(); ret_val } pub fn count_le(&self, value: &T) -> i32 { match value.cmp(&self.value) { Ordering::Less => self.left.as_ref().map_or(0, |n| n.count_le(value)), Ordering::Greater => { self.count + self .right .as_ref() .map_or(0, |n| n.count_le(value) - n.count) } Ordering::Equal => self.count - self.right.as_ref().map_or(0, |n| n.count), } } pub fn count_ge(&self, value: &T) -> i32 { match value.cmp(&self.value) { Ordering::Less => { self.count + self .left .as_ref() .map_or(0, |n| n.count_ge(value) - n.count) } Ordering::Greater => self.right.as_ref().map_or(0, |n| n.count_ge(value)), Ordering::Equal => self.count - self.left.as_ref().map_or(0, |n| n.count), } } } #[cfg(test)] mod tests { use super::Node; #[test] fn is_sorted() { let mut values = vec![5, 3, 7, 1, 4, 6, 9]; let mut root = Node::new(values[0]); for v in values[1..].iter() { root.insert(*v); } let mut items = Vec::::new(); fn traverse(node: &Node, items: &mut Vec) { if let Some(ref left) = node.left { traverse(left, items); } items.push(node.value); if let Some(ref right) = node.right { traverse(right, items); } } traverse(&root, &mut items); values.sort(); assert_eq!(items, values); } #[test] fn is_balanced() { let values = vec![5, 3, 7, 1, 4, 6, 9]; let mut root = Node::new(values[0]); for v in values[1..].iter() { root.insert(*v); } fn traverse(node: &Node) { assert!((node.balance_factor()).abs() <= 1); if let Some(ref left) = node.left { traverse(left); } if let Some(ref right) = node.right { traverse(right); } } traverse(&root); } }