diff --git a/src/mwparserfromhell/wikicode.py b/src/mwparserfromhell/wikicode.py index 4d4f9b3..b17f0ba 100644 --- a/src/mwparserfromhell/wikicode.py +++ b/src/mwparserfromhell/wikicode.py @@ -516,19 +516,43 @@ def matches(self, other): adjusted. Specifically, whitespace and markup is stripped and the first letter's case is normalized. Typical usage is ``if template.name.matches("stub"): ...``. + + If either side has any colons, everything before the last colon is taken to be + a namespace and/or interwiki prefix. The parts before and after the colon are + normalized and compared separately; both must match for the result to be True. """ - normalize = lambda s: (s[0].upper() + s[1:]).replace("_", " ") if s else s - this = normalize(self.strip_code().strip()) + this = self.strip_code().strip() + this_prefix, this_postfix = self._split_and_normalize(this) if isinstance(other, (str, bytes, Wikicode, Node)): that = parse_anything(other).strip_code().strip() - return this == normalize(that) + that_prefix, that_postfix = self._split_and_normalize(that) + return (this_prefix, this_postfix) == (that_prefix, that_postfix) for obj in other: that = parse_anything(obj).strip_code().strip() - if this == normalize(that): + that_prefix, that_postfix = self._split_and_normalize(that) + if (this_prefix, this_postfix) == (that_prefix, that_postfix): return True return False + + def _split_and_normalize(self, s): + """Split a page title into a prefix (everything before the last colon) + and a postfix (everything after the last colon). Both parts are normalized + according to the rules specific to that part (the prefix is case-insensitive, + while the postfix is only case insensitive in the first character) before being + returned. + + If there is no prefix, the returned prefix is an empty string. + """ + normalize = lambda s: (s[0].upper() + s[1:]).replace("_", " ") if s else s + m = re.match(r'(.*):(.*)', s) + if m: + return normalize(m[1]).lower(), normalize(m[2]) + else: + return "", normalize(s) + + def ifilter(self, recursive=True, matches=None, flags=FLAGS, forcetype=None): """Iterate over nodes in our list matching certain conditions. diff --git a/tests/test_wikicode.py b/tests/test_wikicode.py index 16c7ebc..0586f0e 100644 --- a/tests/test_wikicode.py +++ b/tests/test_wikicode.py @@ -363,6 +363,8 @@ def test_matches(): code3 = parse("Hello world!") code4 = parse("World,_hello?") code5 = parse("") + code6 = parse("File:Foo") + code7 = parse("Talk:foo") assert code1.matches("Cleanup") is True assert code1.matches("cleanup") is True assert code1.matches(" cleanup\n") is True @@ -386,6 +388,13 @@ def test_matches(): assert code5.matches("") is True assert code5.matches("") is True assert code5.matches(("a", "b", "")) is True + assert code6.matches("File:Foo") is True + assert code6.matches("File:foo") is True + assert code6.matches("FILE:FOO") is False + assert code6.matches("file:foo") is True + assert code6.matches("FiLe:foo") is True + assert code6.matches("FiLE:Foo") is True + assert code7.matches("Talk:Foo") is True def test_filter_family():