From 80cc22f07572e9c466a2d1803f1b9d941d16c07e Mon Sep 17 00:00:00 2001
From: Yao-Yuan Mao <yymao.astro@gmail.com>
Date: Sun, 12 Dec 2021 11:26:40 -0500
Subject: [PATCH] add split method

---
 easyquery.py | 41 ++++++++++++++++++++++++++++++++++++++++-
 test_main.py | 16 ++++++++++++++++
 2 files changed, 56 insertions(+), 1 deletion(-)

diff --git a/easyquery.py b/easyquery.py
index 387d19f..d32d18a 100644
--- a/easyquery.py
+++ b/easyquery.py
@@ -13,7 +13,7 @@
 import numexpr as ne
 
 __all__ = ['Query', 'QueryMaker']
-__version__ = '0.3.0'
+__version__ = '0.4.0'
 
 
 def _is_string_like(obj):
@@ -34,6 +34,7 @@ class Query(object):
     A Query object has three major methods: filter, count, and mask.
     All of them operate on NumPy structured array and astropy Table:
     - `filter` returns a new table that only has entries satisfying the query;
+    - `split` returns two new tables that has entries satisfying and not satisfying the query, respectively;
     - `count` returns the number of entries satisfying the query;
     - `mask` returns a bool array for masking the table;
     - `where` returns a int array for the indices that select satisfying entries.
@@ -303,6 +304,26 @@ def where(self, table):
 
         return np.flatnonzero(self.mask(table))
 
+    def split(self, table, column_slice=None):
+        """
+        Split the `table` into two parts: satisfying and not satisfy the queries.
+        The function will return q.filter(table), (~q).filter(table)
+        where `q` is the current Query object.
+
+        Parameters
+        ----------
+        table : NumPy structured array, astropy Table, etc.
+
+        Returns
+        -------
+        table_true : filtered table, satisfying the queries
+        table_false : filtered table, not satisfying the queries
+        """
+        mask = self.mask(table)
+        if column_slice is not None:
+            table = self._get_table_column(table, column_slice)
+        return self._mask_table(table, mask), self._mask_table(table, ~mask)
+
     def copy(self):
         """
         Create a copy of the current Query object.
@@ -437,6 +458,24 @@ def where(table, *queries):
     return _query_class(*queries).where(table)
 
 
+def split(table, *queries):
+    """
+    A convenient function to split `table` into satisfying and non-satisfying parts.
+    Equivalent to `Query(*queries).split(table)`
+
+    Parameters
+    ----------
+    table : NumPy structured array, astropy Table, etc.
+    queries : string, tuple, callable
+
+    Returns
+    -------
+    table_true : filtered table, satisfying the queries
+    table_false : filtered table, not satisfying the queries
+    """
+    return _query_class(*queries).split(table)
+
+
 class QueryMaker():
     """
     provides convenience functions to generate query objects
diff --git a/test_main.py b/test_main.py
index 6d8d37d..28db76f 100644
--- a/test_main.py
+++ b/test_main.py
@@ -50,7 +50,11 @@ def check_query_on_table(table, query_object, true_mask=None):
     if true_mask is None:
         true_mask = np.ones(len(table), bool)
 
+    stable1, stable2 = query_object.split(table)
+
     assert (query_object.filter(table) == table[true_mask]).all(), 'filter not correct'
+    assert (stable1 == table[true_mask]).all(), 'split not correct'
+    assert (stable2 == table[~true_mask]).all(), 'split not correct'
     assert query_object.count(table) == np.count_nonzero(true_mask), 'count not correct'
     assert (query_object.mask(table) == true_mask).all(), 'mask not correct'
     assert (query_object.where(table) == np.flatnonzero(true_mask)).all(), 'where not correct'
@@ -62,8 +66,17 @@ def check_query_on_dict_table(table, query_object, true_mask=None):
 
     ftable = query_object.filter(table)
     ftable_true = {k: table[k][true_mask] for k in table}
+
+    stable1, stable2 = query_object.split(table)
+    stable1_true = ftable_true
+    stable2_true = {k: table[k][~true_mask] for k in table}
+
     assert set(ftable) == set(ftable_true), 'filter not correct'
     assert all((ftable[k] == ftable_true[k]).all() for k in ftable), 'filter not correct'
+    assert set(stable1) == set(stable1_true), 'split not correct'
+    assert all((stable1[k] == stable1_true[k]).all() for k in ftable), 'split not correct'
+    assert set(stable2) == set(stable2_true), 'split not correct'
+    assert all((stable2[k] == stable2_true[k]).all() for k in ftable), 'split not correct'
     assert query_object.count(table) == np.count_nonzero(true_mask), 'count not correct'
     assert (query_object.mask(table) == true_mask).all(), 'mask not correct'
     assert (query_object.where(table) == np.flatnonzero(true_mask)).all(), 'where not correct'
@@ -159,10 +172,13 @@ def test_filter_column_slice():
     t = gen_test_table()
     q = Query('a > 2')
     assert (q.filter(t, 'b') == t['b'][t['a'] > 2]).all()
+    assert (q.split(t, 'b')[1] == t['b'][~(t['a'] > 2)]).all()
     q = Query('a > 2', 'b < 2')
     assert (q.filter(t, 'c') == t['c'][(t['a'] > 2) & (t['b'] < 2)]).all()
+    assert (q.split(t, 'c')[1] == t['c'][~((t['a'] > 2) & (t['b'] < 2))]).all()
     q = Query(None)
     assert (q.filter(t, 'a') == t['a']).all()
+    assert len(q.split(t, 'a')[1]) == 0
 
 
 def test_query_maker():