Prechádzať zdrojové kódy

feat(count): add logic for counting values

Danilo Gómez 1 rok pred
rodič
commit
14637a17ce
2 zmenil súbory, kde vykonal 83 pridanie a 9 odobranie
  1. 46 9
      src/node.rs
  2. 37 0
      src/set.rs

+ 46 - 9
src/node.rs

@@ -8,6 +8,7 @@ use std::{
 pub struct Node<T: Ord> {
     value: T,
     height: i32,
+    count: i32,
     left: Option<Box<Node<T>>>,
     right: Option<Box<Node<T>>>,
 }
@@ -17,39 +18,47 @@ impl<T: Ord> Node<T> {
         Node {
             value,
             height: 1,
+            count: 1,
             left: None,
             right: None,
         }
     }
 
+    #[inline]
+    pub fn count(&self) -> i32 {
+        self.count
+    }
+
     fn balance_factor(&self) -> i32 {
         let lheight = self.left.as_ref().map_or(0, |n| n.height);
         let rheight = self.right.as_ref().map_or(0, |n| n.height);
         lheight - rheight
     }
 
-    fn update_height(&mut self) {
-        let lheight = self.left.as_ref().map_or(0, |n| n.height);
-        let rheight = self.right.as_ref().map_or(0, |n| n.height);
+    fn update_stats(&mut self) {
+        let (lheight, lcount) = self.left.as_ref().map_or((0, 0), |n| (n.height, n.count));
+        let (rheight, rcount) = self.right.as_ref().map_or((0, 0), |n| (n.height, n.count));
+
         self.height = max(lheight, rheight) + 1;
+        self.count = lcount + rcount + 1;
     }
 
     fn rotate_right(&mut self) {
         let mut x = self.left.take().unwrap();
         self.left = x.right.take();
-        self.update_height();
+        self.update_stats();
         swap(self, &mut x);
         self.right = Some(x);
-        self.update_height();
+        self.update_stats();
     }
 
     fn rotate_left(&mut self) {
         let mut x = self.right.take().unwrap();
         self.right = x.left.take();
-        self.update_height();
+        self.update_stats();
         swap(self, &mut x);
         self.left = Some(x);
-        self.update_height();
+        self.update_stats();
     }
 
     fn balance(&mut self) {
@@ -91,15 +100,43 @@ impl<T: Ord> Node<T> {
             Ordering::Equal => false,
         };
 
-        self.update_height();
+        self.update_stats();
         self.balance();
         ret_val
     }
+
+    pub fn count_le(&self, value: &T) -> i32 {
+        match value.cmp(&self.value) {
+            Ordering::Less => self.left.as_ref().map_or(0, |n| n.count_le(value)),
+            Ordering::Greater => {
+                self.count
+                    + self
+                        .right
+                        .as_ref()
+                        .map_or(0, |n| n.count_le(value) - n.count)
+            }
+            Ordering::Equal => self.count - self.right.as_ref().map_or(0, |n| n.count),
+        }
+    }
+
+    pub fn count_ge(&self, value: &T) -> i32 {
+        match value.cmp(&self.value) {
+            Ordering::Less => {
+                self.count
+                    + self
+                        .left
+                        .as_ref()
+                        .map_or(0, |n| n.count_ge(value) - n.count)
+            }
+            Ordering::Greater => self.right.as_ref().map_or(0, |n| n.count_ge(value)),
+            Ordering::Equal => self.count - self.left.as_ref().map_or(0, |n| n.count),
+        }
+    }
 }
 
 #[cfg(test)]
 mod tests {
-    use super::*;
+    use super::Node;
 
     #[test]
     fn is_sorted() {

+ 37 - 0
src/set.rs

@@ -24,4 +24,41 @@ impl<T: Ord> Set<T> {
             true
         }
     }
+
+    pub fn count(&self) -> i32 {
+        self.root.as_ref().map_or(0, |n| n.count())
+    }
+
+    pub fn count_lt(&self, value: &T) -> i32 {
+        self.count() - self.count_ge(value)
+    }
+
+    pub fn count_le(&self, value: &T) -> i32 {
+        self.root.as_ref().map_or(0, |n| n.count_le(value))
+    }
+
+    pub fn count_ge(&self, value: &T) -> i32 {
+        self.root.as_ref().map_or(0, |n| n.count_ge(value))
+    }
+
+    pub fn count_gt(&self, value: &T) -> i32 {
+        self.count() - self.count_le(value)
+    }
+}
+
+mod tests {
+    #[test]
+    fn test_count() {
+        let mut tree = super::Set::new();
+        for i in 1..11 {
+            assert!(tree.insert(i));
+            assert_eq!(tree.count(), i);
+        }
+        for i in 1..11 {
+            assert_eq!(tree.count_lt(&i), i - 1);
+            assert_eq!(tree.count_le(&i), i);
+            assert_eq!(tree.count_ge(&i), 11 - i);
+            assert_eq!(tree.count_gt(&i), 10 - i);
+        }
+    }
 }