From aca1e3ff383fe4559f1fce371e10ddd5563af3f2 Mon Sep 17 00:00:00 2001 From: Forest Gregg Date: Fri, 28 Jun 2024 11:57:25 -0400 Subject: [PATCH 1/2] fix exists variables, closes #1196 --- dedupe/variables/exists.py | 7 +++---- tests/test_exists.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) create mode 100644 tests/test_exists.py diff --git a/dedupe/variables/exists.py b/dedupe/variables/exists.py index 00ca7eb4..64b9f287 100644 --- a/dedupe/variables/exists.py +++ b/dedupe/variables/exists.py @@ -5,16 +5,15 @@ from categorical import CategoricalComparator from dedupe._typing import PredicateFunction -from dedupe.variables.base import DerivedType -from dedupe.variables.categorical_type import CategoricalType +from dedupe.variables.base import DerivedType, FieldType -class ExistsType(CategoricalType): +class ExistsType(FieldType): type = "Exists" _predicate_functions: list[PredicateFunction] = [] def __init__(self, field: str, **kwargs): - super().__init__(field, **kwargs) + super().__init__(field, *kwargs) self.cat_comparator = CategoricalComparator([0, 1]) diff --git a/tests/test_exists.py b/tests/test_exists.py new file mode 100644 index 00000000..f4129ef3 --- /dev/null +++ b/tests/test_exists.py @@ -0,0 +1,13 @@ +import unittest + +import numpy + +from dedupe.variables.exists import ExistsType + + +class TestExists(unittest.TestCase): + def test_comparator(self): + var = ExistsType("foo") + assert numpy.array_equal(var.comparator(None, None), [0, 0]) + assert numpy.array_equal(var.comparator(1, 1), [1, 0]) + assert numpy.array_equal(var.comparator(1, 0), [0, 1]) From f82b815b841d0a28d4a103f0ec59d973cb902a5d Mon Sep 17 00:00:00 2001 From: Forest Gregg Date: Fri, 28 Jun 2024 12:04:41 -0400 Subject: [PATCH 2/2] fix signature for super.__init__ of exists --- dedupe/variables/exists.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dedupe/variables/exists.py b/dedupe/variables/exists.py index 64b9f287..72887255 100644 --- a/dedupe/variables/exists.py +++ b/dedupe/variables/exists.py @@ -13,7 +13,7 @@ class ExistsType(FieldType): _predicate_functions: list[PredicateFunction] = [] def __init__(self, field: str, **kwargs): - super().__init__(field, *kwargs) + super().__init__(field, **kwargs) self.cat_comparator = CategoricalComparator([0, 1])