From 80cc22f07572e9c466a2d1803f1b9d941d16c07e Mon Sep 17 00:00:00 2001 From: Yao-Yuan Mao <yymao.astro@gmail.com> Date: Sun, 12 Dec 2021 11:26:40 -0500 Subject: [PATCH] add split method --- easyquery.py | 41 ++++++++++++++++++++++++++++++++++++++++- test_main.py | 16 ++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/easyquery.py b/easyquery.py index 387d19f..d32d18a 100644 --- a/easyquery.py +++ b/easyquery.py @@ -13,7 +13,7 @@ import numexpr as ne __all__ = ['Query', 'QueryMaker'] -__version__ = '0.3.0' +__version__ = '0.4.0' def _is_string_like(obj): @@ -34,6 +34,7 @@ class Query(object): A Query object has three major methods: filter, count, and mask. All of them operate on NumPy structured array and astropy Table: - `filter` returns a new table that only has entries satisfying the query; + - `split` returns two new tables that has entries satisfying and not satisfying the query, respectively; - `count` returns the number of entries satisfying the query; - `mask` returns a bool array for masking the table; - `where` returns a int array for the indices that select satisfying entries. @@ -303,6 +304,26 @@ def where(self, table): return np.flatnonzero(self.mask(table)) + def split(self, table, column_slice=None): + """ + Split the `table` into two parts: satisfying and not satisfy the queries. + The function will return q.filter(table), (~q).filter(table) + where `q` is the current Query object. + + Parameters + ---------- + table : NumPy structured array, astropy Table, etc. + + Returns + ------- + table_true : filtered table, satisfying the queries + table_false : filtered table, not satisfying the queries + """ + mask = self.mask(table) + if column_slice is not None: + table = self._get_table_column(table, column_slice) + return self._mask_table(table, mask), self._mask_table(table, ~mask) + def copy(self): """ Create a copy of the current Query object. @@ -437,6 +458,24 @@ def where(table, *queries): return _query_class(*queries).where(table) +def split(table, *queries): + """ + A convenient function to split `table` into satisfying and non-satisfying parts. + Equivalent to `Query(*queries).split(table)` + + Parameters + ---------- + table : NumPy structured array, astropy Table, etc. + queries : string, tuple, callable + + Returns + ------- + table_true : filtered table, satisfying the queries + table_false : filtered table, not satisfying the queries + """ + return _query_class(*queries).split(table) + + class QueryMaker(): """ provides convenience functions to generate query objects diff --git a/test_main.py b/test_main.py index 6d8d37d..28db76f 100644 --- a/test_main.py +++ b/test_main.py @@ -50,7 +50,11 @@ def check_query_on_table(table, query_object, true_mask=None): if true_mask is None: true_mask = np.ones(len(table), bool) + stable1, stable2 = query_object.split(table) + assert (query_object.filter(table) == table[true_mask]).all(), 'filter not correct' + assert (stable1 == table[true_mask]).all(), 'split not correct' + assert (stable2 == table[~true_mask]).all(), 'split not correct' assert query_object.count(table) == np.count_nonzero(true_mask), 'count not correct' assert (query_object.mask(table) == true_mask).all(), 'mask not correct' assert (query_object.where(table) == np.flatnonzero(true_mask)).all(), 'where not correct' @@ -62,8 +66,17 @@ def check_query_on_dict_table(table, query_object, true_mask=None): ftable = query_object.filter(table) ftable_true = {k: table[k][true_mask] for k in table} + + stable1, stable2 = query_object.split(table) + stable1_true = ftable_true + stable2_true = {k: table[k][~true_mask] for k in table} + assert set(ftable) == set(ftable_true), 'filter not correct' assert all((ftable[k] == ftable_true[k]).all() for k in ftable), 'filter not correct' + assert set(stable1) == set(stable1_true), 'split not correct' + assert all((stable1[k] == stable1_true[k]).all() for k in ftable), 'split not correct' + assert set(stable2) == set(stable2_true), 'split not correct' + assert all((stable2[k] == stable2_true[k]).all() for k in ftable), 'split not correct' assert query_object.count(table) == np.count_nonzero(true_mask), 'count not correct' assert (query_object.mask(table) == true_mask).all(), 'mask not correct' assert (query_object.where(table) == np.flatnonzero(true_mask)).all(), 'where not correct' @@ -159,10 +172,13 @@ def test_filter_column_slice(): t = gen_test_table() q = Query('a > 2') assert (q.filter(t, 'b') == t['b'][t['a'] > 2]).all() + assert (q.split(t, 'b')[1] == t['b'][~(t['a'] > 2)]).all() q = Query('a > 2', 'b < 2') assert (q.filter(t, 'c') == t['c'][(t['a'] > 2) & (t['b'] < 2)]).all() + assert (q.split(t, 'c')[1] == t['c'][~((t['a'] > 2) & (t['b'] < 2))]).all() q = Query(None) assert (q.filter(t, 'a') == t['a']).all() + assert len(q.split(t, 'a')[1]) == 0 def test_query_maker():