Skip to content

Commit

Permalink
complete di
Browse files Browse the repository at this point in the history
  • Loading branch information
sansyrox committed Nov 9, 2023
1 parent 7b3f63b commit e583304
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 30 deletions.
27 changes: 19 additions & 8 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,16 @@
app = Robyn(__file__)
websocket = WebSocket(app, "/web_socket")


# Creating a new WebSocket app to test json handling + to serve an example to future users of this lib
# while the original "raw" web_socket is used with benchmark tests
websocket_json = WebSocket(app, "/web_socket_json")

websocket_di = WebSocket(app, "/web_socket_di")

websocket_di.inject_global(GLOBAL_DEPENDENCY="GLOBAL DEPENDENCY")
websocket_di.inject(ROUTER_DEPENDENCY="ROUTER DEPENDENCY")

current_file_path = pathlib.Path(__file__).parent.resolve()
jinja_template = JinjaTemplate(os.path.join(current_file_path, "templates"))

Expand Down Expand Up @@ -55,7 +61,7 @@ async def jsonws_message(ws, msg: str) -> str:


@websocket.on("message")
async def message(ws: WebSocketConnector, msg: str) -> str:
async def message(ws: WebSocketConnector, msg: str, global_dependencies) -> str:
global websocket_state
websocket_id = ws.id
state = websocket_state[websocket_id]
Expand All @@ -69,7 +75,6 @@ async def message(ws: WebSocketConnector, msg: str) -> str:
elif state == 2:
resp = "*chika* *chika* Slim Shady."
websocket_state[websocket_id] = (state + 1) % 3

return resp


Expand All @@ -93,6 +98,12 @@ def jsonws_connect():
return "Hello world, from ws"


@websocket_di.on("connect")
async def di_message(global_dependencies, router_dependencies):
return global_dependencies["GLOBAL_DEPENDENCY"] + " "+ router_dependencies["ROUTER_DEPENDENCY"]



# ===== Lifecycle handlers =====


Expand All @@ -112,8 +123,8 @@ def shutdown_handler():

@app.before_request()
def global_before_request(request: Request):
# request.headers["global_before"] = "global_before_request"
# return request
request.headers["global_before"] = "global_before_request"
return request


@app.after_request()
Expand Down Expand Up @@ -190,11 +201,11 @@ def sync_middlewares_401():

# Hello world

@app.get("/")
async def hello_world(router_dependencies):
return f"Hello, world! {router_dependencies['ROUTER_DEPENDENCY']}"

app.inject(RouterDependency="Router Dependency")

@app.get("/")
async def hello_world(request):
return f"Hello, world!"

@app.get("/sync/str")
def sync_str_get():
Expand Down
12 changes: 12 additions & 0 deletions integration_tests/test_web_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,15 @@ def test_web_socket_json(session):
resp = json.loads(ws.recv())
assert resp["resp"] == "*chika* *chika* Slim Shady."
assert resp["msg"] == msg

def test_websocket_di(session):
"""
Not using this as the benchmark test since this involves JSON marshalling/unmarshalling
"""

msg = "GLOBAL_DEPENDENCY ROUTER DEPENDENCY"

ws = create_connection(f"{BASE_URL}/web_socket_di")
assert ws.recv() == msg

6 changes: 4 additions & 2 deletions robyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ def add_route(

""" We will add the status code here only
"""
injected_dependencies = self.dependencies.get_dependency_map(self)

if auth_required:
self.middleware_router.add_auth_middleware(endpoint)(handler)

if isinstance(route_type, str):
http_methods = {
"GET": HttpMethod.GET,
Expand All @@ -107,7 +110,6 @@ def add_route(
}
route_type = http_methods[route_type]

injected_dependencies = self.dependencies.get_dependency_map(self)

add_route_response = self.router.add_route(
route_type=route_type,
Expand Down Expand Up @@ -155,7 +157,7 @@ def after_request(self, endpoint: Optional[str] = None) -> Callable[..., None]:
:param endpoint str|None: endpoint to server the route. If None, the middleware will be applied to all the routes.
"""

dependency_map = self.dependencies.get_dependency_map(self)
return self.middleware_router.add_middleware(MiddlewareType.AFTER_REQUEST, endpoint)

def add_directory(
Expand Down
47 changes: 30 additions & 17 deletions robyn/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def add_route(

@wraps(handler)
async def async_inner_handler(*args, **kwargs):
print("This is the args", args)
print("This is the kwargs", kwargs)
try:
response = self._format_response(
await handler(*args, **kwargs),
Expand Down Expand Up @@ -132,8 +134,7 @@ def inner_handler(*args, **kwargs):
number_of_params = len(signature(handler).parameters)
# these are the arguments
params = dict(inspect.signature(handler).parameters)
print("This is the params", params)
print("This is the injected dependencies", injected_dependencies)



new_injected_dependencies = {}
Expand All @@ -158,28 +159,35 @@ def get_routes(self) -> List[Route]:


class MiddlewareRouter(BaseRouter):
def __init__(self) -> None:
def __init__(self, dependencies: DependencyMap = DependencyMap()) -> None:
super().__init__()
self.global_middlewares: List[GlobalMiddleware] = []
self.route_middlewares: List[RouteMiddleware] = []
self.authentication_handler: Optional[AuthenticationHandler] = None
self.dependencies = dependencies

def set_authentication_handler(self, authentication_handler: AuthenticationHandler):
self.authentication_handler = authentication_handler

def add_route(self, middleware_type: MiddlewareType, endpoint: str, handler: Callable, injected_dependencies: Optional[dict]) -> Callable:
def add_route(self, middleware_type: MiddlewareType, endpoint: str, handler: Callable, injected_dependencies: dict) -> Callable:

# add a docstring here
params = dict(inspect.signature(handler).parameters)
number_of_params = len(params)

# need to do something here
dependency_map = {}

new_injected_dependencies = {}
for dependency in injected_dependencies:
if dependency in params:
new_injected_dependencies[dependency] = injected_dependencies[dependency]
else:
logging.warning(f"Dependency {dependency} is not used in the middleware handler {handler.__name__}")

print("This is new injected dependencies", new_injected_dependencies)


function = FunctionInfo(handler, iscoroutinefunction(handler), number_of_params, {}, {})

function = FunctionInfo(handler, iscoroutinefunction(handler), number_of_params, params, new_injected_dependencies)
self.route_middlewares.append(RouteMiddleware(middleware_type, endpoint, function))
return handler

Expand All @@ -188,26 +196,31 @@ def add_auth_middleware(self, endpoint: str):
This method adds an authentication middleware to the specified endpoint.
"""

def inner(handler):
injected_dependencies = {}

def decorator(handler):
@wraps(handler)
def inner_handler(request: Request, *args):
if not self.authentication_handler:
raise AuthenticationNotConfiguredError()
identity = self.authentication_handler.authenticate(request)
if identity is None:
return self.authentication_handler.unauthorized_response
request.identity = identity
return request
return handler(request, *args)

self.add_route(MiddlewareType.BEFORE_REQUEST, endpoint, inner_handler, None)
self.add_route(MiddlewareType.BEFORE_REQUEST, endpoint, inner_handler, injected_dependencies)
return inner_handler

return inner
return decorator


# These inner functions are basically a wrapper around the closure(decorator) being returned.
# They take a handler, convert it into a closure and return the arguments.
# Arguments are returned as they could be modified by the middlewares.
def add_middleware(self, middleware_type: MiddlewareType, endpoint: Optional[str]) -> Callable[..., None]:
# no dependency injection here
injected_dependencies = {}

def inner(handler):
@wraps(handler)
Expand All @@ -220,9 +233,9 @@ def inner_handler(*args, **kwargs):

if endpoint is not None:
if iscoroutinefunction(handler):
self.add_route(middleware_type, endpoint, async_inner_handler, None)
self.add_route(middleware_type, endpoint, async_inner_handler, injected_dependencies)
else:
self.add_route(middleware_type, endpoint, inner_handler, None)
self.add_route(middleware_type, endpoint, inner_handler, injected_dependencies)
else:
params = dict(inspect.signature(handler).parameters)

Expand All @@ -234,8 +247,8 @@ def inner_handler(*args, **kwargs):
async_inner_handler,
True,
len(params),
{},
{},
params,
injected_dependencies,
),
)
)
Expand All @@ -247,8 +260,8 @@ def inner_handler(*args, **kwargs):
inner_handler,
False,
len(params),
{},
{},
params,
injected_dependencies,
),
)
)
Expand Down
16 changes: 16 additions & 0 deletions robyn/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,19 @@ def inner(handler):

return inner

def inject(self, **kwargs):
"""
Injects the dependencies for the route
:param kwargs dict: the dependencies to be injected
"""
self.dependencies.add_router_dependency(self, **kwargs)

def inject_global(self, **kwargs):
"""
Injects the dependencies for the global routes
Ideally, this function should be a global function
:param kwargs dict: the dependencies to be injected
"""
self.dependencies.add_global_dependency(**kwargs)
11 changes: 10 additions & 1 deletion src/executors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,16 @@ where
handler.call((), Some(kwargs))
}
}
2..=u8::MAX => handler.call((function_args,), Some(kwargs)),
2 => {
if function.args.as_ref(py).get_item("request").is_some()
|| function.args.as_ref(py).get_item("response").is_some()
{
handler.call((function_args,), Some(kwargs))
} else {
handler.call((), Some(kwargs))
}
}
3..=u8::MAX => handler.call((function_args,), Some(kwargs)),
}
}

Expand Down
14 changes: 12 additions & 2 deletions src/executors/web_socket_executors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ fn get_function_output<'a>(
handler.call((), Some(kwargs))
}
}
// this is done to accommodate any future params
2 => {
if args.get_item("ws").is_some() && args.get_item("msg").is_some() {
handler.call1((ws.clone(), fn_msg.unwrap_or_default()))
Expand All @@ -43,7 +42,18 @@ fn get_function_output<'a>(
handler.call((), Some(kwargs))
}
}
3_u8..=u8::MAX => handler.call((ws.clone(), fn_msg.unwrap_or_default()), Some(kwargs)),
3 => {
if args.get_item("ws").is_some() && args.get_item("msg").is_some() {
handler.call((ws.clone(), fn_msg.unwrap_or_default()), Some(kwargs))
} else if args.get_item("ws").is_some() {
handler.call((ws.clone(),), Some(kwargs))
} else if args.get_item("msg").is_some() {
handler.call((fn_msg.unwrap_or_default(),), Some(kwargs))
} else {
handler.call((), Some(kwargs))
}
}
4_u8..=u8::MAX => handler.call((ws.clone(), fn_msg.unwrap_or_default()), Some(kwargs)),
}
}

Expand Down

0 comments on commit e583304

Please sign in to comment.