From 585df2c2471e3b7e894396e7994582e49ae3b6b3 Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Mon, 27 Nov 2023 20:58:19 +0530 Subject: [PATCH 1/9] Add multimap headers --- robyn/__init__.py | 5 +++-- robyn/robyn.pyi | 12 ++++++++++++ robyn/router.py | 3 ++- src/lib.rs | 2 ++ src/types/headers.rs | 44 ++++++++++++++++++++++++++++++++++++++++++++ src/types/mod.rs | 1 + 6 files changed, 64 insertions(+), 3 deletions(-) create mode 100644 src/types/headers.rs diff --git a/robyn/__init__.py b/robyn/__init__.py index 553d75753..adbeb03a4 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -23,6 +23,7 @@ get_version, jsonify, WebSocketConnector, + Headers ) from robyn.router import MiddlewareRouter, MiddlewareType, Router, WebSocketRouter from robyn.types import Directory, Header @@ -59,8 +60,8 @@ def __init__(self, file_object: str, config: Config = Config()) -> None: self.router = Router() self.middleware_router = MiddlewareRouter() self.web_socket_router = WebSocketRouter() - self.request_headers: List[Header] = [] # This needs a better type - self.response_headers: List[Header] = [] # This needs a better type + self.request_headers: Headers = Headers() + self.response_headers: Headers = Headers() self.directories: List[Directory] = [] self.event_handlers = {} self.exception_handler: Optional[Callable] = None diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi index a99d0c274..f36cb8dd4 100644 --- a/robyn/robyn.pyi +++ b/robyn/robyn.pyi @@ -61,6 +61,18 @@ class Url: class Identity: claims: dict[str, str] +@dataclass +class Headers: + headers: dict[str, list[str]] + + def set(self, key: str, value: str) -> None: + pass + + def get(self, key: str, default: Optional[str]) -> Optional[str]: + pass + + + @dataclass class QueryParams: """ diff --git a/robyn/router.py b/robyn/router.py index b289a44ff..dc6227b2e 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -48,6 +48,7 @@ def _format_response( default_response_header: dict, ) -> Response: headers = {"Content-Type": "text/plain"} if not default_response_header else default_response_header + # we should create a header object here response = {} if isinstance(res, dict): status_code = res.get("status_code", status_codes.HTTP_200_OK) @@ -86,7 +87,7 @@ def add_route( exception_handler: Optional[Callable], default_response_headers: List[Header], ) -> Union[Callable, CoroutineType]: - response_headers = {d.key: d.val for d in default_response_headers} + response_headers = [( d.key: d.val ) for d in default_response_headers] @wraps(handler) async def async_inner_handler(*args): diff --git a/src/lib.rs b/src/lib.rs index 10cc9ffc2..9f69a888a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ use serde_json::Value; use pyo3::{exceptions::PyValueError, prelude::*}; use types::{ function_info::{FunctionInfo, MiddlewareType}, + headers::Headers, identity::Identity, multimap::QueryParams, request::PyRequest, @@ -55,6 +56,7 @@ pub fn robyn(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(jsonify, m)?)?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/types/headers.rs b/src/types/headers.rs new file mode 100644 index 000000000..c4424ee90 --- /dev/null +++ b/src/types/headers.rs @@ -0,0 +1,44 @@ +use log::debug; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyList}; +use std::collections::HashMap; + +// Custom Multimap class +#[pyclass(name = "Headers")] +#[derive(Clone, Debug, Default)] +pub struct Headers { + pub headers: HashMap>, +} + +#[pymethods] +impl Headers { + #[new] + pub fn new(default_headers: Option<&PyDict>) -> Self { + match default_headers { + Some(default_headers) => { + let mut headers = HashMap::new(); + for (key, value) in default_headers { + let key = key.to_string(); + let value: Vec = value + .downcast::() + .unwrap() + .iter() + .map(|x| x.to_string()) + .collect(); + headers.insert(key, value); + } + Headers { headers } + } + None => Headers { + headers: HashMap::new(), + }, + } + } + + pub fn set_header(&mut self, key: String, value: String) { + self.headers + .entry(key) + .or_insert_with(Vec::new) + .push(value.to_lowercase()); + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index 53fc9656b..0af7a6b00 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -5,6 +5,7 @@ use pyo3::{ }; pub mod function_info; +pub mod headers; pub mod identity; pub mod multimap; pub mod request; From 130ba74bf33eb45fe7088c9207ea0e177f4a6643 Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Thu, 30 Nov 2023 03:34:42 +0530 Subject: [PATCH 2/9] Add multimap headers --- robyn/__init__.py | 4 +-- robyn/robyn.pyi | 51 +++++++++++++++++++++++++++++------- robyn/router.py | 21 ++++++++++----- src/types/headers.rs | 61 ++++++++++++++++++++++++++++++++++++------- src/types/response.rs | 8 +++--- 5 files changed, 116 insertions(+), 29 deletions(-) diff --git a/robyn/__init__.py b/robyn/__init__.py index adbeb03a4..27505ef3e 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -142,10 +142,10 @@ def add_directory( self.directories.append(Directory(route, directory_path, show_files_listing, index_file)) def add_request_header(self, key: str, value: str) -> None: - self.request_headers.append(Header(key, value)) + self.request_headers.set(key, value) def add_response_header(self, key: str, value: str) -> None: - self.response_headers.append(Header(key, value)) + self.response_headers.set(key, value) def add_web_socket(self, endpoint: str, ws: WebSocket) -> None: self.web_socket_router.add_route(endpoint, ws) diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi index f36cb8dd4..d51f3626d 100644 --- a/robyn/robyn.pyi +++ b/robyn/robyn.pyi @@ -61,15 +61,7 @@ class Url: class Identity: claims: dict[str, str] -@dataclass -class Headers: - headers: dict[str, list[str]] - - def set(self, key: str, value: str) -> None: - pass - def get(self, key: str, default: Optional[str]) -> Optional[str]: - pass @@ -163,6 +155,47 @@ class QueryParams: def __repr__(self) -> str: pass +class Headers: + + def __init__(self, default_headers: Optional[dict]) -> None: + pass + + + def set(self, key: str, value: str) -> None: + """ + Sets the value of the header with the given key. + If the key already exists, the value will be appended to the list of values. + + Args: + key (str): The key of the header + value (str): The value of the header + """ + pass + + def get(self, key: str, default: Optional[str]) -> Optional[str]: + """ + Gets the last value of the header with the given key. + + Args: + key (str): The key of the header + default (Optional[str]): The default value if the key does not exist + """ + pass + + def populate_from_dict(self, headers: dict[str, str]) -> None: + """ + Populates the headers from a dictionary. + + Args: + headers (dict[str, str]): The dictionary of headers + """ + pass + + + def is_empty(self) -> bool: + pass + + @dataclass class Request: """ @@ -207,7 +240,7 @@ class Response: """ status_code: int - headers: dict[str, str] + headers: Headers description: Union[str, bytes] response_type: Optional[str] = None file_path: Optional[str] = None diff --git a/robyn/router.py b/robyn/router.py index dc6227b2e..b5a06f7c0 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -6,7 +6,7 @@ from typing import Callable, Dict, List, NamedTuple, Union, Optional from robyn.authentication import AuthenticationHandler, AuthenticationNotConfiguredError -from robyn.robyn import FunctionInfo, HttpMethod, MiddlewareType, Request, Response +from robyn.robyn import FunctionInfo, Headers, HttpMethod, MiddlewareType, Request, Response from robyn import status_codes from robyn.ws import WebSocket @@ -45,14 +45,22 @@ def __init__(self) -> None: def _format_response( self, res: dict, - default_response_header: dict, + default_response_header: Headers, ) -> Response: - headers = {"Content-Type": "text/plain"} if not default_response_header else default_response_header + # TODO: Add support for custom headers + headers = default_response_header + + if headers.is_empty(): + headers.set("Content-Type", "text/plain") # we should create a header object here response = {} if isinstance(res, dict): + # this should change status_code = res.get("status_code", status_codes.HTTP_200_OK) - headers = res.get("headers", headers) + headers = res.get("headers", None) + if headers is not None: + headers.populate_from_dict(headers) + description = res.get("description", "") if not isinstance(status_code, int): @@ -65,6 +73,8 @@ def _format_response( elif isinstance(res, Response): response = res elif isinstance(res, bytes): + headers = Headers({"Content-Type": "application/octet-stream"}) + response = Response( status_code=status_codes.HTTP_200_OK, headers={"Content-Type": "application/octet-stream"}, @@ -85,9 +95,8 @@ def add_route( handler: Callable, is_const: bool, exception_handler: Optional[Callable], - default_response_headers: List[Header], + response_headers: Headers, ) -> Union[Callable, CoroutineType]: - response_headers = [( d.key: d.val ) for d in default_response_headers] @wraps(handler) async def async_inner_handler(*args): diff --git a/src/types/headers.rs b/src/types/headers.rs index c4424ee90..d21f01e91 100644 --- a/src/types/headers.rs +++ b/src/types/headers.rs @@ -19,13 +19,17 @@ impl Headers { let mut headers = HashMap::new(); for (key, value) in default_headers { let key = key.to_string(); - let value: Vec = value - .downcast::() - .unwrap() - .iter() - .map(|x| x.to_string()) - .collect(); - headers.insert(key, value); + + let new_value = value.downcast::(); + + if new_value.is_err() { + let value = value.to_string(); + headers.entry(key).or_insert_with(Vec::new).push(value); + } else { + let value: Vec = + new_value.unwrap().iter().map(|x| x.to_string()).collect(); + headers.entry(key).or_insert_with(Vec::new).extend(value); + } } Headers { headers } } @@ -35,10 +39,49 @@ impl Headers { } } - pub fn set_header(&mut self, key: String, value: String) { + pub fn set(&mut self, key: String, value: String) { self.headers - .entry(key) + .entry(key.to_lowercase()) .or_insert_with(Vec::new) .push(value.to_lowercase()); } + + pub fn get(&self, py: Python, key: String) -> Py { + match self.headers.get(&key.to_lowercase()) { + Some(values) => { + let py_values = PyList::new(py, values.iter().map(|value| value.to_object(py))); + py_values.into() + } + None => PyList::empty(py).into(), + } + } + + /// Returns all headers as a PyList of tuples. + pub fn get_headers(&self, py: Python) -> Py { + let headers_list = PyList::new( + py, + self.headers.iter().map(|(key, values)| { + let py_values = PyList::new(py, values); + (key.clone(), py_values.to_object(py)) + }), + ); + headers_list.into() + } + + pub fn populate_from_dict(&mut self, headers: &PyDict) { + for (key, value) in headers { + let key = key.to_string().to_lowercase(); + let value: Vec = value + .downcast::() + .unwrap() + .iter() + .map(|x| x.to_string()) + .collect(); + self.headers.insert(key, value); + } + } + + pub fn is_empty(&self) -> bool { + self.headers.is_empty() + } } diff --git a/src/types/response.rs b/src/types/response.rs index 4962fd0f6..8c804b88b 100644 --- a/src/types/response.rs +++ b/src/types/response.rs @@ -5,9 +5,11 @@ use actix_web::{HttpRequest, HttpResponse, HttpResponseBuilder, Responder}; use pyo3::{ exceptions::{PyIOError, PyValueError}, prelude::*, - types::{PyBytes, PyDict, PyString}, + types::{PyBytes, PyString}, }; +use super::headers::Headers; + use crate::io_helpers::{apply_hashmap_headers, read_file}; use crate::types::{check_description_type, get_description_from_pyobject}; @@ -80,7 +82,7 @@ pub struct PyResponse { #[pyo3(get)] pub response_type: String, #[pyo3(get, set)] - pub headers: Py, + pub headers: Headers, #[pyo3(get)] pub description: Py, #[pyo3(get)] @@ -94,7 +96,7 @@ impl PyResponse { pub fn new( py: Python, status_code: u16, - headers: Py, + headers: Headers, description: Py, ) -> PyResult { if description.downcast::(py).is_err() From f8006e4e40e3ed45fe764d0ed3a2fc38cbf45d96 Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Fri, 1 Dec 2023 19:12:45 +0530 Subject: [PATCH 3/9] Fix many headers --- integration_tests/base_routes.py | 36 ++--- .../helpers/http_methods_helpers.py | 5 +- integration_tests/test_app.py | 25 ++-- integration_tests/test_authentication.py | 7 +- integration_tests/test_binary_output.py | 2 +- integration_tests/test_file_download.py | 3 +- integration_tests/test_middlewares.py | 15 +- integration_tests/test_unsupported_types.py | 7 +- robyn/__init__.py | 12 +- robyn/authentication.py | 4 +- robyn/processpool.py | 20 ++- robyn/robyn.pyi | 2 +- robyn/router.py | 8 +- src/executors/mod.rs | 2 + src/io_helpers/mod.rs | 14 +- src/server.rs | 43 +++--- src/types/headers.rs | 130 ++++++++++++++---- src/types/multimap.rs | 2 +- src/types/request.rs | 31 ++--- src/types/response.rs | 11 +- unit_tests/test_request_object.py | 8 +- 21 files changed, 245 insertions(+), 142 deletions(-) diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index b1de33b0b..63f3c533e 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -15,6 +15,7 @@ WebSocketConnector, ) from robyn.authentication import AuthenticationHandler, BearerGetter, Identity +from robyn.robyn import Headers from robyn.templating import JinjaTemplate from integration_tests.views import SyncView, AsyncView @@ -111,20 +112,21 @@ def shutdown_handler(): @app.before_request() def global_before_request(request: Request): - request.headers["global_before"] = "global_before_request" + # request.headers["global_before"] = "global_before_request" + request.headers.set("global_before", "global_before_request") return request @app.after_request() def global_after_request(response: Response): - response.headers["global_after"] = "global_after_request" + response.headers.set("global_after", "global_after_request") return response @app.get("/sync/global/middlewares") def sync_global_middlewares(request: Request): assert "global_before" in request.headers - assert request.headers["global_before"] == "global_before_request" + assert request.headers.get("global_before") == "global_before_request" return "sync global middlewares" @@ -133,13 +135,13 @@ def sync_global_middlewares(request: Request): @app.before_request("/sync/middlewares") def sync_before_request(request: Request): - request.headers["before"] = "sync_before_request" + request.headers.set("before", "sync_before_request") return request @app.after_request("/sync/middlewares") def sync_after_request(response: Response): - response.headers["after"] = "sync_after_request" + response.headers.set("after", "sync_after_request") response.description = response.description + " after" return response @@ -147,20 +149,21 @@ def sync_after_request(response: Response): @app.get("/sync/middlewares") def sync_middlewares(request: Request): assert "before" in request.headers - assert request.headers["before"] == "sync_before_request" + assert request.headers.get("before") == "sync_before_request" assert request.ip_addr == "127.0.0.1" return "sync middlewares" @app.before_request("/async/middlewares") async def async_before_request(request: Request): - request.headers["before"] = "async_before_request" + request.headers.set("before", "async_before_request") return request @app.after_request("/async/middlewares") async def async_after_request(response: Response): - response.headers["after"] = "async_after_request" + # response.headers["after"] = "async_after_request" + response.headers.set("after", "async_after_request") response.description = response.description + " after" return response @@ -168,14 +171,15 @@ async def async_after_request(response: Response): @app.get("/async/middlewares") async def async_middlewares(request: Request): assert "before" in request.headers - assert request.headers["before"] == "async_before_request" + # assert request.headers["before"] == "async_before_request" + assert request.headers.get("before") == "async_before_request" assert request.ip_addr == "127.0.0.1" return "async middlewares" @app.before_request("/sync/middlewares/401") def sync_before_request_401(): - return Response(401, {}, "sync before request 401") + return Response(401, Headers({}), "sync before request 401") @app.get("/sync/middlewares/401") @@ -266,22 +270,22 @@ async def async_dict_const_get(): @app.get("/sync/response") def sync_response_get(): - return Response(200, {"sync": "response"}, "sync response get") + return Response(200, Headers({"sync": "response"}), "sync response get") @app.get("/async/response") async def async_response_get(): - return Response(200, {"async": "response"}, "async response get") + return Response(200, Headers({"async": "response"}), "async response get") @app.get("/sync/response/const", const=True) def sync_response_const_get(): - return Response(200, {"sync_const": "response"}, "sync response const get") + return Response(200, Headers({"sync_const": "response"}), "sync response const get") @app.get("/async/response/const", const=True) async def async_response_const_get(): - return Response(200, {"async_const": "response"}, "async response const get") + return Response(200, Headers({"async_const": "response"}), "async response const get") # Binary @@ -301,7 +305,7 @@ async def async_octet_get(): def sync_octet_response_get(): return Response( status_code=200, - headers={"Content-Type": "application/octet-stream"}, + headers=Headers({"Content-Type": "application/octet-stream"}), description="sync octet response", ) @@ -310,7 +314,7 @@ def sync_octet_response_get(): async def async_octet_response_get(): return Response( status_code=200, - headers={"Content-Type": "application/octet-stream"}, + headers=Headers({"Content-Type": "application/octet-stream"}), description="async octet response", ) diff --git a/integration_tests/helpers/http_methods_helpers.py b/integration_tests/helpers/http_methods_helpers.py index 3515ab9e3..0c4c7fd19 100644 --- a/integration_tests/helpers/http_methods_helpers.py +++ b/integration_tests/helpers/http_methods_helpers.py @@ -11,10 +11,9 @@ def check_response(response: requests.Response, expected_status_code: int): headers is not present in the response. """ assert response.status_code == expected_status_code - assert "global_after" in response.headers - assert response.headers["global_after"] == "global_after_request" + assert response.headers.get("global_after") == "global_after_request" assert "server" in response.headers - assert response.headers["server"] == "robyn" + assert response.headers.get("server") == "robyn" def get( diff --git a/integration_tests/test_app.py b/integration_tests/test_app.py index 17a449dfc..d2df7b0a5 100644 --- a/integration_tests/test_app.py +++ b/integration_tests/test_app.py @@ -1,6 +1,6 @@ from robyn import Robyn, ALLOW_CORS from robyn.events import Events -from robyn.types import Header +from robyn.robyn import Headers import pytest @@ -9,14 +9,14 @@ def test_add_request_header(): app = Robyn(__file__) app.add_request_header("server", "robyn") - assert app.request_headers == [Header(key="server", val="robyn")] + assert app.request_headers.get_headers() == Headers({"server": "robyn"}).get_headers() @pytest.mark.benchmark def test_add_response_header(): app = Robyn(__file__) app.add_response_header("content-type", "application/json") - assert app.response_headers == [Header(key="content-type", val="application/json")] + assert app.response_headers.get_headers() == Headers({"content-type": "application/json"}).get_headers() @pytest.mark.benchmark @@ -48,12 +48,13 @@ async def mock_shutdown_handler(): def test_allow_cors(): app = Robyn(__file__) ALLOW_CORS(app, ["*"]) - assert app.request_headers == [ - Header(key="Access-Control-Allow-Origin", val="*"), - Header( - key="Access-Control-Allow-Methods", - val="GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS", - ), - Header(key="Access-Control-Allow-Headers", val="Content-Type, Authorization"), - Header(key="Access-Control-Allow-Credentials", val="true"), - ] + + headers = Headers({}) + headers.set("Access-Control-Allow-Origin", "*") + headers.set( + "Access-Control-Allow-Methods", + "GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS", + ) + headers.set("Access-Control-Allow-Headers", "Content-Type, Authorization") + headers.set("Access-Control-Allow-Credentials", "true") + assert app.response_headers.get_headers() == headers.get_headers() diff --git a/integration_tests/test_authentication.py b/integration_tests/test_authentication.py index 34409ca67..92983bfee 100644 --- a/integration_tests/test_authentication.py +++ b/integration_tests/test_authentication.py @@ -19,7 +19,8 @@ def test_invalid_authentication_token(session, function_type: str): should_check_response=False, ) assert r.status_code == 401 - assert r.headers["WWW-Authenticate"] == "BearerGetter" + assert r.headers.get("WWW-Authenticate") == "BearerGetter" + @pytest.mark.benchmark @@ -31,7 +32,7 @@ def test_invalid_authentication_header(session, function_type: str): should_check_response=False, ) assert r.status_code == 401 - assert r.headers["WWW-Authenticate"] == "BearerGetter" + assert r.headers.get("WWW-Authenticate") == "BearerGetter" @pytest.mark.benchmark @@ -39,4 +40,4 @@ def test_invalid_authentication_header(session, function_type: str): def test_invalid_authentication_no_token(session, function_type: str): r = get(f"/{function_type}/auth", should_check_response=False) assert r.status_code == 401 - assert r.headers["WWW-Authenticate"] == "BearerGetter" + assert r.headers.get("WWW-Authenticate") == "BearerGetter" diff --git a/integration_tests/test_binary_output.py b/integration_tests/test_binary_output.py index fac931191..a8b14fa64 100644 --- a/integration_tests/test_binary_output.py +++ b/integration_tests/test_binary_output.py @@ -17,5 +17,5 @@ ) def test_binary_output(route: str, text: str, session): r = get(route) - assert r.headers["Content-Type"] == "application/octet-stream" + assert r.headers.get("Content-Type") == "application/octet-stream" assert r.text == text diff --git a/integration_tests/test_file_download.py b/integration_tests/test_file_download.py index 7baf74966..18047838e 100644 --- a/integration_tests/test_file_download.py +++ b/integration_tests/test_file_download.py @@ -6,5 +6,6 @@ @pytest.mark.parametrize("function_type", ["sync", "async"]) def test_file_download(function_type: str, session): r = get(f"/{function_type}/file/download") - assert r.headers["Content-Disposition"] == "attachment" + assert r.headers.get("Content-Disposition") == "attachment" + assert r.text == "This is a test file for the downloading purpose" diff --git a/integration_tests/test_middlewares.py b/integration_tests/test_middlewares.py index c118f97a6..f299e9a37 100644 --- a/integration_tests/test_middlewares.py +++ b/integration_tests/test_middlewares.py @@ -7,19 +7,22 @@ @pytest.mark.parametrize("function_type", ["sync", "async"]) def test_middlewares(function_type: str, session): r = get(f"/{function_type}/middlewares") + headers = r.headers # We do not want the request headers to be in the response - assert "before" not in r.headers - assert "after" in r.headers - assert r.headers["after"] == f"{function_type}_after_request" + assert headers.get("global_before") + assert headers.get("global_after") + + assert r.headers.get("after") == f"{function_type}_after_request" assert r.text == f"{function_type} middlewares after" @pytest.mark.benchmark def test_global_middleware(session): r = get("/sync/global/middlewares") - assert "global_before" not in r.headers - assert "global_after" in r.headers - assert r.headers["global_after"] == "global_after_request" + headers = r.headers + assert headers.get("global_before") + assert headers.get("global_after") + assert r.headers.get("global_after") == "global_after_request" assert r.text == "sync global middlewares" diff --git a/integration_tests/test_unsupported_types.py b/integration_tests/test_unsupported_types.py index bf29023fe..205969b5e 100644 --- a/integration_tests/test_unsupported_types.py +++ b/integration_tests/test_unsupported_types.py @@ -16,7 +16,7 @@ class A: ["OK", b"OK"], Response( status_code=200, - headers={}, + headers=None, description=b"OK", ), ] @@ -29,7 +29,7 @@ def test_bad_body_types(description): with pytest.raises(ValueError): _ = Response( status_code=200, - headers={}, + headers=None, description=description, ) @@ -38,6 +38,7 @@ def test_bad_body_types(description): def test_good_body_types(description): _ = Response( status_code=200, - headers={}, + headers=None, + description=description, ) diff --git a/robyn/__init__.py b/robyn/__init__.py index 27505ef3e..1cccf17a5 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -60,8 +60,8 @@ def __init__(self, file_object: str, config: Config = Config()) -> None: self.router = Router() self.middleware_router = MiddlewareRouter() self.web_socket_router = WebSocketRouter() - self.request_headers: Headers = Headers() - self.response_headers: Headers = Headers() + self.request_headers: Headers = Headers({}) + self.response_headers: Headers = Headers({}) self.directories: List[Directory] = [] self.event_handlers = {} self.exception_handler: Optional[Callable] = None @@ -412,13 +412,13 @@ def options(self, endpoint: str): def ALLOW_CORS(app: Robyn, origins: List[str]): """Allows CORS for the given origins for the entire router.""" for origin in origins: - app.add_request_header("Access-Control-Allow-Origin", origin) - app.add_request_header( + app.add_response_header("Access-Control-Allow-Origin", origin) + app.add_response_header( "Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS", ) - app.add_request_header("Access-Control-Allow-Headers", "Content-Type, Authorization") - app.add_request_header("Access-Control-Allow-Credentials", "true") + app.add_response_header("Access-Control-Allow-Headers", "Content-Type, Authorization") + app.add_response_header("Access-Control-Allow-Credentials", "true") __all__ = [ diff --git a/robyn/authentication.py b/robyn/authentication.py index 6d40dc9b3..2879b93f6 100644 --- a/robyn/authentication.py +++ b/robyn/authentication.py @@ -1,7 +1,7 @@ from abc import ABC, abstractclassmethod, abstractmethod from typing import Optional -from robyn.robyn import Identity, Request, Response +from robyn.robyn import Headers, Identity, Request, Response from robyn.status_codes import HTTP_401_UNAUTHORIZED @@ -56,7 +56,7 @@ def __init__(self, token_getter: TokenGetter): @property def unauthorized_response(self) -> Response: return Response( - headers={"WWW-Authenticate": self.token_getter.scheme}, + headers=Headers({"WWW-Authenticate": self.token_getter.scheme}), description="Unauthorized", status_code=HTTP_401_UNAUTHORIZED, ) diff --git a/robyn/processpool.py b/robyn/processpool.py index f7e92a928..c303fb450 100644 --- a/robyn/processpool.py +++ b/robyn/processpool.py @@ -7,7 +7,7 @@ from robyn.logger import logger from robyn.events import Events -from robyn.robyn import FunctionInfo, Server, SocketHeld +from robyn.robyn import FunctionInfo, Headers, Server, SocketHeld from robyn.router import GlobalMiddleware, RouteMiddleware, Route from robyn.types import Directory, Header from robyn.ws import WebSocket @@ -17,7 +17,7 @@ def run_processes( url: str, port: int, directories: List[Directory], - request_headers: List[Header], + request_headers: Headers, routes: List[Route], global_middlewares: List[GlobalMiddleware], route_middlewares: List[RouteMiddleware], @@ -25,7 +25,7 @@ def run_processes( event_handlers: Dict[Events, FunctionInfo], workers: int, processes: int, - response_headers: List[Header], + response_headers: Headers, open_browser: bool, ) -> List[Process]: socket = SocketHeld(url, port) @@ -65,7 +65,7 @@ def terminating_signal_handler(_sig, _frame): def init_processpool( directories: List[Directory], - request_headers: List[Header], + request_headers: Headers, routes: List[Route], global_middlewares: List[GlobalMiddleware], route_middlewares: List[RouteMiddleware], @@ -74,7 +74,7 @@ def init_processpool( socket: SocketHeld, workers: int, processes: int, - response_headers: List[Header], + response_headers: Headers, ) -> List[Process]: process_pool = [] if sys.platform.startswith("win32"): @@ -134,7 +134,7 @@ def initialize_event_loop(): def spawn_process( directories: List[Directory], - request_headers: List[Header], + request_headers: Headers, routes: List[Route], global_middlewares: List[GlobalMiddleware], route_middlewares: List[RouteMiddleware], @@ -142,7 +142,7 @@ def spawn_process( event_handlers: Dict[Events, FunctionInfo], socket: SocketHeld, workers: int, - response_headers: List[Header], + response_headers: Headers, ): """ This function is called by the main process handler to create a server runtime. @@ -168,11 +168,9 @@ def spawn_process( for directory in directories: server.add_directory(*directory.as_list()) - for header in request_headers: - server.add_request_header(*header.as_list()) + server.set_request_headers(request_headers) - for header in response_headers: - server.add_response_header(*header.as_list()) + server.set_response_headers(response_headers) for route in routes: route_type, endpoint, function, is_const = route diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi index d51f3626d..050bed6bd 100644 --- a/robyn/robyn.pyi +++ b/robyn/robyn.pyi @@ -240,7 +240,7 @@ class Response: """ status_code: int - headers: Headers + headers: Optional[Headers] description: Union[str, bytes] response_type: Optional[str] = None file_path: Optional[str] = None diff --git a/robyn/router.py b/robyn/router.py index b5a06f7c0..582e7dc24 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -48,6 +48,8 @@ def _format_response( default_response_header: Headers, ) -> Response: # TODO: Add support for custom headers + assert isinstance(default_response_header, Headers) + headers = default_response_header if headers.is_empty(): @@ -57,9 +59,9 @@ def _format_response( if isinstance(res, dict): # this should change status_code = res.get("status_code", status_codes.HTTP_200_OK) - headers = res.get("headers", None) + response_headers = res.get("headers", {}) if headers is not None: - headers.populate_from_dict(headers) + headers.populate_from_dict(response_headers) description = res.get("description", "") @@ -77,7 +79,7 @@ def _format_response( response = Response( status_code=status_codes.HTTP_200_OK, - headers={"Content-Type": "application/octet-stream"}, + headers=Headers( {"Content-Type": "application/octet-stream"}), description=res, ) else: diff --git a/src/executors/mod.rs b/src/executors/mod.rs index ffb37ee36..a92f8aea6 100644 --- a/src/executors/mod.rs +++ b/src/executors/mod.rs @@ -80,6 +80,8 @@ pub async fn execute_http_function( })? .await?; + debug!("Function output: {:?}", output); + return Python::with_gil(|py| -> PyResult { output.extract(py) }); }; diff --git a/src/io_helpers/mod.rs b/src/io_helpers/mod.rs index ad5bb0d5c..bdeb36614 100644 --- a/src/io_helpers/mod.rs +++ b/src/io_helpers/mod.rs @@ -5,15 +5,17 @@ use std::io::Read; use actix_web::HttpResponseBuilder; use anyhow::Result; +use crate::types::headers::Headers; + // this should be something else // probably inside the submodule of the http router #[inline] -pub fn apply_hashmap_headers( - response: &mut HttpResponseBuilder, - headers: &HashMap, -) { - for (key, val) in headers.iter() { - response.insert_header((key.clone(), val.clone())); +pub fn apply_hashmap_headers(response: &mut HttpResponseBuilder, headers: &Headers) { + for mut iter in headers.headers.iter() { + let (key, values) = iter.pair(); + for value in values { + response.append_header((key.clone(), value.clone())); + } } } diff --git a/src/server.rs b/src/server.rs index 2ce7069cc..8dcafc271 100644 --- a/src/server.rs +++ b/src/server.rs @@ -7,6 +7,7 @@ use crate::routers::http_router::HttpRouter; use crate::routers::{middleware_router::MiddlewareRouter, web_socket_router::WebSocketRouter}; use crate::shared_socket::SocketHeld; use crate::types::function_info::{FunctionInfo, MiddlewareType}; +use crate::types::headers::Headers; use crate::types::request::Request; use crate::types::response::Response; use crate::types::HttpMethod; @@ -51,8 +52,8 @@ pub struct Server { const_router: Arc, websocket_router: Arc, middleware_router: Arc, - global_request_headers: Arc>, - global_response_headers: Arc>, + global_request_headers: Arc, + global_response_headers: Arc, directories: Arc>>, startup_handler: Option>, shutdown_handler: Option>, @@ -67,8 +68,8 @@ impl Server { const_router: Arc::new(ConstRouter::new()), websocket_router: Arc::new(WebSocketRouter::new()), middleware_router: Arc::new(MiddlewareRouter::new()), - global_request_headers: Arc::new(DashMap::new()), - global_response_headers: Arc::new(DashMap::new()), + global_request_headers: Arc::new(Headers::new(None)), + global_response_headers: Arc::new(Headers::new(None)), directories: Arc::new(RwLock::new(Vec::new())), startup_handler: None, shutdown_handler: None, @@ -255,27 +256,41 @@ impl Server { /// Adds a new request header to our concurrent hashmap /// this can be called after the server has started. pub fn add_request_header(&self, key: &str, value: &str) { - self.global_request_headers - .insert(key.to_string(), value.to_string()); + self.global_response_headers + .headers + .entry(key.to_string()) + .or_insert_with(Vec::new) + .push(value.to_string()); } /// Adds a new response header to our concurrent hashmap /// this can be called after the server has started. pub fn add_response_header(&self, key: &str, value: &str) { self.global_response_headers - .insert(key.to_string(), value.to_string()); + .headers + .entry(key.to_string()) + .or_insert_with(Vec::new) + .push(value.to_string()); } /// Removes a new request header to our concurrent hashmap /// this can be called after the server has started. pub fn remove_header(&self, key: &str) { - self.global_request_headers.remove(key); + self.global_request_headers.headers.remove(key); } /// Removes a new response header to our concurrent hashmap /// this can be called after the server has started. pub fn remove_response_header(&self, key: &str) { - self.global_response_headers.remove(key); + self.global_response_headers.headers.remove(key); + } + + pub fn set_request_headers(&mut self, headers: &Headers) { + self.global_request_headers = Arc::new(headers.clone()); + } + + pub fn set_response_headers(&mut self, headers: &Headers) { + self.global_response_headers = Arc::new(headers.clone()); } /// Add a new route to the routing tables @@ -375,8 +390,8 @@ async fn index( router: web::Data>, const_router: web::Data>, middleware_router: web::Data>, - global_request_headers: web::Data>>, - global_response_headers: web::Data>>, + global_request_headers: web::Data>, + global_response_headers: web::Data>, body: Bytes, req: HttpRequest, ) -> impl Responder { @@ -437,11 +452,7 @@ async fn index( Response::not_found(&request.headers) }; - response.headers.extend( - global_response_headers - .iter() - .map(|elt| (elt.key().clone(), elt.value().clone())), - ); + response.headers.extend(&global_response_headers); // After middleware // Global diff --git a/src/types/headers.rs b/src/types/headers.rs index d21f01e91..60b0de94c 100644 --- a/src/types/headers.rs +++ b/src/types/headers.rs @@ -1,13 +1,16 @@ +use actix_http::header::HeaderMap; +use actix_web::{web::Bytes, HttpRequest}; +use dashmap::DashMap; use log::debug; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList}; +use pyo3::types::{PyDict, PyIterator, PyList}; use std::collections::HashMap; // Custom Multimap class #[pyclass(name = "Headers")] #[derive(Clone, Debug, Default)] pub struct Headers { - pub headers: HashMap>, + pub headers: DashMap>, } #[pymethods] @@ -16,15 +19,17 @@ impl Headers { pub fn new(default_headers: Option<&PyDict>) -> Self { match default_headers { Some(default_headers) => { - let mut headers = HashMap::new(); + let mut headers = DashMap::new(); for (key, value) in default_headers { - let key = key.to_string(); + let key = key.to_string().to_lowercase(); let new_value = value.downcast::(); if new_value.is_err() { let value = value.to_string(); headers.entry(key).or_insert_with(Vec::new).push(value); + + // headers.entry(key).or_insert_with(Vec::new).push(value); } else { let value: Vec = new_value.unwrap().iter().map(|x| x.to_string()).collect(); @@ -34,7 +39,7 @@ impl Headers { Headers { headers } } None => Headers { - headers: HashMap::new(), + headers: DashMap::new(), }, } } @@ -46,7 +51,7 @@ impl Headers { .push(value.to_lowercase()); } - pub fn get(&self, py: Python, key: String) -> Py { + pub fn get_all(&self, py: Python, key: String) -> Py { match self.headers.get(&key.to_lowercase()) { Some(values) => { let py_values = PyList::new(py, values.iter().map(|value| value.to_object(py))); @@ -56,32 +61,111 @@ impl Headers { } } - /// Returns all headers as a PyList of tuples. - pub fn get_headers(&self, py: Python) -> Py { - let headers_list = PyList::new( - py, - self.headers.iter().map(|(key, values)| { - let py_values = PyList::new(py, values); - (key.clone(), py_values.to_object(py)) - }), - ); - headers_list.into() + pub fn get(&self, key: String) -> PyResult { + // return the last value + match self.headers.get(&key.to_lowercase()) { + Some(values) => Ok(values.last().unwrap().to_string()), + None => Err(pyo3::exceptions::PyKeyError::new_err(format!( + "KeyError: {}", + key + ))), + } + } + + pub fn get_headers(&self, py: Python) -> Py { + // return as a dict of lists + let dict = PyDict::new(py); + for iter in self.headers.iter() { + let (key, values) = iter.pair(); + let py_values = PyList::new(py, values.iter().map(|value| value.to_object(py))); + dict.set_item(key, py_values).unwrap(); + } + dict.into() + } + + pub fn contains(&self, key: String) -> bool { + self.headers.contains_key(&key.to_lowercase()) } pub fn populate_from_dict(&mut self, headers: &PyDict) { for (key, value) in headers { let key = key.to_string().to_lowercase(); - let value: Vec = value - .downcast::() - .unwrap() - .iter() - .map(|x| x.to_string()) - .collect(); - self.headers.insert(key, value); + let new_value = value.downcast::(); + + if new_value.is_err() { + let value = value.to_string(); + self.headers.entry(key).or_insert_with(Vec::new).push(value); + } else { + let value: Vec = new_value.unwrap().iter().map(|x| x.to_string()).collect(); + self.headers + .entry(key) + .or_insert_with(Vec::new) + .extend(value); + } } } pub fn is_empty(&self) -> bool { self.headers.is_empty() } + + fn __eq__(&self, other: &Headers) -> bool { + if self.headers.is_empty() && other.headers.is_empty() { + return true; + } + + if self.headers.len() != other.headers.len() { + return false; + } + + for iter in &self.headers { + let (key, values) = iter.pair(); + match other.headers.get(key) { + Some(other_values) => { + if values.len() != other_values.len() + || !values.iter().all(|v| other_values.contains(v)) + { + return false; + } + } + None => return false, + } + } + + true + } + + pub fn __repr__(&self) -> String { + format!("{:?}", self.headers) + } +} + +impl Headers { + pub fn remove(&mut self, key: &str) { + self.headers.remove(&key.to_lowercase()); + } + + pub fn extend(&mut self, headers: &Headers) { + for iter in headers.headers.iter() { + let (key, values) = iter.pair(); + let mut entry = self.headers.entry(key.clone()).or_default(); + entry.extend(values.clone()); + } + } + + pub fn from_actix_headers(req_headers: &HeaderMap) -> Self { + let mut headers = Headers::default(); + + for (key, value) in req_headers { + let key = key.to_string().to_lowercase(); + let value = value.to_str().unwrap().to_lowercase(); + headers + .headers + .entry(key) + .or_insert_with(Vec::new) + .push(value); + } + + headers + } } diff --git a/src/types/multimap.rs b/src/types/multimap.rs index b0f4e2a62..8715d3d5a 100644 --- a/src/types/multimap.rs +++ b/src/types/multimap.rs @@ -88,7 +88,7 @@ impl QueryParams { multimap } - pub fn from_dict(dict: &PyDict) -> Self { + pub fn from_py_dict(dict: &PyDict) -> Self { let mut multimap = QueryParams::new(); for (key, value) in dict.iter() { let key = key.extract::().unwrap(); diff --git a/src/types/request.rs b/src/types/request.rs index cf1c3e385..9287d7336 100644 --- a/src/types/request.rs +++ b/src/types/request.rs @@ -6,12 +6,12 @@ use std::collections::HashMap; use crate::types::{check_body_type, get_body_from_pyobject, Url}; -use super::{identity::Identity, multimap::QueryParams}; +use super::{headers::Headers, identity::Identity, multimap::QueryParams}; #[derive(Default, Debug, Clone, FromPyObject)] pub struct Request { pub query_params: QueryParams, - pub headers: HashMap, + pub headers: Headers, pub method: String, pub path_params: HashMap, // https://pyo3.rs/v0.19.2/function.html?highlight=from_py_#per-argument-options @@ -25,7 +25,7 @@ pub struct Request { impl ToPyObject for Request { fn to_object(&self, py: Python) -> PyObject { let query_params = self.query_params.clone(); - let headers = self.headers.clone().into_py(py).extract(py).unwrap(); + let headers = self.headers.clone(); let path_params = self.path_params.clone().into_py(py).extract(py).unwrap(); let body = match String::from_utf8(self.body.clone()) { Ok(s) => s.into_py(py), @@ -47,11 +47,7 @@ impl ToPyObject for Request { } impl Request { - pub fn from_actix_request( - req: &HttpRequest, - body: Bytes, - global_headers: &DashMap, - ) -> Self { + pub fn from_actix_request(req: &HttpRequest, body: Bytes, global_headers: &Headers) -> Self { let mut query_params: QueryParams = QueryParams::new(); if !req.query_string().is_empty() { let split = req.query_string().split('&'); @@ -63,16 +59,9 @@ impl Request { query_params.set(key, value); } } - let headers = req - .headers() - .iter() - .map(|(k, v)| (k.to_string(), v.to_str().unwrap().to_string())) - .chain( - global_headers - .iter() - .map(|h| (h.key().clone(), h.value().clone())), - ) - .collect(); + + let mut headers = Headers::from_actix_headers(req.headers()); + headers.extend(global_headers); let url = Url::new( req.connection_info().scheme(), @@ -100,7 +89,7 @@ pub struct PyRequest { #[pyo3(get, set)] pub query_params: QueryParams, #[pyo3(get, set)] - pub headers: Py, + pub headers: Headers, #[pyo3(get, set)] pub path_params: Py, #[pyo3(get, set)] @@ -121,7 +110,7 @@ impl PyRequest { #[allow(clippy::too_many_arguments)] pub fn new( query_params: &PyDict, - headers: Py, + headers: Headers, path_params: Py, body: Py, method: String, @@ -129,7 +118,7 @@ impl PyRequest { identity: Option, ip_addr: Option, ) -> Self { - let query_params = QueryParams::from_dict(query_params); + let query_params = QueryParams::from_py_dict(query_params); Self { query_params, diff --git a/src/types/response.rs b/src/types/response.rs index 8c804b88b..047b70ee0 100644 --- a/src/types/response.rs +++ b/src/types/response.rs @@ -17,7 +17,7 @@ use crate::types::{check_description_type, get_description_from_pyobject}; pub struct Response { pub status_code: u16, pub response_type: String, - pub headers: HashMap, + pub headers: Headers, // https://pyo3.rs/v0.19.2/function.html?highlight=from_py_#per-argument-options #[pyo3(from_py_with = "get_description_from_pyobject")] pub description: Vec, @@ -36,7 +36,7 @@ impl Responder for Response { } impl Response { - pub fn not_found(headers: &HashMap) -> Self { + pub fn not_found(headers: &Headers) -> Self { Self { status_code: 404, response_type: "text".to_string(), @@ -46,7 +46,7 @@ impl Response { } } - pub fn internal_server_error(headers: &HashMap) -> Self { + pub fn internal_server_error(headers: &Headers) -> Self { Self { status_code: 500, response_type: "text".to_string(), @@ -96,7 +96,7 @@ impl PyResponse { pub fn new( py: Python, status_code: u16, - headers: Headers, + headers: Option, description: Py, ) -> PyResult { if description.downcast::(py).is_err() @@ -106,6 +106,9 @@ impl PyResponse { "Could not convert specified body to bytes", )); }; + + let headers = headers.unwrap_or_default(); + Ok(Self { status_code, // we should be handling based on headers but works for now diff --git a/unit_tests/test_request_object.py b/unit_tests/test_request_object.py index fccc7fa8a..beda577ba 100644 --- a/unit_tests/test_request_object.py +++ b/unit_tests/test_request_object.py @@ -1,4 +1,4 @@ -from robyn.robyn import Request, Url +from robyn.robyn import Headers, Request, Url def test_request_object(): @@ -9,7 +9,8 @@ def test_request_object(): ) request = Request( query_params={}, - headers={"Content-Type": "application/json"}, + # headers={"Content-Type": "application/json"}, + headers=Headers({"Content-Type": "application/json"}), path_params={}, body="", method="GET", @@ -20,5 +21,6 @@ def test_request_object(): assert request.url.scheme == "https" assert request.url.host == "localhost" - assert request.headers["Content-Type"] == "application/json" + print(request.headers.get("Content-Type")) + assert request.headers.get("Content-Type") == "application/json" assert request.method == "GET" From ce69d58fcd6190594f61fedb8d1f584b53f5f481 Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Sat, 2 Dec 2023 01:25:53 +0530 Subject: [PATCH 4/9] need to fix final tests --- integration_tests/base_routes.py | 9 ++-- .../helpers/http_methods_helpers.py | 1 + integration_tests/test_app.py | 2 +- integration_tests/test_unsupported_types.py | 9 ++-- robyn/__init__.py | 8 +++- robyn/processpool.py | 4 +- robyn/robyn.pyi | 7 +-- src/server.rs | 44 ++++++++----------- src/types/headers.rs | 19 +++++++- src/types/request.rs | 6 +-- src/types/response.rs | 6 +-- 11 files changed, 63 insertions(+), 52 deletions(-) diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index 63f3c533e..886ee31ed 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -112,7 +112,6 @@ def shutdown_handler(): @app.before_request() def global_before_request(request: Request): - # request.headers["global_before"] = "global_before_request" request.headers.set("global_before", "global_before_request") return request @@ -125,7 +124,7 @@ def global_after_request(response: Response): @app.get("/sync/global/middlewares") def sync_global_middlewares(request: Request): - assert "global_before" in request.headers + assert request.headers.contains("global_before") assert request.headers.get("global_before") == "global_before_request" return "sync global middlewares" @@ -148,7 +147,7 @@ def sync_after_request(response: Response): @app.get("/sync/middlewares") def sync_middlewares(request: Request): - assert "before" in request.headers + assert request.headers.contains("before") assert request.headers.get("before") == "sync_before_request" assert request.ip_addr == "127.0.0.1" return "sync middlewares" @@ -170,7 +169,7 @@ async def async_after_request(response: Response): @app.get("/async/middlewares") async def async_middlewares(request: Request): - assert "before" in request.headers + assert request.headers.contains("before") # assert request.headers["before"] == "async_before_request" assert request.headers.get("before") == "async_before_request" assert request.ip_addr == "127.0.0.1" @@ -771,7 +770,7 @@ async def async_without_decorator(): def main(): - app.add_response_header("server", "robyn") + app.set_response_header("server", "robyn") app.add_directory( route="/test_dir", directory_path=os.path.join(current_file_path, "build"), diff --git a/integration_tests/helpers/http_methods_helpers.py b/integration_tests/helpers/http_methods_helpers.py index 0c4c7fd19..921375e72 100644 --- a/integration_tests/helpers/http_methods_helpers.py +++ b/integration_tests/helpers/http_methods_helpers.py @@ -11,6 +11,7 @@ def check_response(response: requests.Response, expected_status_code: int): headers is not present in the response. """ assert response.status_code == expected_status_code + print(response.headers) assert response.headers.get("global_after") == "global_after_request" assert "server" in response.headers assert response.headers.get("server") == "robyn" diff --git a/integration_tests/test_app.py b/integration_tests/test_app.py index d2df7b0a5..a2aaf8f57 100644 --- a/integration_tests/test_app.py +++ b/integration_tests/test_app.py @@ -8,7 +8,7 @@ @pytest.mark.benchmark def test_add_request_header(): app = Robyn(__file__) - app.add_request_header("server", "robyn") + app.set_request_header("server", "robyn") assert app.request_headers.get_headers() == Headers({"server": "robyn"}).get_headers() diff --git a/integration_tests/test_unsupported_types.py b/integration_tests/test_unsupported_types.py index 205969b5e..028595aca 100644 --- a/integration_tests/test_unsupported_types.py +++ b/integration_tests/test_unsupported_types.py @@ -1,6 +1,6 @@ import pytest -from robyn.robyn import Response +from robyn.robyn import Headers, Response class A: @@ -16,7 +16,7 @@ class A: ["OK", b"OK"], Response( status_code=200, - headers=None, + headers=Headers({}), description=b"OK", ), ] @@ -29,7 +29,7 @@ def test_bad_body_types(description): with pytest.raises(ValueError): _ = Response( status_code=200, - headers=None, + headers=Headers({}), description=description, ) @@ -38,7 +38,6 @@ def test_bad_body_types(description): def test_good_body_types(description): _ = Response( status_code=200, - headers=None, - + headers=Headers({}), description=description, ) diff --git a/robyn/__init__.py b/robyn/__init__.py index 1cccf17a5..50ea54906 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -142,9 +142,15 @@ def add_directory( self.directories.append(Directory(route, directory_path, show_files_listing, index_file)) def add_request_header(self, key: str, value: str) -> None: - self.request_headers.set(key, value) + self.request_headers.append(key, value) def add_response_header(self, key: str, value: str) -> None: + self.response_headers.append(key, value) + + def set_request_header(self, key: str, value: str) -> None: + self.request_headers.set(key, value) + + def set_response_header(self, key: str, value: str) -> None: self.response_headers.set(key, value) def add_web_socket(self, endpoint: str, ws: WebSocket) -> None: diff --git a/robyn/processpool.py b/robyn/processpool.py index c303fb450..265f31d0b 100644 --- a/robyn/processpool.py +++ b/robyn/processpool.py @@ -168,9 +168,9 @@ def spawn_process( for directory in directories: server.add_directory(*directory.as_list()) - server.set_request_headers(request_headers) + server.apply_request_headers(request_headers) - server.set_response_headers(response_headers) + server.apply_response_headers(response_headers) for route in routes: route_type, endpoint, function, is_const = route diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi index 050bed6bd..a261d7c80 100644 --- a/robyn/robyn.pyi +++ b/robyn/robyn.pyi @@ -240,7 +240,7 @@ class Response: """ status_code: int - headers: Optional[Headers] + headers: Headers description: Union[str, bytes] response_type: Optional[str] = None file_path: Optional[str] = None @@ -256,10 +256,11 @@ class Server: index_file: Optional[str], ) -> None: pass - def add_request_header(self, key: str, value: str) -> None: + def apply_request_header(self, key: str, value: str) -> None: pass - def add_response_header(self, key: str, value: str) -> None: + def apply_response_header(self, key: str, value: str) -> None: pass + def add_route( self, route_type: HttpMethod, diff --git a/src/server.rs b/src/server.rs index 8dcafc271..8281dfe8a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -253,26 +253,6 @@ impl Server { }); } - /// Adds a new request header to our concurrent hashmap - /// this can be called after the server has started. - pub fn add_request_header(&self, key: &str, value: &str) { - self.global_response_headers - .headers - .entry(key.to_string()) - .or_insert_with(Vec::new) - .push(value.to_string()); - } - - /// Adds a new response header to our concurrent hashmap - /// this can be called after the server has started. - pub fn add_response_header(&self, key: &str, value: &str) { - self.global_response_headers - .headers - .entry(key.to_string()) - .or_insert_with(Vec::new) - .push(value.to_string()); - } - /// Removes a new request header to our concurrent hashmap /// this can be called after the server has started. pub fn remove_header(&self, key: &str) { @@ -285,11 +265,11 @@ impl Server { self.global_response_headers.headers.remove(key); } - pub fn set_request_headers(&mut self, headers: &Headers) { + pub fn apply_request_headers(&mut self, headers: &Headers) { self.global_request_headers = Arc::new(headers.clone()); } - pub fn set_response_headers(&mut self, headers: &Headers) { + pub fn apply_response_headers(&mut self, headers: &Headers) { self.global_response_headers = Arc::new(headers.clone()); } @@ -395,6 +375,9 @@ async fn index( body: Bytes, req: HttpRequest, ) -> impl Responder { + debug!("Global Request Headers: {:?}", global_request_headers); + debug!("Global Response Headers: {:?}", global_response_headers); + let mut request = Request::from_actix_request(&req, body, &global_request_headers); // Before middleware @@ -409,7 +392,7 @@ async fn index( request.path_params = route_params; } for before_middleware in before_middlewares { - request = match execute_middleware_function(&request, &before_middleware).await { + request = match execute_middleware_function(&mut request, &before_middleware).await { Ok(MiddlewareReturn::Request(r)) => r, Ok(MiddlewareReturn::Response(r)) => { // If a before middleware returns a response, we abort the request and return the response @@ -452,8 +435,12 @@ async fn index( Response::not_found(&request.headers) }; + debug!("OG Response : {:?}", response); + response.headers.extend(&global_response_headers); + debug!("Extended Response : {:?}", response); + // After middleware // Global let mut after_middlewares = @@ -465,12 +452,17 @@ async fn index( after_middlewares.push(function); } for after_middleware in after_middlewares { - response = match execute_middleware_function(&response, &after_middleware).await { + response = match execute_middleware_function(&mut response, &after_middleware).await { Ok(MiddlewareReturn::Request(_)) => { error!("After middleware returned a request"); return Response::internal_server_error(&request.headers); } - Ok(MiddlewareReturn::Response(r)) => r, + Ok(MiddlewareReturn::Response(r)) => { + let response = r; + + debug!("Response returned: {:?}", response); + response + } Err(e) => { error!( "Error while executing after middleware function for endpoint `{}`: {}", @@ -482,7 +474,7 @@ async fn index( }; } - debug!("Response: {:?}", response); + debug!("Response returned: {:?}", response); response } diff --git a/src/types/headers.rs b/src/types/headers.rs index 60b0de94c..1bc4313d2 100644 --- a/src/types/headers.rs +++ b/src/types/headers.rs @@ -45,10 +45,19 @@ impl Headers { } pub fn set(&mut self, key: String, value: String) { + debug!("Setting header {} to {}", key, value); + self.headers + .insert(key.to_lowercase(), vec![value.to_lowercase()]); + debug!("This is new self: {:?}", self); + } + + pub fn append(&mut self, key: String, value: String) { + debug!("Setting header {} to {}", key, value); self.headers .entry(key.to_lowercase()) .or_insert_with(Vec::new) .push(value.to_lowercase()); + debug!("This is new self: {:?}", self); } pub fn get_all(&self, py: Python, key: String) -> Py { @@ -64,7 +73,11 @@ impl Headers { pub fn get(&self, key: String) -> PyResult { // return the last value match self.headers.get(&key.to_lowercase()) { - Some(values) => Ok(values.last().unwrap().to_string()), + Some(iter) => { + let (_, values) = iter.pair(); + let last_value = values.last().unwrap(); + Ok(last_value.to_string()) + } None => Err(pyo3::exceptions::PyKeyError::new_err(format!( "KeyError: {}", key @@ -84,6 +97,8 @@ impl Headers { } pub fn contains(&self, key: String) -> bool { + debug!("Checking if header {} exists", key); + debug!("Headers: {:?}", self.headers); self.headers.contains_key(&key.to_lowercase()) } @@ -154,7 +169,7 @@ impl Headers { } pub fn from_actix_headers(req_headers: &HeaderMap) -> Self { - let mut headers = Headers::default(); + let headers = Headers::default(); for (key, value) in req_headers { let key = key.to_string().to_lowercase(); diff --git a/src/types/request.rs b/src/types/request.rs index 9287d7336..439c3a7d2 100644 --- a/src/types/request.rs +++ b/src/types/request.rs @@ -25,7 +25,7 @@ pub struct Request { impl ToPyObject for Request { fn to_object(&self, py: Python) -> PyObject { let query_params = self.query_params.clone(); - let headers = self.headers.clone(); + let headers: Py = self.headers.clone().into_py(py).extract(py).unwrap(); let path_params = self.path_params.clone().into_py(py).extract(py).unwrap(); let body = match String::from_utf8(self.body.clone()) { Ok(s) => s.into_py(py), @@ -89,7 +89,7 @@ pub struct PyRequest { #[pyo3(get, set)] pub query_params: QueryParams, #[pyo3(get, set)] - pub headers: Headers, + pub headers: Py, #[pyo3(get, set)] pub path_params: Py, #[pyo3(get, set)] @@ -110,7 +110,7 @@ impl PyRequest { #[allow(clippy::too_many_arguments)] pub fn new( query_params: &PyDict, - headers: Headers, + headers: Py, path_params: Py, body: Py, method: String, diff --git a/src/types/response.rs b/src/types/response.rs index 047b70ee0..ea77d0db6 100644 --- a/src/types/response.rs +++ b/src/types/response.rs @@ -82,7 +82,7 @@ pub struct PyResponse { #[pyo3(get)] pub response_type: String, #[pyo3(get, set)] - pub headers: Headers, + pub headers: Py, #[pyo3(get)] pub description: Py, #[pyo3(get)] @@ -96,7 +96,7 @@ impl PyResponse { pub fn new( py: Python, status_code: u16, - headers: Option, + headers: Py, description: Py, ) -> PyResult { if description.downcast::(py).is_err() @@ -107,8 +107,6 @@ impl PyResponse { )); }; - let headers = headers.unwrap_or_default(); - Ok(Self { status_code, // we should be handling based on headers but works for now From ef65fc6d12f64f7b73596cec9ab786815f264cdd Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Sat, 2 Dec 2023 03:08:34 +0530 Subject: [PATCH 5/9] Fix request headers --- integration_tests/test_middlewares.py | 2 -- robyn/__init__.py | 1 - robyn/authentication.py | 5 ++++- robyn/router.py | 27 +++++++++++---------------- robyn/templating.py | 4 ++-- src/io_helpers/mod.rs | 3 ++- src/server.rs | 22 ++++++++++------------ src/types/headers.rs | 13 ++++--------- src/types/response.rs | 18 ++++++++++++++---- 9 files changed, 47 insertions(+), 48 deletions(-) diff --git a/integration_tests/test_middlewares.py b/integration_tests/test_middlewares.py index f299e9a37..9833d54d7 100644 --- a/integration_tests/test_middlewares.py +++ b/integration_tests/test_middlewares.py @@ -9,7 +9,6 @@ def test_middlewares(function_type: str, session): r = get(f"/{function_type}/middlewares") headers = r.headers # We do not want the request headers to be in the response - assert headers.get("global_before") assert headers.get("global_after") assert r.headers.get("after") == f"{function_type}_after_request" @@ -20,7 +19,6 @@ def test_middlewares(function_type: str, session): def test_global_middleware(session): r = get("/sync/global/middlewares") headers = r.headers - assert headers.get("global_before") assert headers.get("global_after") assert r.headers.get("global_after") == "global_after_request" assert r.text == "sync global middlewares" diff --git a/robyn/__init__.py b/robyn/__init__.py index 50ea54906..efb50d251 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -107,7 +107,6 @@ def add_route( handler, is_const, self.exception_handler, - self.response_headers, ) logger.info("Added route %s %s", route_type, endpoint) diff --git a/robyn/authentication.py b/robyn/authentication.py index 2879b93f6..8cd4bd988 100644 --- a/robyn/authentication.py +++ b/robyn/authentication.py @@ -79,7 +79,10 @@ class BearerGetter(TokenGetter): @classmethod def get_token(cls, request: Request) -> Optional[str]: - authorization_header = request.headers.get("authorization") + if request.headers.contains("authorization") : + authorization_header = request.headers.get("authorization") + else: + authorization_header = None if not authorization_header or not authorization_header.startswith("Bearer "): return None diff --git a/robyn/router.py b/robyn/router.py index 582e7dc24..28217fb6d 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -45,23 +45,20 @@ def __init__(self) -> None: def _format_response( self, res: dict, - default_response_header: Headers, ) -> Response: # TODO: Add support for custom headers - assert isinstance(default_response_header, Headers) + headers = Headers({"Content-Type": "text/plain"}) - headers = default_response_header - - if headers.is_empty(): - headers.set("Content-Type", "text/plain") # we should create a header object here response = {} if isinstance(res, dict): # this should change status_code = res.get("status_code", status_codes.HTTP_200_OK) - response_headers = res.get("headers", {}) - if headers is not None: - headers.populate_from_dict(response_headers) + headers = res.get("headers", {}) + headers = Headers(headers) + if not headers.contains("Content-Type"): + headers.set("Content-Type", "text/plain") + description = res.get("description", "") @@ -76,10 +73,9 @@ def _format_response( response = res elif isinstance(res, bytes): headers = Headers({"Content-Type": "application/octet-stream"}) - response = Response( status_code=status_codes.HTTP_200_OK, - headers=Headers( {"Content-Type": "application/octet-stream"}), + headers=headers, description=res, ) else: @@ -97,7 +93,6 @@ def add_route( handler: Callable, is_const: bool, exception_handler: Optional[Callable], - response_headers: Headers, ) -> Union[Callable, CoroutineType]: @wraps(handler) @@ -105,14 +100,14 @@ async def async_inner_handler(*args): try: response = self._format_response( await handler(*args), - response_headers, + # response_headers, ) except Exception as err: if exception_handler is None: raise response = self._format_response( exception_handler(err), - response_headers, + # response_headers, ) return response @@ -121,14 +116,14 @@ def inner_handler(*args): try: response = self._format_response( handler(*args), - response_headers, + # response_headers, ) except Exception as err: if exception_handler is None: raise response = self._format_response( exception_handler(err), - response_headers, + # response_headers, ) return response diff --git a/robyn/templating.py b/robyn/templating.py index 881fecf33..0fa8ffabd 100644 --- a/robyn/templating.py +++ b/robyn/templating.py @@ -2,7 +2,7 @@ from robyn import status_codes -from .robyn import Response +from .robyn import Headers, Response from jinja2 import Environment, FileSystemLoader @@ -25,7 +25,7 @@ def render_template(self, template_name, **kwargs) -> Response: return Response( status_code=status_codes.HTTP_200_OK, description=rendered_template, - headers={"Content-Type": "text/html; charset=utf-8"}, + headers=Headers({"Content-Type": "text/html; charset=utf-8"}), ) diff --git a/src/io_helpers/mod.rs b/src/io_helpers/mod.rs index bdeb36614..9a8dc1ac5 100644 --- a/src/io_helpers/mod.rs +++ b/src/io_helpers/mod.rs @@ -4,6 +4,7 @@ use std::io::Read; use actix_web::HttpResponseBuilder; use anyhow::Result; +use log::debug; use crate::types::headers::Headers; @@ -11,7 +12,7 @@ use crate::types::headers::Headers; // probably inside the submodule of the http router #[inline] pub fn apply_hashmap_headers(response: &mut HttpResponseBuilder, headers: &Headers) { - for mut iter in headers.headers.iter() { + for iter in headers.headers.iter() { let (key, values) = iter.pair(); for value in values { response.append_header((key.clone(), value.clone())); diff --git a/src/server.rs b/src/server.rs index 8281dfe8a..6ec51c766 100644 --- a/src/server.rs +++ b/src/server.rs @@ -375,9 +375,6 @@ async fn index( body: Bytes, req: HttpRequest, ) -> impl Responder { - debug!("Global Request Headers: {:?}", global_request_headers); - debug!("Global Response Headers: {:?}", global_response_headers); - let mut request = Request::from_actix_request(&req, body, &global_request_headers); // Before middleware @@ -404,7 +401,7 @@ async fn index( req.uri().path(), get_traceback(e.downcast_ref::().unwrap()) ); - return Response::internal_server_error(&request.headers); + return Response::internal_server_error(None); } }; } @@ -420,19 +417,20 @@ async fn index( req.uri().path(), ) { request.path_params = route_params; - execute_http_function(&request, &function) - .await - .unwrap_or_else(|e| { + match execute_http_function(&request, &function).await { + Ok(r) => r, + Err(e) => { error!( "Error while executing route function for endpoint `{}`: {}", req.uri().path(), get_traceback(&e) ); - Response::internal_server_error(&request.headers) - }) + Response::internal_server_error(None) + } + } } else { - Response::not_found(&request.headers) + Response::not_found(None) }; debug!("OG Response : {:?}", response); @@ -455,7 +453,7 @@ async fn index( response = match execute_middleware_function(&mut response, &after_middleware).await { Ok(MiddlewareReturn::Request(_)) => { error!("After middleware returned a request"); - return Response::internal_server_error(&request.headers); + return Response::internal_server_error(Some(&response.headers)); } Ok(MiddlewareReturn::Response(r)) => { let response = r; @@ -469,7 +467,7 @@ async fn index( req.uri().path(), get_traceback(e.downcast_ref::().unwrap()) ); - return Response::internal_server_error(&request.headers); + return Response::internal_server_error(Some(&response.headers)); } }; } diff --git a/src/types/headers.rs b/src/types/headers.rs index 1bc4313d2..35665c225 100644 --- a/src/types/headers.rs +++ b/src/types/headers.rs @@ -19,7 +19,7 @@ impl Headers { pub fn new(default_headers: Option<&PyDict>) -> Self { match default_headers { Some(default_headers) => { - let mut headers = DashMap::new(); + let headers = DashMap::new(); for (key, value) in default_headers { let key = key.to_string().to_lowercase(); @@ -28,8 +28,6 @@ impl Headers { if new_value.is_err() { let value = value.to_string(); headers.entry(key).or_insert_with(Vec::new).push(value); - - // headers.entry(key).or_insert_with(Vec::new).push(value); } else { let value: Vec = new_value.unwrap().iter().map(|x| x.to_string()).collect(); @@ -46,9 +44,7 @@ impl Headers { pub fn set(&mut self, key: String, value: String) { debug!("Setting header {} to {}", key, value); - self.headers - .insert(key.to_lowercase(), vec![value.to_lowercase()]); - debug!("This is new self: {:?}", self); + self.headers.insert(key.to_lowercase(), vec![value]); } pub fn append(&mut self, key: String, value: String) { @@ -56,8 +52,7 @@ impl Headers { self.headers .entry(key.to_lowercase()) .or_insert_with(Vec::new) - .push(value.to_lowercase()); - debug!("This is new self: {:?}", self); + .push(value); } pub fn get_all(&self, py: Python, key: String) -> Py { @@ -173,7 +168,7 @@ impl Headers { for (key, value) in req_headers { let key = key.to_string().to_lowercase(); - let value = value.to_str().unwrap().to_lowercase(); + let value = value.to_str().unwrap().to_string(); headers .headers .entry(key) diff --git a/src/types/response.rs b/src/types/response.rs index ea77d0db6..18d34e2ef 100644 --- a/src/types/response.rs +++ b/src/types/response.rs @@ -36,21 +36,31 @@ impl Responder for Response { } impl Response { - pub fn not_found(headers: &Headers) -> Self { + pub fn not_found(headers: Option<&Headers>) -> Self { + let headers = match headers { + Some(headers) => headers.clone(), + None => Headers::new(None), + }; + Self { status_code: 404, response_type: "text".to_string(), - headers: headers.clone(), + headers, description: "Not found".to_owned().into_bytes(), file_path: None, } } - pub fn internal_server_error(headers: &Headers) -> Self { + pub fn internal_server_error(headers: Option<&Headers>) -> Self { + let headers = match headers { + Some(headers) => headers.clone(), + None => Headers::new(None), + }; + Self { status_code: 500, response_type: "text".to_string(), - headers: headers.clone(), + headers, description: "Internal server error".to_owned().into_bytes(), file_path: None, } From 28d8059a81677eb7799a623b334c6f07eabdc549 Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Sat, 2 Dec 2023 03:40:28 +0530 Subject: [PATCH 6/9] Fix request headers --- integration_tests/base_routes.py | 2 -- robyn/robyn.pyi | 20 +++++++++++++------- robyn/router.py | 7 ------- src/types/request.rs | 4 +--- unit_tests/test_request_object.py | 5 ++--- 5 files changed, 16 insertions(+), 22 deletions(-) diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index 886ee31ed..193b7e62b 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -161,7 +161,6 @@ async def async_before_request(request: Request): @app.after_request("/async/middlewares") async def async_after_request(response: Response): - # response.headers["after"] = "async_after_request" response.headers.set("after", "async_after_request") response.description = response.description + " after" return response @@ -170,7 +169,6 @@ async def async_after_request(response: Response): @app.get("/async/middlewares") async def async_middlewares(request: Request): assert request.headers.contains("before") - # assert request.headers["before"] == "async_before_request" assert request.headers.get("before") == "async_before_request" assert request.ip_addr == "127.0.0.1" return "async middlewares" diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi index a261d7c80..14fec95a8 100644 --- a/robyn/robyn.pyi +++ b/robyn/robyn.pyi @@ -65,7 +65,6 @@ class Identity: -@dataclass class QueryParams: """ The query params object passed to the route handler. @@ -74,8 +73,6 @@ class QueryParams: queries (dict[str, list[str]]): The query parameters of the request. e.g. /user?id=123 -> {"id": "123"} """ - queries: dict[str, list[str]] - def set(self, key: str, value: str) -> None: """ Sets the value of the query parameter with the given key. @@ -172,13 +169,12 @@ class Headers: """ pass - def get(self, key: str, default: Optional[str]) -> Optional[str]: + def get(self, key: str) -> Optional[str]: """ Gets the last value of the header with the given key. Args: key (str): The key of the header - default (Optional[str]): The default value if the key does not exist """ pass @@ -191,6 +187,16 @@ class Headers: """ pass + def contains(self, key: str) -> bool: + """ + Returns: + True if the headers contain the key, False otherwise + + Args: + key (str): The key of the header + """ + pass + def is_empty(self) -> bool: pass @@ -203,7 +209,7 @@ class Request: Attributes: query_params (QueryParams): The query parameters of the request. e.g. /user?id=123 -> {"id": "123"} - headers (dict[str, str]): The headers of the request. e.g. {"Content-Type": "application/json"} + headers Headers: The headers of the request. e.g. Headers({"Content-Type": "application/json"}) params (dict[str, str]): The parameters of the request. e.g. /user/:id -> {"id": "123"} body (Union[str, bytes]): The body of the request. If the request is a JSON, it will be a dict. method (str): The method of the request. e.g. GET, POST, PUT, DELETE @@ -211,7 +217,7 @@ class Request: """ query_params: QueryParams - headers: dict[str, str] + headers: Headers path_params: dict[str, str] body: Union[str, bytes] method: str diff --git a/robyn/router.py b/robyn/router.py index 28217fb6d..13dd12e64 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -46,10 +46,8 @@ def _format_response( self, res: dict, ) -> Response: - # TODO: Add support for custom headers headers = Headers({"Content-Type": "text/plain"}) - # we should create a header object here response = {} if isinstance(res, dict): # this should change @@ -59,7 +57,6 @@ def _format_response( if not headers.contains("Content-Type"): headers.set("Content-Type", "text/plain") - description = res.get("description", "") if not isinstance(status_code, int): @@ -100,14 +97,12 @@ async def async_inner_handler(*args): try: response = self._format_response( await handler(*args), - # response_headers, ) except Exception as err: if exception_handler is None: raise response = self._format_response( exception_handler(err), - # response_headers, ) return response @@ -116,14 +111,12 @@ def inner_handler(*args): try: response = self._format_response( handler(*args), - # response_headers, ) except Exception as err: if exception_handler is None: raise response = self._format_response( exception_handler(err), - # response_headers, ) return response diff --git a/src/types/request.rs b/src/types/request.rs index 439c3a7d2..fce8ac074 100644 --- a/src/types/request.rs +++ b/src/types/request.rs @@ -109,7 +109,7 @@ impl PyRequest { #[new] #[allow(clippy::too_many_arguments)] pub fn new( - query_params: &PyDict, + query_params: QueryParams, headers: Py, path_params: Py, body: Py, @@ -118,8 +118,6 @@ impl PyRequest { identity: Option, ip_addr: Option, ) -> Self { - let query_params = QueryParams::from_py_dict(query_params); - Self { query_params, headers, diff --git a/unit_tests/test_request_object.py b/unit_tests/test_request_object.py index beda577ba..f1f46cd88 100644 --- a/unit_tests/test_request_object.py +++ b/unit_tests/test_request_object.py @@ -1,4 +1,4 @@ -from robyn.robyn import Headers, Request, Url +from robyn.robyn import Headers, QueryParams, Request, Url def test_request_object(): @@ -8,8 +8,7 @@ def test_request_object(): path="/user", ) request = Request( - query_params={}, - # headers={"Content-Type": "application/json"}, + query_params=QueryParams(), headers=Headers({"Content-Type": "application/json"}), path_params={}, body="", From e89b762f26a6237b9632b98ec3c6cbae27cea113 Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Sat, 2 Dec 2023 03:43:19 +0530 Subject: [PATCH 7/9] fix formatting --- integration_tests/test_authentication.py | 1 - robyn/__init__.py | 11 +---------- robyn/authentication.py | 4 ++-- robyn/robyn.pyi | 8 -------- robyn/router.py | 1 - src/io_helpers/mod.rs | 2 -- src/server.rs | 1 - src/types/headers.rs | 4 +--- src/types/request.rs | 1 - src/types/response.rs | 2 -- 10 files changed, 4 insertions(+), 31 deletions(-) diff --git a/integration_tests/test_authentication.py b/integration_tests/test_authentication.py index 92983bfee..34a33f608 100644 --- a/integration_tests/test_authentication.py +++ b/integration_tests/test_authentication.py @@ -22,7 +22,6 @@ def test_invalid_authentication_token(session, function_type: str): assert r.headers.get("WWW-Authenticate") == "BearerGetter" - @pytest.mark.benchmark @pytest.mark.parametrize("function_type", ["sync", "async"]) def test_invalid_authentication_header(session, function_type: str): diff --git a/robyn/__init__.py b/robyn/__init__.py index efb50d251..72718d9cd 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -15,16 +15,7 @@ from robyn.logger import logger from robyn.processpool import run_processes from robyn.responses import serve_file, serve_html -from robyn.robyn import ( - FunctionInfo, - HttpMethod, - Request, - Response, - get_version, - jsonify, - WebSocketConnector, - Headers -) +from robyn.robyn import FunctionInfo, HttpMethod, Request, Response, get_version, jsonify, WebSocketConnector, Headers from robyn.router import MiddlewareRouter, MiddlewareType, Router, WebSocketRouter from robyn.types import Directory, Header from robyn import status_codes diff --git a/robyn/authentication.py b/robyn/authentication.py index 8cd4bd988..785121b75 100644 --- a/robyn/authentication.py +++ b/robyn/authentication.py @@ -79,8 +79,8 @@ class BearerGetter(TokenGetter): @classmethod def get_token(cls, request: Request) -> Optional[str]: - if request.headers.contains("authorization") : - authorization_header = request.headers.get("authorization") + if request.headers.contains("authorization"): + authorization_header = request.headers.get("authorization") else: authorization_header = None diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi index 14fec95a8..01e570a97 100644 --- a/robyn/robyn.pyi +++ b/robyn/robyn.pyi @@ -61,10 +61,6 @@ class Url: class Identity: claims: dict[str, str] - - - - class QueryParams: """ The query params object passed to the route handler. @@ -153,11 +149,9 @@ class QueryParams: pass class Headers: - def __init__(self, default_headers: Optional[dict]) -> None: pass - def set(self, key: str, value: str) -> None: """ Sets the value of the header with the given key. @@ -197,11 +191,9 @@ class Headers: """ pass - def is_empty(self) -> bool: pass - @dataclass class Request: """ diff --git a/robyn/router.py b/robyn/router.py index 13dd12e64..62ac9f51e 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -91,7 +91,6 @@ def add_route( is_const: bool, exception_handler: Optional[Callable], ) -> Union[Callable, CoroutineType]: - @wraps(handler) async def async_inner_handler(*args): try: diff --git a/src/io_helpers/mod.rs b/src/io_helpers/mod.rs index 9a8dc1ac5..914e21b46 100644 --- a/src/io_helpers/mod.rs +++ b/src/io_helpers/mod.rs @@ -1,10 +1,8 @@ -use std::collections::HashMap; use std::fs::File; use std::io::Read; use actix_web::HttpResponseBuilder; use anyhow::Result; -use log::debug; use crate::types::headers::Headers; diff --git a/src/server.rs b/src/server.rs index 6ec51c766..ab6d4c017 100644 --- a/src/server.rs +++ b/src/server.rs @@ -26,7 +26,6 @@ use actix_files::Files; use actix_http::KeepAlive; use actix_web::web::Bytes; use actix_web::*; -use dashmap::DashMap; // pyO3 module use log::{debug, error}; diff --git a/src/types/headers.rs b/src/types/headers.rs index 35665c225..e1a24761c 100644 --- a/src/types/headers.rs +++ b/src/types/headers.rs @@ -1,10 +1,8 @@ use actix_http::header::HeaderMap; -use actix_web::{web::Bytes, HttpRequest}; use dashmap::DashMap; use log::debug; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyIterator, PyList}; -use std::collections::HashMap; +use pyo3::types::{PyDict, PyList}; // Custom Multimap class #[pyclass(name = "Headers")] diff --git a/src/types/request.rs b/src/types/request.rs index fce8ac074..98ed788d6 100644 --- a/src/types/request.rs +++ b/src/types/request.rs @@ -1,5 +1,4 @@ use actix_web::{web::Bytes, HttpRequest}; -use dashmap::DashMap; use pyo3::{exceptions::PyValueError, prelude::*, types::PyDict, types::PyString}; use serde_json::Value; use std::collections::HashMap; diff --git a/src/types/response.rs b/src/types/response.rs index 18d34e2ef..1eaeee100 100644 --- a/src/types/response.rs +++ b/src/types/response.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use actix_http::{body::BoxBody, StatusCode}; use actix_web::{HttpRequest, HttpResponse, HttpResponseBuilder, Responder}; use pyo3::{ From 82935ee9f5bded71638b6358ee5e4ebef2f601c9 Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Sat, 2 Dec 2023 05:23:31 +0530 Subject: [PATCH 8/9] docs --- .../api_reference/getting_started.mdx | 45 ++++++++++++++++++- .../api_reference/middlewares.mdx | 6 +-- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/docs_src/src/pages/documentation/api_reference/getting_started.mdx b/docs_src/src/pages/documentation/api_reference/getting_started.mdx index 59123a4ae..30f3b6cb0 100644 --- a/docs_src/src/pages/documentation/api_reference/getting_started.mdx +++ b/docs_src/src/pages/documentation/api_reference/getting_started.mdx @@ -457,6 +457,27 @@ Or setting the Headers globally *per* router. + + +`add_response_header` appends the header to the list of headers, while `set_response_header` replaces the header if it exists. + + + + + + + ```python {{ title: 'untyped' }} + app.set_response_header("content-type", "application/json") + ``` + + ```python {{title: 'typed'}} + app.set_response_header("content-type", "application/json") + ``` + + + + + ## Request Headers @@ -480,7 +501,7 @@ Either, by using the `headers` field in the `Request` object: print("These are the request headers: ", headers) - headers["modified"] = ["modified_value"] + headers.set("modified", "modified_value") print("These are the modified request headers: ", headers) @@ -496,7 +517,7 @@ Either, by using the `headers` field in the `Request` object: print("These are the request headers: ", headers) - headers["modified"] = ["modified_value"] + headers.set("modified", "modified_value") print("These are the modified request headers: ", headers) @@ -527,6 +548,26 @@ Or by using the global Request Headers: + + +`add_request_header` appends the header to the list of headers, while `set_request_header` replaces the header if it exists. + + + + + + + ```python {{ title: 'untyped' }} + app.set_request_header("server", "robyn") + ``` + + ```python {{title: 'typed'}} + app.set_request_header("server", "robyn") + ``` + + + + --- diff --git a/docs_src/src/pages/documentation/api_reference/middlewares.mdx b/docs_src/src/pages/documentation/api_reference/middlewares.mdx index 5d01dbaac..65ac2dfe1 100644 --- a/docs_src/src/pages/documentation/api_reference/middlewares.mdx +++ b/docs_src/src/pages/documentation/api_reference/middlewares.mdx @@ -89,7 +89,7 @@ Batman was excited to learn that he could add events as functions as well as dec @app.after_request("/") def hello_after_request(response: Response): - response.headers["after"] = "sync_after_request" + response.headers.set("after", "sync_after_request"") print(response) ``` @@ -98,12 +98,12 @@ Batman was excited to learn that he could add events as functions as well as dec @app.before_request("/") async def hello_before_request(request): - request.headers["before"] = "sync_before_request" + request.headers.set("before", "sync_before_request") print(request) @app.after_request("/") def hello_after_request(response): - response.headers["after"] = "sync_after_request" + response.headers.set("after", "sync_after_request"") print(response) ``` From e381a3520d9272ae6a5d489fbd6fffe7617ecfd7 Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Sat, 2 Dec 2023 07:09:24 +0530 Subject: [PATCH 9/9] remove Header class --- integration_tests/test_middlewares.py | 4 ++-- robyn/__init__.py | 2 +- robyn/processpool.py | 2 +- robyn/router.py | 1 - robyn/types.py | 9 --------- 5 files changed, 4 insertions(+), 14 deletions(-) diff --git a/integration_tests/test_middlewares.py b/integration_tests/test_middlewares.py index 9833d54d7..172a5007b 100644 --- a/integration_tests/test_middlewares.py +++ b/integration_tests/test_middlewares.py @@ -9,8 +9,8 @@ def test_middlewares(function_type: str, session): r = get(f"/{function_type}/middlewares") headers = r.headers # We do not want the request headers to be in the response - assert headers.get("global_after") - + assert not headers.get("before") + assert headers.get("after") assert r.headers.get("after") == f"{function_type}_after_request" assert r.text == f"{function_type} middlewares after" diff --git a/robyn/__init__.py b/robyn/__init__.py index 72718d9cd..5c95a97c7 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -17,7 +17,7 @@ from robyn.responses import serve_file, serve_html from robyn.robyn import FunctionInfo, HttpMethod, Request, Response, get_version, jsonify, WebSocketConnector, Headers from robyn.router import MiddlewareRouter, MiddlewareType, Router, WebSocketRouter -from robyn.types import Directory, Header +from robyn.types import Directory from robyn import status_codes from robyn.ws import WebSocket diff --git a/robyn/processpool.py b/robyn/processpool.py index 265f31d0b..6e0042d58 100644 --- a/robyn/processpool.py +++ b/robyn/processpool.py @@ -9,7 +9,7 @@ from robyn.events import Events from robyn.robyn import FunctionInfo, Headers, Server, SocketHeld from robyn.router import GlobalMiddleware, RouteMiddleware, Route -from robyn.types import Directory, Header +from robyn.types import Directory from robyn.ws import WebSocket diff --git a/robyn/router.py b/robyn/router.py index 62ac9f51e..9e8475827 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -10,7 +10,6 @@ from robyn import status_codes from robyn.ws import WebSocket -from robyn.types import Header class Route(NamedTuple): diff --git a/robyn/types.py b/robyn/types.py index f5fd350bf..b7ed7d30f 100644 --- a/robyn/types.py +++ b/robyn/types.py @@ -16,12 +16,3 @@ def as_list(self): self.show_files_listing, self.index_file, ] - - -@dataclass -class Header: - key: str - val: str - - def as_list(self): - return [self.key, self.val]