Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multimap headers #704

Merged
merged 9 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions docs_src/src/pages/documentation/api_reference/getting_started.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,27 @@ Or setting the Headers globally *per* router.
</Col>
</Row>

<Row>
<Col>
`add_response_header` appends the header to the list of headers, while `set_response_header` replaces the header if it exists.
</Col>
<Col>


<CodeGroup title="Request" tag="GET" label="/hello_world">

```python {{ title: 'untyped' }}
app.set_response_header("content-type", "application/json")
```

```python {{title: 'typed'}}
app.set_response_header("content-type", "application/json")
```
</CodeGroup>
</Col>
</Row>



## Request Headers

Expand All @@ -480,7 +501,7 @@ Either, by using the `headers` field in the `Request` object:

print("These are the request headers: ", headers)

headers["modified"] = ["modified_value"]
headers.set("modified", "modified_value")

print("These are the modified request headers: ", headers)

Expand All @@ -496,7 +517,7 @@ Either, by using the `headers` field in the `Request` object:

print("These are the request headers: ", headers)

headers["modified"] = ["modified_value"]
headers.set("modified", "modified_value")

print("These are the modified request headers: ", headers)

Expand Down Expand Up @@ -527,6 +548,26 @@ Or by using the global Request Headers:
</Col>
</Row>

<Row>
<Col>
`add_request_header` appends the header to the list of headers, while `set_request_header` replaces the header if it exists.
</Col>
<Col sticky>


<CodeGroup title="Request" tag="GET" label="/hello_world">

```python {{ title: 'untyped' }}
app.set_request_header("server", "robyn")
```

```python {{title: 'typed'}}
app.set_request_header("server", "robyn")
```
</CodeGroup>
</Col>
</Row>


---

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ Batman was excited to learn that he could add events as functions as well as dec

@app.after_request("/")
def hello_after_request(response: Response):
response.headers["after"] = "sync_after_request"
response.headers.set("after", "sync_after_request"")
print(response)
```

Expand All @@ -98,12 +98,12 @@ Batman was excited to learn that he could add events as functions as well as dec

@app.before_request("/")
async def hello_before_request(request):
request.headers["before"] = "sync_before_request"
request.headers.set("before", "sync_before_request")
print(request)

@app.after_request("/")
def hello_after_request(response):
response.headers["after"] = "sync_after_request"
response.headers.set("after", "sync_after_request"")
print(response)
```

Expand Down
41 changes: 21 additions & 20 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
WebSocketConnector,
)
from robyn.authentication import AuthenticationHandler, BearerGetter, Identity
from robyn.robyn import Headers
from robyn.templating import JinjaTemplate

from integration_tests.views import SyncView, AsyncView
Expand Down Expand Up @@ -111,20 +112,20 @@ def shutdown_handler():

@app.before_request()
def global_before_request(request: Request):
request.headers["global_before"] = "global_before_request"
request.headers.set("global_before", "global_before_request")
return request


@app.after_request()
def global_after_request(response: Response):
response.headers["global_after"] = "global_after_request"
response.headers.set("global_after", "global_after_request")
return response


@app.get("/sync/global/middlewares")
def sync_global_middlewares(request: Request):
assert "global_before" in request.headers
assert request.headers["global_before"] == "global_before_request"
assert request.headers.contains("global_before")
assert request.headers.get("global_before") == "global_before_request"
return "sync global middlewares"


Expand All @@ -133,49 +134,49 @@ def sync_global_middlewares(request: Request):

@app.before_request("/sync/middlewares")
def sync_before_request(request: Request):
request.headers["before"] = "sync_before_request"
request.headers.set("before", "sync_before_request")
return request


@app.after_request("/sync/middlewares")
def sync_after_request(response: Response):
response.headers["after"] = "sync_after_request"
response.headers.set("after", "sync_after_request")
response.description = response.description + " after"
return response


@app.get("/sync/middlewares")
def sync_middlewares(request: Request):
assert "before" in request.headers
assert request.headers["before"] == "sync_before_request"
assert request.headers.contains("before")
assert request.headers.get("before") == "sync_before_request"
assert request.ip_addr == "127.0.0.1"
return "sync middlewares"


@app.before_request("/async/middlewares")
async def async_before_request(request: Request):
request.headers["before"] = "async_before_request"
request.headers.set("before", "async_before_request")
return request


@app.after_request("/async/middlewares")
async def async_after_request(response: Response):
response.headers["after"] = "async_after_request"
response.headers.set("after", "async_after_request")
response.description = response.description + " after"
return response


@app.get("/async/middlewares")
async def async_middlewares(request: Request):
assert "before" in request.headers
assert request.headers["before"] == "async_before_request"
assert request.headers.contains("before")
assert request.headers.get("before") == "async_before_request"
assert request.ip_addr == "127.0.0.1"
return "async middlewares"


@app.before_request("/sync/middlewares/401")
def sync_before_request_401():
return Response(401, {}, "sync before request 401")
return Response(401, Headers({}), "sync before request 401")


@app.get("/sync/middlewares/401")
Expand Down Expand Up @@ -266,22 +267,22 @@ async def async_dict_const_get():

@app.get("/sync/response")
def sync_response_get():
return Response(200, {"sync": "response"}, "sync response get")
return Response(200, Headers({"sync": "response"}), "sync response get")


@app.get("/async/response")
async def async_response_get():
return Response(200, {"async": "response"}, "async response get")
return Response(200, Headers({"async": "response"}), "async response get")


@app.get("/sync/response/const", const=True)
def sync_response_const_get():
return Response(200, {"sync_const": "response"}, "sync response const get")
return Response(200, Headers({"sync_const": "response"}), "sync response const get")


@app.get("/async/response/const", const=True)
async def async_response_const_get():
return Response(200, {"async_const": "response"}, "async response const get")
return Response(200, Headers({"async_const": "response"}), "async response const get")


# Binary
Expand All @@ -301,7 +302,7 @@ async def async_octet_get():
def sync_octet_response_get():
return Response(
status_code=200,
headers={"Content-Type": "application/octet-stream"},
headers=Headers({"Content-Type": "application/octet-stream"}),
description="sync octet response",
)

Expand All @@ -310,7 +311,7 @@ def sync_octet_response_get():
async def async_octet_response_get():
return Response(
status_code=200,
headers={"Content-Type": "application/octet-stream"},
headers=Headers({"Content-Type": "application/octet-stream"}),
description="async octet response",
)

Expand Down Expand Up @@ -767,7 +768,7 @@ async def async_without_decorator():


def main():
app.add_response_header("server", "robyn")
app.set_response_header("server", "robyn")
app.add_directory(
route="/test_dir",
directory_path=os.path.join(current_file_path, "build"),
Expand Down
6 changes: 3 additions & 3 deletions integration_tests/helpers/http_methods_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ def check_response(response: requests.Response, expected_status_code: int):
headers is not present in the response.
"""
assert response.status_code == expected_status_code
assert "global_after" in response.headers
assert response.headers["global_after"] == "global_after_request"
print(response.headers)
assert response.headers.get("global_after") == "global_after_request"
assert "server" in response.headers
assert response.headers["server"] == "robyn"
assert response.headers.get("server") == "robyn"


def get(
Expand Down
27 changes: 14 additions & 13 deletions integration_tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from robyn import Robyn, ALLOW_CORS
from robyn.events import Events
from robyn.types import Header
from robyn.robyn import Headers

import pytest


@pytest.mark.benchmark
def test_add_request_header():
app = Robyn(__file__)
app.add_request_header("server", "robyn")
assert app.request_headers == [Header(key="server", val="robyn")]
app.set_request_header("server", "robyn")
assert app.request_headers.get_headers() == Headers({"server": "robyn"}).get_headers()


@pytest.mark.benchmark
def test_add_response_header():
app = Robyn(__file__)
app.add_response_header("content-type", "application/json")
assert app.response_headers == [Header(key="content-type", val="application/json")]
assert app.response_headers.get_headers() == Headers({"content-type": "application/json"}).get_headers()


@pytest.mark.benchmark
Expand Down Expand Up @@ -48,12 +48,13 @@ async def mock_shutdown_handler():
def test_allow_cors():
app = Robyn(__file__)
ALLOW_CORS(app, ["*"])
assert app.request_headers == [
Header(key="Access-Control-Allow-Origin", val="*"),
Header(
key="Access-Control-Allow-Methods",
val="GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS",
),
Header(key="Access-Control-Allow-Headers", val="Content-Type, Authorization"),
Header(key="Access-Control-Allow-Credentials", val="true"),
]

headers = Headers({})
headers.set("Access-Control-Allow-Origin", "*")
headers.set(
"Access-Control-Allow-Methods",
"GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS",
)
headers.set("Access-Control-Allow-Headers", "Content-Type, Authorization")
headers.set("Access-Control-Allow-Credentials", "true")
assert app.response_headers.get_headers() == headers.get_headers()
6 changes: 3 additions & 3 deletions integration_tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_invalid_authentication_token(session, function_type: str):
should_check_response=False,
)
assert r.status_code == 401
assert r.headers["WWW-Authenticate"] == "BearerGetter"
assert r.headers.get("WWW-Authenticate") == "BearerGetter"


@pytest.mark.benchmark
Expand All @@ -31,12 +31,12 @@ def test_invalid_authentication_header(session, function_type: str):
should_check_response=False,
)
assert r.status_code == 401
assert r.headers["WWW-Authenticate"] == "BearerGetter"
assert r.headers.get("WWW-Authenticate") == "BearerGetter"


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_invalid_authentication_no_token(session, function_type: str):
r = get(f"/{function_type}/auth", should_check_response=False)
assert r.status_code == 401
assert r.headers["WWW-Authenticate"] == "BearerGetter"
assert r.headers.get("WWW-Authenticate") == "BearerGetter"
2 changes: 1 addition & 1 deletion integration_tests/test_binary_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
)
def test_binary_output(route: str, text: str, session):
r = get(route)
assert r.headers["Content-Type"] == "application/octet-stream"
assert r.headers.get("Content-Type") == "application/octet-stream"
assert r.text == text
3 changes: 2 additions & 1 deletion integration_tests/test_file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_file_download(function_type: str, session):
r = get(f"/{function_type}/file/download")
assert r.headers["Content-Disposition"] == "attachment"
assert r.headers.get("Content-Disposition") == "attachment"

assert r.text == "This is a test file for the downloading purpose"
13 changes: 7 additions & 6 deletions integration_tests/test_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_middlewares(function_type: str, session):
r = get(f"/{function_type}/middlewares")
headers = r.headers
# We do not want the request headers to be in the response
assert "before" not in r.headers
assert "after" in r.headers
assert r.headers["after"] == f"{function_type}_after_request"
assert not headers.get("before")
assert headers.get("after")
assert r.headers.get("after") == f"{function_type}_after_request"
assert r.text == f"{function_type} middlewares after"


@pytest.mark.benchmark
def test_global_middleware(session):
r = get("/sync/global/middlewares")
assert "global_before" not in r.headers
assert "global_after" in r.headers
assert r.headers["global_after"] == "global_after_request"
headers = r.headers
assert headers.get("global_after")
assert r.headers.get("global_after") == "global_after_request"
assert r.text == "sync global middlewares"


Expand Down
Loading
Loading