Skip to content

Commit

Permalink
Fix middleware handler and remove build_middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Oct 27, 2023
1 parent e47d248 commit dab592c
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 63 deletions.
24 changes: 0 additions & 24 deletions esmerald/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -2304,24 +2304,6 @@ def build_routes_exception_handlers(

return exception_handlers

def build_routes_middleware(
self, route: "RouteParent", middlewares: Optional[List["Middleware"]] = None
) -> List["Middleware"]:
"""
Builds the middleware stack from the top to the bottom of the routes.
The includes are an exception as they are treated as an independent ASGI
application and therefore handles their own middlewares independently.
"""
if not middlewares:
middlewares = []

if isinstance(route, Include):
app = getattr(route, "app", None)
if app and isinstance(app, (Esmerald, ChildEsmerald)):
return middlewares
return middlewares

def build_user_middleware_stack(self) -> List["StarletteMiddleware"]:
"""
Configures all the passed settings
Expand All @@ -2333,7 +2315,6 @@ def build_user_middleware_stack(self) -> List["StarletteMiddleware"]:
It evaluates the middleware passed into the routes from bottom up
"""
user_middleware = []
handlers_middleware: List["Middleware"] = []

if self.allowed_hosts:
user_middleware.append(
Expand All @@ -2351,11 +2332,6 @@ def build_user_middleware_stack(self) -> List["StarletteMiddleware"]:
StarletteMiddleware(SessionMiddleware, **self.session_config.model_dump())
)

handlers_middleware += self.router.middleware
for route in self.routes or []:
handlers_middleware.extend(self.build_routes_middleware(route))

self._middleware += handlers_middleware
for middleware in self._middleware or []:
if isinstance(middleware, StarletteMiddleware):
user_middleware.append(middleware)
Expand Down
3 changes: 2 additions & 1 deletion esmerald/routing/apis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ def get_filtered_handler(self) -> List[str]:
for handler_name in filtered_handlers:
for base in self.__class__.__bases__:
if handler_name not in dir(base) and isinstance(
getattr(self, handler_name), (HTTPHandler, WebSocketHandler, WebhookHandler)
getattr(self, handler_name),
(HTTPHandler, WebSocketHandler, WebhookHandler),
):
route_handlers.append(handler_name)
return route_handlers
Expand Down
33 changes: 6 additions & 27 deletions esmerald/routing/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,8 @@ async def hello(self) -> str:
handler=route_handler,
name=route_handler.path,
middleware=value.middleware,
interceptors=value.interceptors,
include_in_schema=value.include_in_schema,
permissions=value.permissions,
exception_handlers=value.exception_handlers,
)
Expand Down Expand Up @@ -1203,7 +1205,10 @@ async def handle(self, scope: "Scope", receive: "Receive", send: "Send") -> None
)
await response(scope, receive, send)

def __call__(self, fn: "AnyCallable") -> "HTTPHandler":
def __call__(
self,
fn: "AnyCallable",
) -> "HTTPHandler":
self.fn = fn
self.endpoint = fn
self.validate_handler()
Expand Down Expand Up @@ -1882,13 +1887,7 @@ async def another(request: Request) -> str:
if routes:
routes = self.resolve_route_path_handler(routes)

# Build the middleware from the routes
routes_middleware: List["Middleware"] = []
for route in routes or []:
routes_middleware = cast("List[Middleware]", self.build_routes_middleware(route))

# Add the middleware to the include
self.middleware += routes_middleware
include_middleware: Sequence["Middleware"] = []

for _middleware in self.middleware:
Expand Down Expand Up @@ -1919,26 +1918,6 @@ def resolve_app_parent(self, app: Optional[Any]) -> Optional[Any]:
app.parent = self
return app

def build_routes_middleware(
self, route: "RouteParent", middlewares: Optional[Sequence["Middleware"]] = None
) -> Sequence["Middleware"]:
"""
Builds the middleware stack from the top to the bottom of the routes.
"""
from esmerald import ChildEsmerald, Esmerald

if not middlewares:
middlewares = []

if isinstance(route, Include):
app = getattr(route, "app", None)
if app and isinstance(app, (Esmerald, ChildEsmerald)):
return middlewares

if isinstance(route, (Gateway, WebSocketGateway)):
middlewares.extend(route.middleware)
return middlewares

def resolve_route_path_handler(
self, routes: Sequence[Union["APIGateHandler", "Include"]]
) -> List[Union["Gateway", "WebSocketGateway", "Include"]]:
Expand Down
110 changes: 105 additions & 5 deletions tests/middleware/complex/test_complex.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from starlette.types import ASGIApp

from esmerald import Gateway, Request
from esmerald import Gateway, Include, Request
from esmerald.conf import settings
from esmerald.config.jwt import JWTConfig
from esmerald.contrib.auth.edgy.base_user import User as EdgyUser
Expand Down Expand Up @@ -62,16 +62,66 @@ async def update_user(self, request: Request) -> str:


def test_can_access_endpoint(test_app_client_factory):
with create_client(routes=[Gateway(handler=UserView)]) as client:
response = client.post("/users")
with create_client(routes=[Gateway("/v1", handler=UserView)]) as client:
response = client.post("/v1/users")

assert response.status_code == 201

response = client.get("/users")
response = client.get("/v1/users")

assert response.status_code == 401

response = client.put("/users/2")
response = client.put("/v1/users/2")

assert response.status_code == 401


def test_can_access_endpoint_with_include(test_app_client_factory):
with create_client(
routes=[
Include(
routes=[
Gateway("/v1", handler=UserView),
],
)
]
) as client:
response = client.post("/v1/users")

assert response.status_code == 201

response = client.get("/v1/users")

assert response.status_code == 401

response = client.put("/v1/users/2")

assert response.status_code == 401


def test_can_access_endpoint_with_include_nested(test_app_client_factory):
with create_client(
routes=[
Include(
routes=[
Include(
routes=[
Gateway("/v1", handler=UserView),
],
)
]
)
]
) as client:
response = client.post("/v1/users")

assert response.status_code == 201

response = client.get("/v1/users")

assert response.status_code == 401

response = client.put("/v1/users/2")

assert response.status_code == 401

Expand Down Expand Up @@ -110,3 +160,53 @@ def test_can_access_endpoint_blocked(test_app_client_factory):
response = client.put("/users/2")

assert response.status_code == 401


def test_can_access_endpoint_blocked_with_include(test_app_client_factory):
with create_client(
routes=[
Include(
routes=[
Gateway(handler=AnotherUserView),
],
)
]
) as client:
response = client.post("/users")

assert response.status_code == 401

response = client.get("/users")

assert response.status_code == 401

response = client.put("/users/2")

assert response.status_code == 401


def test_can_access_endpoint_blocked_with_include_nested(test_app_client_factory):
with create_client(
routes=[
Include(
routes=[
Include(
routes=[
Gateway(handler=AnotherUserView),
],
)
]
)
]
) as client:
response = client.post("/users")

assert response.status_code == 401

response = client.get("/users")

assert response.status_code == 401

response = client.put("/users/2")

assert response.status_code == 401
13 changes: 7 additions & 6 deletions tests/middleware/test_middleware_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def my_handler(self) -> None:
with create_client(
routes=[
Include(
path="/",
routes=[Gateway(path="/", handler=MyController)],
middleware=[create_test_middleware(2), create_test_middleware(3)],
),
Expand All @@ -216,8 +215,7 @@ def my_handler(self) -> None:
create_test_middleware(1),
],
) as client:
client.get("/router/controller/handler")

client.get("/controller/handler")
assert results == [0, 1, 2, 3, 4, 5, 6, 7]


Expand Down Expand Up @@ -269,7 +267,7 @@ def my_handler(self) -> None:
create_test_middleware(1),
],
) as client:
client.get("/router/controller/handler")
client.get("/controller/handler")

assert results == [0, 1, 2, 3, 4, 5, 6, 7]

Expand Down Expand Up @@ -371,7 +369,7 @@ def my_handler(self) -> None:
create_test_middleware(1),
],
) as client:
client.get("/routes/controller/handler")
client.get("/controller/handler")

assert results == [0, 1, 2, 3, 4, 5, 6, 7]

Expand Down Expand Up @@ -467,7 +465,10 @@ def my_handler(self) -> None:
routes=[
Include(
routes=[Include(app=child_esmerald)],
middleware=[create_test_middleware(2), create_test_middleware(3)],
middleware=[
create_test_middleware(2),
create_test_middleware(3),
],
)
],
)
Expand Down

0 comments on commit dab592c

Please sign in to comment.