Skip to content

Commit

Permalink
feat: Add multimap headers (#704)
Browse files Browse the repository at this point in the history
* Add multimap headers

* fix formatting

* docs
  • Loading branch information
sansyrox committed Dec 2, 2023
1 parent 806fc8a commit 81aac86
Show file tree
Hide file tree
Showing 27 changed files with 468 additions and 209 deletions.
45 changes: 43 additions & 2 deletions docs_src/src/pages/documentation/api_reference/getting_started.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,27 @@ Or setting the Headers globally *per* router.
</Col>
</Row>

<Row>
<Col>
`add_response_header` appends the header to the list of headers, while `set_response_header` replaces the header if it exists.
</Col>
<Col>


<CodeGroup title="Request" tag="GET" label="/hello_world">

```python {{ title: 'untyped' }}
app.set_response_header("content-type", "application/json")
```

```python {{title: 'typed'}}
app.set_response_header("content-type", "application/json")
```
</CodeGroup>
</Col>
</Row>



## Request Headers

Expand All @@ -480,7 +501,7 @@ Either, by using the `headers` field in the `Request` object:

print("These are the request headers: ", headers)

headers["modified"] = ["modified_value"]
headers.set("modified", "modified_value")

print("These are the modified request headers: ", headers)

Expand All @@ -496,7 +517,7 @@ Either, by using the `headers` field in the `Request` object:

print("These are the request headers: ", headers)

headers["modified"] = ["modified_value"]
headers.set("modified", "modified_value")

print("These are the modified request headers: ", headers)

Expand Down Expand Up @@ -527,6 +548,26 @@ Or by using the global Request Headers:
</Col>
</Row>

<Row>
<Col>
`add_request_header` appends the header to the list of headers, while `set_request_header` replaces the header if it exists.
</Col>
<Col sticky>


<CodeGroup title="Request" tag="GET" label="/hello_world">

```python {{ title: 'untyped' }}
app.set_request_header("server", "robyn")
```

```python {{title: 'typed'}}
app.set_request_header("server", "robyn")
```
</CodeGroup>
</Col>
</Row>


---

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ Batman was excited to learn that he could add events as functions as well as dec

@app.after_request("/")
def hello_after_request(response: Response):
response.headers["after"] = "sync_after_request"
response.headers.set("after", "sync_after_request"")
print(response)
```

Expand All @@ -98,12 +98,12 @@ Batman was excited to learn that he could add events as functions as well as dec

@app.before_request("/")
async def hello_before_request(request):
request.headers["before"] = "sync_before_request"
request.headers.set("before", "sync_before_request")
print(request)

@app.after_request("/")
def hello_after_request(response):
response.headers["after"] = "sync_after_request"
response.headers.set("after", "sync_after_request"")
print(response)
```

Expand Down
41 changes: 21 additions & 20 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
WebSocketConnector,
)
from robyn.authentication import AuthenticationHandler, BearerGetter, Identity
from robyn.robyn import Headers
from robyn.templating import JinjaTemplate

from integration_tests.views import SyncView, AsyncView
Expand Down Expand Up @@ -111,20 +112,20 @@ def shutdown_handler():

@app.before_request()
def global_before_request(request: Request):
request.headers["global_before"] = "global_before_request"
request.headers.set("global_before", "global_before_request")
return request


@app.after_request()
def global_after_request(response: Response):
response.headers["global_after"] = "global_after_request"
response.headers.set("global_after", "global_after_request")
return response


@app.get("/sync/global/middlewares")
def sync_global_middlewares(request: Request):
assert "global_before" in request.headers
assert request.headers["global_before"] == "global_before_request"
assert request.headers.contains("global_before")
assert request.headers.get("global_before") == "global_before_request"
return "sync global middlewares"


Expand All @@ -133,49 +134,49 @@ def sync_global_middlewares(request: Request):

@app.before_request("/sync/middlewares")
def sync_before_request(request: Request):
request.headers["before"] = "sync_before_request"
request.headers.set("before", "sync_before_request")
return request


@app.after_request("/sync/middlewares")
def sync_after_request(response: Response):
response.headers["after"] = "sync_after_request"
response.headers.set("after", "sync_after_request")
response.description = response.description + " after"
return response


@app.get("/sync/middlewares")
def sync_middlewares(request: Request):
assert "before" in request.headers
assert request.headers["before"] == "sync_before_request"
assert request.headers.contains("before")
assert request.headers.get("before") == "sync_before_request"
assert request.ip_addr == "127.0.0.1"
return "sync middlewares"


@app.before_request("/async/middlewares")
async def async_before_request(request: Request):
request.headers["before"] = "async_before_request"
request.headers.set("before", "async_before_request")
return request


@app.after_request("/async/middlewares")
async def async_after_request(response: Response):
response.headers["after"] = "async_after_request"
response.headers.set("after", "async_after_request")
response.description = response.description + " after"
return response


@app.get("/async/middlewares")
async def async_middlewares(request: Request):
assert "before" in request.headers
assert request.headers["before"] == "async_before_request"
assert request.headers.contains("before")
assert request.headers.get("before") == "async_before_request"
assert request.ip_addr == "127.0.0.1"
return "async middlewares"


@app.before_request("/sync/middlewares/401")
def sync_before_request_401():
return Response(401, {}, "sync before request 401")
return Response(401, Headers({}), "sync before request 401")


@app.get("/sync/middlewares/401")
Expand Down Expand Up @@ -266,22 +267,22 @@ async def async_dict_const_get():

@app.get("/sync/response")
def sync_response_get():
return Response(200, {"sync": "response"}, "sync response get")
return Response(200, Headers({"sync": "response"}), "sync response get")


@app.get("/async/response")
async def async_response_get():
return Response(200, {"async": "response"}, "async response get")
return Response(200, Headers({"async": "response"}), "async response get")


@app.get("/sync/response/const", const=True)
def sync_response_const_get():
return Response(200, {"sync_const": "response"}, "sync response const get")
return Response(200, Headers({"sync_const": "response"}), "sync response const get")


@app.get("/async/response/const", const=True)
async def async_response_const_get():
return Response(200, {"async_const": "response"}, "async response const get")
return Response(200, Headers({"async_const": "response"}), "async response const get")


# Binary
Expand All @@ -301,7 +302,7 @@ async def async_octet_get():
def sync_octet_response_get():
return Response(
status_code=200,
headers={"Content-Type": "application/octet-stream"},
headers=Headers({"Content-Type": "application/octet-stream"}),
description="sync octet response",
)

Expand All @@ -310,7 +311,7 @@ def sync_octet_response_get():
async def async_octet_response_get():
return Response(
status_code=200,
headers={"Content-Type": "application/octet-stream"},
headers=Headers({"Content-Type": "application/octet-stream"}),
description="async octet response",
)

Expand Down Expand Up @@ -767,7 +768,7 @@ async def async_without_decorator():


def main():
app.add_response_header("server", "robyn")
app.set_response_header("server", "robyn")
app.add_directory(
route="/test_dir",
directory_path=os.path.join(current_file_path, "build"),
Expand Down
6 changes: 3 additions & 3 deletions integration_tests/helpers/http_methods_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ def check_response(response: requests.Response, expected_status_code: int):
headers is not present in the response.
"""
assert response.status_code == expected_status_code
assert "global_after" in response.headers
assert response.headers["global_after"] == "global_after_request"
print(response.headers)
assert response.headers.get("global_after") == "global_after_request"
assert "server" in response.headers
assert response.headers["server"] == "robyn"
assert response.headers.get("server") == "robyn"


def get(
Expand Down
27 changes: 14 additions & 13 deletions integration_tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from robyn import Robyn, ALLOW_CORS
from robyn.events import Events
from robyn.types import Header
from robyn.robyn import Headers

import pytest


@pytest.mark.benchmark
def test_add_request_header():
app = Robyn(__file__)
app.add_request_header("server", "robyn")
assert app.request_headers == [Header(key="server", val="robyn")]
app.set_request_header("server", "robyn")
assert app.request_headers.get_headers() == Headers({"server": "robyn"}).get_headers()


@pytest.mark.benchmark
def test_add_response_header():
app = Robyn(__file__)
app.add_response_header("content-type", "application/json")
assert app.response_headers == [Header(key="content-type", val="application/json")]
assert app.response_headers.get_headers() == Headers({"content-type": "application/json"}).get_headers()


@pytest.mark.benchmark
Expand Down Expand Up @@ -48,12 +48,13 @@ async def mock_shutdown_handler():
def test_allow_cors():
app = Robyn(__file__)
ALLOW_CORS(app, ["*"])
assert app.request_headers == [
Header(key="Access-Control-Allow-Origin", val="*"),
Header(
key="Access-Control-Allow-Methods",
val="GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS",
),
Header(key="Access-Control-Allow-Headers", val="Content-Type, Authorization"),
Header(key="Access-Control-Allow-Credentials", val="true"),
]

headers = Headers({})
headers.set("Access-Control-Allow-Origin", "*")
headers.set(
"Access-Control-Allow-Methods",
"GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS",
)
headers.set("Access-Control-Allow-Headers", "Content-Type, Authorization")
headers.set("Access-Control-Allow-Credentials", "true")
assert app.response_headers.get_headers() == headers.get_headers()
6 changes: 3 additions & 3 deletions integration_tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_invalid_authentication_token(session, function_type: str):
should_check_response=False,
)
assert r.status_code == 401
assert r.headers["WWW-Authenticate"] == "BearerGetter"
assert r.headers.get("WWW-Authenticate") == "BearerGetter"


@pytest.mark.benchmark
Expand All @@ -31,12 +31,12 @@ def test_invalid_authentication_header(session, function_type: str):
should_check_response=False,
)
assert r.status_code == 401
assert r.headers["WWW-Authenticate"] == "BearerGetter"
assert r.headers.get("WWW-Authenticate") == "BearerGetter"


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_invalid_authentication_no_token(session, function_type: str):
r = get(f"/{function_type}/auth", should_check_response=False)
assert r.status_code == 401
assert r.headers["WWW-Authenticate"] == "BearerGetter"
assert r.headers.get("WWW-Authenticate") == "BearerGetter"
2 changes: 1 addition & 1 deletion integration_tests/test_binary_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
)
def test_binary_output(route: str, text: str, session):
r = get(route)
assert r.headers["Content-Type"] == "application/octet-stream"
assert r.headers.get("Content-Type") == "application/octet-stream"
assert r.text == text
3 changes: 2 additions & 1 deletion integration_tests/test_file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_file_download(function_type: str, session):
r = get(f"/{function_type}/file/download")
assert r.headers["Content-Disposition"] == "attachment"
assert r.headers.get("Content-Disposition") == "attachment"

assert r.text == "This is a test file for the downloading purpose"
13 changes: 7 additions & 6 deletions integration_tests/test_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_middlewares(function_type: str, session):
r = get(f"/{function_type}/middlewares")
headers = r.headers
# We do not want the request headers to be in the response
assert "before" not in r.headers
assert "after" in r.headers
assert r.headers["after"] == f"{function_type}_after_request"
assert not headers.get("before")
assert headers.get("after")
assert r.headers.get("after") == f"{function_type}_after_request"
assert r.text == f"{function_type} middlewares after"


@pytest.mark.benchmark
def test_global_middleware(session):
r = get("/sync/global/middlewares")
assert "global_before" not in r.headers
assert "global_after" in r.headers
assert r.headers["global_after"] == "global_after_request"
headers = r.headers
assert headers.get("global_after")
assert r.headers.get("global_after") == "global_after_request"
assert r.text == "sync global middlewares"


Expand Down
Loading

0 comments on commit 81aac86

Please sign in to comment.