From 35dee7989d47892bdd552f35398959f345ded520 Mon Sep 17 00:00:00 2001 From: Casper Welzel Andersen Date: Thu, 25 Mar 2021 12:22:59 +0100 Subject: [PATCH 1/3] Fix CheckWronglyVersionedBaseUrls middleware Ensure the version part is properly checked. --- optimade/server/middleware.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/optimade/server/middleware.py b/optimade/server/middleware.py index e7737925b..cdb35650c 100644 --- a/optimade/server/middleware.py +++ b/optimade/server/middleware.py @@ -92,15 +92,15 @@ def check_url(url: StarletteURL): """ base_url = get_base_url(url) optimade_path = f"{url.scheme}://{url.netloc}{url.path}"[len(base_url) :] - if re.match(r"^/v[0-9]+", optimade_path): + match = re.match(r"^(?P/v[0-9]+(\.[0-9]+){0,2}).*", optimade_path) + if match is not None: for version_prefix in BASE_URL_PREFIXES.values(): - if optimade_path.startswith(f"{version_prefix}/"): + if match.group("version") == version_prefix: break else: - version_prefix = re.findall(r"(/v[0-9]+(\.[0-9]+){0,2})", optimade_path) raise VersionNotSupported( detail=( - f"The parsed versioned base URL {version_prefix[0][0]!r} from " + f"The parsed versioned base URL {match.group('version')!r} from " f"{url} is not supported by this implementation. " f"Supported versioned base URLs are: {', '.join(BASE_URL_PREFIXES.values())}" ) From 99d8883211aaf5c3bdcc850bc51b799953afb643 Mon Sep 17 00:00:00 2001 From: Matthew Evans Date: Thu, 25 Mar 2021 15:24:08 +0100 Subject: [PATCH 2/3] Don't use for-else for middleware --- optimade/server/middleware.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/optimade/server/middleware.py b/optimade/server/middleware.py index cdb35650c..56ce24701 100644 --- a/optimade/server/middleware.py +++ b/optimade/server/middleware.py @@ -94,10 +94,7 @@ def check_url(url: StarletteURL): optimade_path = f"{url.scheme}://{url.netloc}{url.path}"[len(base_url) :] match = re.match(r"^(?P/v[0-9]+(\.[0-9]+){0,2}).*", optimade_path) if match is not None: - for version_prefix in BASE_URL_PREFIXES.values(): - if match.group("version") == version_prefix: - break - else: + if match.group("version") not in BASE_URL_PREFIXES.values(): raise VersionNotSupported( detail=( f"The parsed versioned base URL {match.group('version')!r} from " From 50225732920af7186c85d80416967b5ea534efe5 Mon Sep 17 00:00:00 2001 From: Casper Welzel Andersen Date: Thu, 25 Mar 2021 17:46:26 +0100 Subject: [PATCH 3/3] Add test for bare versioned base URLs Ensure `version` is correctly set for test client. --- tests/server/middleware/test_versioned_url.py | 15 +++++++++++++-- tests/server/utils.py | 2 ++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/tests/server/middleware/test_versioned_url.py b/tests/server/middleware/test_versioned_url.py index 11fb15da7..4463be7ab 100644 --- a/tests/server/middleware/test_versioned_url.py +++ b/tests/server/middleware/test_versioned_url.py @@ -1,5 +1,5 @@ """Test CheckWronglyVersionedBaseUrls middleware""" -import urllib +import urllib.parse import pytest @@ -63,7 +63,7 @@ def test_multiple_versions_in_path(both_clients): urllib.parse.urlparse(url) ) - # Test also that the a non-valid OPTIMADE version raises + # Test also that a non-valid OPTIMADE version raises url = f"{CONFIG.base_url}/v0/info" with pytest.raises(VersionNotSupported): CheckWronglyVersionedBaseUrls(both_clients.app).check_url( @@ -74,3 +74,14 @@ def test_multiple_versions_in_path(both_clients): CONFIG.base_url = org_base_url else: CONFIG.base_url = None + + +def test_versioned_base_urls(both_clients): + """Test the middleware does not wrongly catch requests to versioned base URLs""" + from optimade.server.config import CONFIG + from optimade.server.routers.utils import BASE_URL_PREFIXES + + for request in BASE_URL_PREFIXES.values(): + CheckWronglyVersionedBaseUrls(both_clients.app).check_url( + urllib.parse.urlparse(f"{CONFIG.base_url}{request}") + ) diff --git a/tests/server/utils.py b/tests/server/utils.py index 5aa9bb9bb..dd4707bbc 100644 --- a/tests/server/utils.py +++ b/tests/server/utils.py @@ -44,6 +44,8 @@ def __init__( if version: if not version.startswith("v") and not version.startswith("/v"): version = f"/v{version}" + if version.startswith("v"): + version = f"/{version}" if re.match(r"/v[0-9](.[0-9]){0,2}", version) is None: warnings.warn( f"Invalid version passed to client: {version!r}. "